make dag.Set a simple map

This allows iteration directly over the set, removing the need to
allocate and copy a new slice for every iteration.
This commit is contained in:
James Bardin 2020-01-07 15:38:41 -05:00
parent 991ee37cee
commit 32ae3b5452
6 changed files with 68 additions and 72 deletions

View File

@ -29,8 +29,8 @@ func (g *AcyclicGraph) DirectedGraph() Grapher {
// Returns a Set that includes every Vertex yielded by walking down from the // Returns a Set that includes every Vertex yielded by walking down from the
// provided starting Vertex v. // provided starting Vertex v.
func (g *AcyclicGraph) Ancestors(v Vertex) (*Set, error) { func (g *AcyclicGraph) Ancestors(v Vertex) (Set, error) {
s := new(Set) s := make(Set)
start := AsVertexList(g.DownEdges(v)) start := AsVertexList(g.DownEdges(v))
memoFunc := func(v Vertex, d int) error { memoFunc := func(v Vertex, d int) error {
s.Add(v) s.Add(v)
@ -46,8 +46,8 @@ func (g *AcyclicGraph) Ancestors(v Vertex) (*Set, error) {
// Returns a Set that includes every Vertex yielded by walking up from the // Returns a Set that includes every Vertex yielded by walking up from the
// provided starting Vertex v. // provided starting Vertex v.
func (g *AcyclicGraph) Descendents(v Vertex) (*Set, error) { func (g *AcyclicGraph) Descendents(v Vertex) (Set, error) {
s := new(Set) s := make(Set)
start := AsVertexList(g.UpEdges(v)) start := AsVertexList(g.UpEdges(v))
memoFunc := func(v Vertex, d int) error { memoFunc := func(v Vertex, d int) error {
s.Add(v) s.Add(v)
@ -174,7 +174,7 @@ func (g *AcyclicGraph) Walk(cb WalkFunc) tfdiags.Diagnostics {
} }
// simple convenience helper for converting a dag.Set to a []Vertex // simple convenience helper for converting a dag.Set to a []Vertex
func AsVertexList(s *Set) []Vertex { func AsVertexList(s Set) []Vertex {
rawList := s.List() rawList := s.List()
vertexList := make([]Vertex, len(rawList)) vertexList := make([]Vertex, len(rawList))
for i, raw := range rawList { for i, raw := range rawList {

View File

@ -10,10 +10,10 @@ import (
// Graph is used to represent a dependency graph. // Graph is used to represent a dependency graph.
type Graph struct { type Graph struct {
vertices *Set vertices Set
edges *Set edges Set
downEdges map[interface{}]*Set downEdges map[interface{}]Set
upEdges map[interface{}]*Set upEdges map[interface{}]Set
// JSON encoder for recording debug information // JSON encoder for recording debug information
debug *encoder debug *encoder
@ -179,13 +179,13 @@ func (g *Graph) RemoveEdge(edge Edge) {
} }
// DownEdges returns the outward edges from the source Vertex v. // DownEdges returns the outward edges from the source Vertex v.
func (g *Graph) DownEdges(v Vertex) *Set { func (g *Graph) DownEdges(v Vertex) Set {
g.init() g.init()
return g.downEdges[hashcode(v)] return g.downEdges[hashcode(v)]
} }
// UpEdges returns the inward edges to the destination Vertex v. // UpEdges returns the inward edges to the destination Vertex v.
func (g *Graph) UpEdges(v Vertex) *Set { func (g *Graph) UpEdges(v Vertex) Set {
g.init() g.init()
return g.upEdges[hashcode(v)] return g.upEdges[hashcode(v)]
} }
@ -214,7 +214,7 @@ func (g *Graph) Connect(edge Edge) {
// Add the down edge // Add the down edge
s, ok := g.downEdges[sourceCode] s, ok := g.downEdges[sourceCode]
if !ok { if !ok {
s = new(Set) s = make(Set)
g.downEdges[sourceCode] = s g.downEdges[sourceCode] = s
} }
s.Add(target) s.Add(target)
@ -222,7 +222,7 @@ func (g *Graph) Connect(edge Edge) {
// Add the up edge // Add the up edge
s, ok = g.upEdges[targetCode] s, ok = g.upEdges[targetCode]
if !ok { if !ok {
s = new(Set) s = make(Set)
g.upEdges[targetCode] = s g.upEdges[targetCode] = s
} }
s.Add(source) s.Add(source)
@ -311,16 +311,16 @@ func (g *Graph) String() string {
func (g *Graph) init() { func (g *Graph) init() {
if g.vertices == nil { if g.vertices == nil {
g.vertices = new(Set) g.vertices = make(Set)
} }
if g.edges == nil { if g.edges == nil {
g.edges = new(Set) g.edges = make(Set)
} }
if g.downEdges == nil { if g.downEdges == nil {
g.downEdges = make(map[interface{}]*Set) g.downEdges = make(map[interface{}]Set)
} }
if g.upEdges == nil { if g.upEdges == nil {
g.upEdges = make(map[interface{}]*Set) g.upEdges = make(map[interface{}]Set)
} }
} }

View File

@ -134,15 +134,15 @@ func TestGraphEdgesFrom(t *testing.T) {
edges := g.EdgesFrom(1) edges := g.EdgesFrom(1)
var expected Set expected := make(Set)
expected.Add(BasicEdge(1, 3)) expected.Add(BasicEdge(1, 3))
var s Set s := make(Set)
for _, e := range edges { for _, e := range edges {
s.Add(e) s.Add(e)
} }
if s.Intersection(&expected).Len() != expected.Len() { if s.Intersection(expected).Len() != expected.Len() {
t.Fatalf("bad: %#v", edges) t.Fatalf("bad: %#v", edges)
} }
} }
@ -157,15 +157,15 @@ func TestGraphEdgesTo(t *testing.T) {
edges := g.EdgesTo(3) edges := g.EdgesTo(3)
var expected Set expected := make(Set)
expected.Add(BasicEdge(1, 3)) expected.Add(BasicEdge(1, 3))
var s Set s := make(Set)
for _, e := range edges { for _, e := range edges {
s.Add(e) s.Add(e)
} }
if s.Intersection(&expected).Len() != expected.Len() { if s.Intersection(expected).Len() != expected.Len() {
t.Fatalf("bad: %#v", edges) t.Fatalf("bad: %#v", edges)
} }
} }

View File

@ -1,14 +1,7 @@
package dag package dag
import (
"sync"
)
// Set is a set data structure. // Set is a set data structure.
type Set struct { type Set map[interface{}]interface{}
m map[interface{}]interface{}
once sync.Once
}
// Hashable is the interface used by set to get the hash code of a value. // Hashable is the interface used by set to get the hash code of a value.
// If this isn't given, then the value of the item being added to the set // If this isn't given, then the value of the item being added to the set
@ -27,32 +20,29 @@ func hashcode(v interface{}) interface{} {
} }
// Add adds an item to the set // Add adds an item to the set
func (s *Set) Add(v interface{}) { func (s Set) Add(v interface{}) {
s.once.Do(s.init) s[hashcode(v)] = v
s.m[hashcode(v)] = v
} }
// Delete removes an item from the set. // Delete removes an item from the set.
func (s *Set) Delete(v interface{}) { func (s Set) Delete(v interface{}) {
s.once.Do(s.init) delete(s, hashcode(v))
delete(s.m, hashcode(v))
} }
// Include returns true/false of whether a value is in the set. // Include returns true/false of whether a value is in the set.
func (s *Set) Include(v interface{}) bool { func (s Set) Include(v interface{}) bool {
s.once.Do(s.init) _, ok := s[hashcode(v)]
_, ok := s.m[hashcode(v)]
return ok return ok
} }
// Intersection computes the set intersection with other. // Intersection computes the set intersection with other.
func (s *Set) Intersection(other *Set) *Set { func (s Set) Intersection(other Set) Set {
result := new(Set) result := make(Set)
if s == nil { if s == nil {
return result return result
} }
if other != nil { if other != nil {
for _, v := range s.m { for _, v := range s {
if other.Include(v) { if other.Include(v) {
result.Add(v) result.Add(v)
} }
@ -64,13 +54,13 @@ func (s *Set) Intersection(other *Set) *Set {
// Difference returns a set with the elements that s has but // Difference returns a set with the elements that s has but
// other doesn't. // other doesn't.
func (s *Set) Difference(other *Set) *Set { func (s Set) Difference(other Set) Set {
result := new(Set) result := make(Set)
if s != nil { if s != nil {
for k, v := range s.m { for k, v := range s {
var ok bool var ok bool
if other != nil { if other != nil {
_, ok = other.m[k] _, ok = other[k]
} }
if !ok { if !ok {
result.Add(v) result.Add(v)
@ -83,10 +73,10 @@ func (s *Set) Difference(other *Set) *Set {
// Filter returns a set that contains the elements from the receiver // Filter returns a set that contains the elements from the receiver
// where the given callback returns true. // where the given callback returns true.
func (s *Set) Filter(cb func(interface{}) bool) *Set { func (s Set) Filter(cb func(interface{}) bool) Set {
result := new(Set) result := make(Set)
for _, v := range s.m { for _, v := range s {
if cb(v) { if cb(v) {
result.Add(v) result.Add(v)
} }
@ -96,28 +86,20 @@ func (s *Set) Filter(cb func(interface{}) bool) *Set {
} }
// Len is the number of items in the set. // Len is the number of items in the set.
func (s *Set) Len() int { func (s Set) Len() int {
if s == nil { return len(s)
return 0
}
return len(s.m)
} }
// List returns the list of set elements. // List returns the list of set elements.
func (s *Set) List() []interface{} { func (s Set) List() []interface{} {
if s == nil { if s == nil {
return nil return nil
} }
r := make([]interface{}, 0, len(s.m)) r := make([]interface{}, 0, len(s))
for _, v := range s.m { for _, v := range s {
r = append(r, v) r = append(r, v)
} }
return r return r
} }
func (s *Set) init() {
s.m = make(map[interface{}]interface{})
}

View File

@ -35,7 +35,9 @@ func TestSetDifference(t *testing.T) {
for i, tc := range cases { for i, tc := range cases {
t.Run(fmt.Sprintf("%d-%s", i, tc.Name), func(t *testing.T) { t.Run(fmt.Sprintf("%d-%s", i, tc.Name), func(t *testing.T) {
var one, two, expected Set one := make(Set)
two := make(Set)
expected := make(Set)
for _, v := range tc.A { for _, v := range tc.A {
one.Add(v) one.Add(v)
} }
@ -46,8 +48,8 @@ func TestSetDifference(t *testing.T) {
expected.Add(v) expected.Add(v)
} }
actual := one.Difference(&two) actual := one.Difference(two)
match := actual.Intersection(&expected) match := actual.Intersection(expected)
if match.Len() != expected.Len() { if match.Len() != expected.Len() {
t.Fatalf("bad: %#v", actual.List()) t.Fatalf("bad: %#v", actual.List())
} }
@ -78,7 +80,8 @@ func TestSetFilter(t *testing.T) {
for i, tc := range cases { for i, tc := range cases {
t.Run(fmt.Sprintf("%d-%#v", i, tc.Input), func(t *testing.T) { t.Run(fmt.Sprintf("%d-%#v", i, tc.Input), func(t *testing.T) {
var input, expected Set input := make(Set)
expected := make(Set)
for _, v := range tc.Input { for _, v := range tc.Input {
input.Add(v) input.Add(v)
} }
@ -89,7 +92,7 @@ func TestSetFilter(t *testing.T) {
actual := input.Filter(func(v interface{}) bool { actual := input.Filter(func(v interface{}) bool {
return v.(int) < 5 return v.(int) < 5
}) })
match := actual.Intersection(&expected) match := actual.Intersection(expected)
if match.Len() != expected.Len() { if match.Len() != expected.Len() {
t.Fatalf("bad: %#v", actual.List()) t.Fatalf("bad: %#v", actual.List())
} }

View File

@ -64,6 +64,15 @@ type Walker struct {
diagsLock sync.Mutex diagsLock sync.Mutex
} }
func (w *Walker) init() {
if w.vertices == nil {
w.vertices = make(Set)
}
if w.edges == nil {
w.edges = make(Set)
}
}
type walkerVertex struct { type walkerVertex struct {
// These should only be set once on initialization and never written again. // These should only be set once on initialization and never written again.
// They are not protected by a lock since they don't need to be since // They are not protected by a lock since they don't need to be since
@ -140,7 +149,9 @@ func (w *Walker) Wait() tfdiags.Diagnostics {
// time during a walk. // time during a walk.
func (w *Walker) Update(g *AcyclicGraph) { func (w *Walker) Update(g *AcyclicGraph) {
log.Print("[TRACE] dag/walk: updating graph") log.Print("[TRACE] dag/walk: updating graph")
var v, e *Set w.init()
v := make(Set)
e := make(Set)
if g != nil { if g != nil {
v, e = g.vertices, g.edges v, e = g.vertices, g.edges
} }
@ -157,9 +168,9 @@ func (w *Walker) Update(g *AcyclicGraph) {
} }
// Calculate all our sets // Calculate all our sets
newEdges := e.Difference(&w.edges) newEdges := e.Difference(w.edges)
oldEdges := w.edges.Difference(e) oldEdges := w.edges.Difference(e)
newVerts := v.Difference(&w.vertices) newVerts := v.Difference(w.vertices)
oldVerts := w.vertices.Difference(v) oldVerts := w.vertices.Difference(v)
// Add the new vertices // Add the new vertices
@ -207,7 +218,7 @@ func (w *Walker) Update(g *AcyclicGraph) {
} }
// Add the new edges // Add the new edges
var changedDeps Set changedDeps := make(Set)
for _, raw := range newEdges.List() { for _, raw := range newEdges.List() {
edge := raw.(Edge) edge := raw.(Edge)
waiter, dep := w.edgeParts(edge) waiter, dep := w.edgeParts(edge)