diff --git a/dag/dag.go b/dag/dag.go index 8ca4e910e..f16a459f6 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -36,7 +36,7 @@ func (g *AcyclicGraph) Ancestors(v Vertex) (Set, error) { return nil } - if err := g.DepthFirstWalk(g.DownEdges(v), memoFunc); err != nil { + if err := g.DepthFirstWalk(g.downEdgesNoCopy(v), memoFunc); err != nil { return nil, err } @@ -52,7 +52,7 @@ func (g *AcyclicGraph) Descendents(v Vertex) (Set, error) { return nil } - if err := g.ReverseDepthFirstWalk(g.UpEdges(v), memoFunc); err != nil { + if err := g.ReverseDepthFirstWalk(g.upEdgesNoCopy(v), memoFunc); err != nil { return nil, err } @@ -65,7 +65,7 @@ func (g *AcyclicGraph) Descendents(v Vertex) (Set, error) { func (g *AcyclicGraph) Root() (Vertex, error) { roots := make([]Vertex, 0, 1) for _, v := range g.Vertices() { - if g.UpEdges(v).Len() == 0 { + if g.upEdgesNoCopy(v).Len() == 0 { roots = append(roots, v) } } @@ -101,10 +101,10 @@ func (g *AcyclicGraph) TransitiveReduction() { // // For each v-prime reachable from v, remove the edge (u, v-prime). for _, u := range g.Vertices() { - uTargets := g.DownEdges(u) + uTargets := g.downEdgesNoCopy(u) - g.DepthFirstWalk(g.DownEdges(u), func(v Vertex, d int) error { - shared := uTargets.Intersection(g.DownEdges(v)) + g.DepthFirstWalk(g.downEdgesNoCopy(u), func(v Vertex, d int) error { + shared := uTargets.Intersection(g.downEdgesNoCopy(v)) for _, vPrime := range shared { g.RemoveEdge(BasicEdge(u, vPrime)) } @@ -208,7 +208,7 @@ func (g *AcyclicGraph) DepthFirstWalk(start Set, f DepthWalkFunc) error { return err } - for _, v := range g.DownEdges(current.Vertex) { + for _, v := range g.downEdgesNoCopy(current.Vertex) { frontier = append(frontier, &vertexAtDepth{ Vertex: v, Depth: current.Depth + 1, @@ -248,7 +248,7 @@ func (g *AcyclicGraph) SortedDepthFirstWalk(start []Vertex, f DepthWalkFunc) err } // Visit targets of this in a consistent order. - targets := AsVertexList(g.DownEdges(current.Vertex)) + targets := AsVertexList(g.downEdgesNoCopy(current.Vertex)) sort.Sort(byVertexName(targets)) for _, t := range targets { @@ -285,7 +285,7 @@ func (g *AcyclicGraph) ReverseDepthFirstWalk(start Set, f DepthWalkFunc) error { } seen[current.Vertex] = struct{}{} - for _, t := range g.UpEdges(current.Vertex) { + for _, t := range g.upEdgesNoCopy(current.Vertex) { frontier = append(frontier, &vertexAtDepth{ Vertex: t, Depth: current.Depth + 1, @@ -325,7 +325,7 @@ func (g *AcyclicGraph) SortedReverseDepthFirstWalk(start []Vertex, f DepthWalkFu seen[current.Vertex] = struct{}{} // Add next set of targets in a consistent order. - targets := AsVertexList(g.UpEdges(current.Vertex)) + targets := AsVertexList(g.upEdgesNoCopy(current.Vertex)) sort.Sort(byVertexName(targets)) for _, t := range targets { frontier = append(frontier, &vertexAtDepth{ diff --git a/dag/graph.go b/dag/graph.go index 4ce0dbccb..1d0544354 100644 --- a/dag/graph.go +++ b/dag/graph.go @@ -111,10 +111,10 @@ func (g *Graph) Remove(v Vertex) Vertex { g.vertices.Delete(v) // Delete the edges to non-existent things - for _, target := range g.DownEdges(v) { + for _, target := range g.downEdgesNoCopy(v) { g.RemoveEdge(BasicEdge(v, target)) } - for _, source := range g.UpEdges(v) { + for _, source := range g.upEdgesNoCopy(v) { g.RemoveEdge(BasicEdge(source, v)) } @@ -137,10 +137,10 @@ func (g *Graph) Replace(original, replacement Vertex) bool { // Add our new vertex, then copy all the edges g.Add(replacement) - for _, target := range g.DownEdges(original) { + for _, target := range g.downEdgesNoCopy(original) { g.Connect(BasicEdge(replacement, target)) } - for _, source := range g.UpEdges(original) { + for _, source := range g.upEdgesNoCopy(original) { g.Connect(BasicEdge(source, replacement)) } @@ -166,14 +166,29 @@ func (g *Graph) RemoveEdge(edge Edge) { } } -// DownEdges returns the outward edges from the source Vertex v. +// UpEdges returns the vertices connected to the outward edges from the source +// Vertex v. +func (g *Graph) UpEdges(v Vertex) Set { + return g.upEdgesNoCopy(v).Copy() +} + +// DownEdges returns the vertices connected from the inward edges to Vertex v. func (g *Graph) DownEdges(v Vertex) Set { + return g.downEdgesNoCopy(v).Copy() +} + +// downEdgesNoCopy returns the outward edges from the source Vertex v as a Set. +// This Set is the same as used internally bu the Graph to prevent a copy, and +// must not be modified by the caller. +func (g *Graph) downEdgesNoCopy(v Vertex) Set { g.init() return g.downEdges[hashcode(v)] } -// UpEdges returns the inward edges to the destination Vertex v. -func (g *Graph) UpEdges(v Vertex) Set { +// upEdgesNoCopy returns the inward edges to the destination Vertex v as a Set. +// This Set is the same as used internally bu the Graph to prevent a copy, and +// must not be modified by the caller. +func (g *Graph) upEdgesNoCopy(v Vertex) Set { g.init() return g.upEdges[hashcode(v)] } diff --git a/dag/graph_test.go b/dag/graph_test.go index 297974431..76c47641d 100644 --- a/dag/graph_test.go +++ b/dag/graph_test.go @@ -170,6 +170,42 @@ func TestGraphEdgesTo(t *testing.T) { } } +func TestGraphUpdownEdges(t *testing.T) { + // Verify that we can't inadvertently modify the internal graph sets + var g Graph + g.Add(1) + g.Add(2) + g.Add(3) + g.Connect(BasicEdge(1, 2)) + g.Connect(BasicEdge(2, 3)) + + up := g.UpEdges(2) + if up.Len() != 1 || !up.Include(1) { + t.Fatalf("expected only an up edge of '1', got %#v", up) + } + // modify the up set + up.Add(9) + + orig := g.UpEdges(2) + diff := up.Difference(orig) + if diff.Len() != 1 || !diff.Include(9) { + t.Fatalf("expected a diff of only '9', got %#v", diff) + } + + down := g.DownEdges(2) + if down.Len() != 1 || !down.Include(3) { + t.Fatalf("expected only a down edge of '3', got %#v", down) + } + // modify the down set + down.Add(8) + + orig = g.DownEdges(2) + diff = down.Difference(orig) + if diff.Len() != 1 || !diff.Include(8) { + t.Fatalf("expected a diff of only '8', got %#v", diff) + } +} + type hashVertex struct { code interface{} } diff --git a/dag/set.go b/dag/set.go index f3fd704ba..c5c1af120 100644 --- a/dag/set.go +++ b/dag/set.go @@ -103,3 +103,12 @@ func (s Set) List() []interface{} { return r } + +// Copy returns a shallow copy of the set. +func (s Set) Copy() Set { + c := make(Set) + for k, v := range s { + c[k] = v + } + return c +} diff --git a/dag/set_test.go b/dag/set_test.go index 63b72e323..36bd6a65b 100644 --- a/dag/set_test.go +++ b/dag/set_test.go @@ -99,3 +99,23 @@ func TestSetFilter(t *testing.T) { }) } } + +func TestSetCopy(t *testing.T) { + a := make(Set) + a.Add(1) + a.Add(2) + + b := a.Copy() + b.Add(3) + + diff := b.Difference(a) + + if diff.Len() != 1 { + t.Fatalf("expected single diff value, got %#v", diff) + } + + if !diff.Include(3) { + t.Fatalf("diff does not contain 3, got %#v", diff) + } + +} diff --git a/dag/tarjan.go b/dag/tarjan.go index 330abd589..fb4d4a773 100644 --- a/dag/tarjan.go +++ b/dag/tarjan.go @@ -24,7 +24,7 @@ func stronglyConnected(acct *sccAcct, g *Graph, v Vertex) int { index := acct.visit(v) minIdx := index - for _, raw := range g.DownEdges(v) { + for _, raw := range g.downEdgesNoCopy(v) { target := raw.(Vertex) targetIdx := acct.VertexIndex[target] diff --git a/terraform/transform_output.go b/terraform/transform_output.go index 4d51dabd6..b926b2fd1 100644 --- a/terraform/transform_output.go +++ b/terraform/transform_output.go @@ -86,10 +86,7 @@ func (t *destroyRootOutputTransformer) Transform(g *Graph) error { log.Printf("[TRACE] creating %s", node.Name()) g.Add(node) - deps, err := g.Descendents(v) - if err != nil { - return err - } + deps := g.UpEdges(v) // the destroy node must depend on the eval node deps.Add(v)