diff --git a/dag/dag.go b/dag/dag.go index 77c67eff9..705f041d4 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -29,8 +29,8 @@ func (g *AcyclicGraph) DirectedGraph() Grapher { // Returns a Set that includes every Vertex yielded by walking down from the // provided starting Vertex v. -func (g *AcyclicGraph) Ancestors(v Vertex) (*Set, error) { - s := new(Set) +func (g *AcyclicGraph) Ancestors(v Vertex) (Set, error) { + s := make(Set) start := AsVertexList(g.DownEdges(v)) memoFunc := func(v Vertex, d int) error { s.Add(v) @@ -46,8 +46,8 @@ func (g *AcyclicGraph) Ancestors(v Vertex) (*Set, error) { // Returns a Set that includes every Vertex yielded by walking up from the // provided starting Vertex v. -func (g *AcyclicGraph) Descendents(v Vertex) (*Set, error) { - s := new(Set) +func (g *AcyclicGraph) Descendents(v Vertex) (Set, error) { + s := make(Set) start := AsVertexList(g.UpEdges(v)) memoFunc := func(v Vertex, d int) error { s.Add(v) @@ -174,7 +174,7 @@ func (g *AcyclicGraph) Walk(cb WalkFunc) tfdiags.Diagnostics { } // simple convenience helper for converting a dag.Set to a []Vertex -func AsVertexList(s *Set) []Vertex { +func AsVertexList(s Set) []Vertex { rawList := s.List() vertexList := make([]Vertex, len(rawList)) for i, raw := range rawList { diff --git a/dag/graph.go b/dag/graph.go index e7517a206..1b50524b4 100644 --- a/dag/graph.go +++ b/dag/graph.go @@ -10,10 +10,10 @@ import ( // Graph is used to represent a dependency graph. type Graph struct { - vertices *Set - edges *Set - downEdges map[interface{}]*Set - upEdges map[interface{}]*Set + vertices Set + edges Set + downEdges map[interface{}]Set + upEdges map[interface{}]Set // JSON encoder for recording debug information debug *encoder @@ -179,13 +179,13 @@ func (g *Graph) RemoveEdge(edge Edge) { } // DownEdges returns the outward edges from the source Vertex v. -func (g *Graph) DownEdges(v Vertex) *Set { +func (g *Graph) DownEdges(v Vertex) Set { g.init() return g.downEdges[hashcode(v)] } // UpEdges returns the inward edges to the destination Vertex v. -func (g *Graph) UpEdges(v Vertex) *Set { +func (g *Graph) UpEdges(v Vertex) Set { g.init() return g.upEdges[hashcode(v)] } @@ -214,7 +214,7 @@ func (g *Graph) Connect(edge Edge) { // Add the down edge s, ok := g.downEdges[sourceCode] if !ok { - s = new(Set) + s = make(Set) g.downEdges[sourceCode] = s } s.Add(target) @@ -222,7 +222,7 @@ func (g *Graph) Connect(edge Edge) { // Add the up edge s, ok = g.upEdges[targetCode] if !ok { - s = new(Set) + s = make(Set) g.upEdges[targetCode] = s } s.Add(source) @@ -311,16 +311,16 @@ func (g *Graph) String() string { func (g *Graph) init() { if g.vertices == nil { - g.vertices = new(Set) + g.vertices = make(Set) } if g.edges == nil { - g.edges = new(Set) + g.edges = make(Set) } if g.downEdges == nil { - g.downEdges = make(map[interface{}]*Set) + g.downEdges = make(map[interface{}]Set) } if g.upEdges == nil { - g.upEdges = make(map[interface{}]*Set) + g.upEdges = make(map[interface{}]Set) } } diff --git a/dag/graph_test.go b/dag/graph_test.go index 02c4debd5..297974431 100644 --- a/dag/graph_test.go +++ b/dag/graph_test.go @@ -134,15 +134,15 @@ func TestGraphEdgesFrom(t *testing.T) { edges := g.EdgesFrom(1) - var expected Set + expected := make(Set) expected.Add(BasicEdge(1, 3)) - var s Set + s := make(Set) for _, e := range edges { s.Add(e) } - if s.Intersection(&expected).Len() != expected.Len() { + if s.Intersection(expected).Len() != expected.Len() { t.Fatalf("bad: %#v", edges) } } @@ -157,15 +157,15 @@ func TestGraphEdgesTo(t *testing.T) { edges := g.EdgesTo(3) - var expected Set + expected := make(Set) expected.Add(BasicEdge(1, 3)) - var s Set + s := make(Set) for _, e := range edges { s.Add(e) } - if s.Intersection(&expected).Len() != expected.Len() { + if s.Intersection(expected).Len() != expected.Len() { t.Fatalf("bad: %#v", edges) } } diff --git a/dag/set.go b/dag/set.go index 92b42151d..f3fd704ba 100644 --- a/dag/set.go +++ b/dag/set.go @@ -1,14 +1,7 @@ package dag -import ( - "sync" -) - // Set is a set data structure. -type Set struct { - m map[interface{}]interface{} - once sync.Once -} +type Set map[interface{}]interface{} // Hashable is the interface used by set to get the hash code of a value. // If this isn't given, then the value of the item being added to the set @@ -27,32 +20,29 @@ func hashcode(v interface{}) interface{} { } // Add adds an item to the set -func (s *Set) Add(v interface{}) { - s.once.Do(s.init) - s.m[hashcode(v)] = v +func (s Set) Add(v interface{}) { + s[hashcode(v)] = v } // Delete removes an item from the set. -func (s *Set) Delete(v interface{}) { - s.once.Do(s.init) - delete(s.m, hashcode(v)) +func (s Set) Delete(v interface{}) { + delete(s, hashcode(v)) } // Include returns true/false of whether a value is in the set. -func (s *Set) Include(v interface{}) bool { - s.once.Do(s.init) - _, ok := s.m[hashcode(v)] +func (s Set) Include(v interface{}) bool { + _, ok := s[hashcode(v)] return ok } // Intersection computes the set intersection with other. -func (s *Set) Intersection(other *Set) *Set { - result := new(Set) +func (s Set) Intersection(other Set) Set { + result := make(Set) if s == nil { return result } if other != nil { - for _, v := range s.m { + for _, v := range s { if other.Include(v) { result.Add(v) } @@ -64,13 +54,13 @@ func (s *Set) Intersection(other *Set) *Set { // Difference returns a set with the elements that s has but // other doesn't. -func (s *Set) Difference(other *Set) *Set { - result := new(Set) +func (s Set) Difference(other Set) Set { + result := make(Set) if s != nil { - for k, v := range s.m { + for k, v := range s { var ok bool if other != nil { - _, ok = other.m[k] + _, ok = other[k] } if !ok { result.Add(v) @@ -83,10 +73,10 @@ func (s *Set) Difference(other *Set) *Set { // Filter returns a set that contains the elements from the receiver // where the given callback returns true. -func (s *Set) Filter(cb func(interface{}) bool) *Set { - result := new(Set) +func (s Set) Filter(cb func(interface{}) bool) Set { + result := make(Set) - for _, v := range s.m { + for _, v := range s { if cb(v) { result.Add(v) } @@ -96,28 +86,20 @@ func (s *Set) Filter(cb func(interface{}) bool) *Set { } // Len is the number of items in the set. -func (s *Set) Len() int { - if s == nil { - return 0 - } - - return len(s.m) +func (s Set) Len() int { + return len(s) } // List returns the list of set elements. -func (s *Set) List() []interface{} { +func (s Set) List() []interface{} { if s == nil { return nil } - r := make([]interface{}, 0, len(s.m)) - for _, v := range s.m { + r := make([]interface{}, 0, len(s)) + for _, v := range s { r = append(r, v) } return r } - -func (s *Set) init() { - s.m = make(map[interface{}]interface{}) -} diff --git a/dag/set_test.go b/dag/set_test.go index c70da475e..63b72e323 100644 --- a/dag/set_test.go +++ b/dag/set_test.go @@ -35,7 +35,9 @@ func TestSetDifference(t *testing.T) { for i, tc := range cases { t.Run(fmt.Sprintf("%d-%s", i, tc.Name), func(t *testing.T) { - var one, two, expected Set + one := make(Set) + two := make(Set) + expected := make(Set) for _, v := range tc.A { one.Add(v) } @@ -46,8 +48,8 @@ func TestSetDifference(t *testing.T) { expected.Add(v) } - actual := one.Difference(&two) - match := actual.Intersection(&expected) + actual := one.Difference(two) + match := actual.Intersection(expected) if match.Len() != expected.Len() { t.Fatalf("bad: %#v", actual.List()) } @@ -78,7 +80,8 @@ func TestSetFilter(t *testing.T) { for i, tc := range cases { t.Run(fmt.Sprintf("%d-%#v", i, tc.Input), func(t *testing.T) { - var input, expected Set + input := make(Set) + expected := make(Set) for _, v := range tc.Input { input.Add(v) } @@ -89,7 +92,7 @@ func TestSetFilter(t *testing.T) { actual := input.Filter(func(v interface{}) bool { return v.(int) < 5 }) - match := actual.Intersection(&expected) + match := actual.Intersection(expected) if match.Len() != expected.Len() { t.Fatalf("bad: %#v", actual.List()) } diff --git a/dag/walk.go b/dag/walk.go index 509d76a3e..4fd41ed86 100644 --- a/dag/walk.go +++ b/dag/walk.go @@ -64,6 +64,15 @@ type Walker struct { diagsLock sync.Mutex } +func (w *Walker) init() { + if w.vertices == nil { + w.vertices = make(Set) + } + if w.edges == nil { + w.edges = make(Set) + } +} + type walkerVertex struct { // These should only be set once on initialization and never written again. // They are not protected by a lock since they don't need to be since @@ -140,7 +149,9 @@ func (w *Walker) Wait() tfdiags.Diagnostics { // time during a walk. func (w *Walker) Update(g *AcyclicGraph) { log.Print("[TRACE] dag/walk: updating graph") - var v, e *Set + w.init() + v := make(Set) + e := make(Set) if g != nil { v, e = g.vertices, g.edges } @@ -157,9 +168,9 @@ func (w *Walker) Update(g *AcyclicGraph) { } // Calculate all our sets - newEdges := e.Difference(&w.edges) + newEdges := e.Difference(w.edges) oldEdges := w.edges.Difference(e) - newVerts := v.Difference(&w.vertices) + newVerts := v.Difference(w.vertices) oldVerts := w.vertices.Difference(v) // Add the new vertices @@ -207,7 +218,7 @@ func (w *Walker) Update(g *AcyclicGraph) { } // Add the new edges - var changedDeps Set + changedDeps := make(Set) for _, raw := range newEdges.List() { edge := raw.(Edge) waiter, dep := w.edgeParts(edge)