diff --git a/dag/graph.go b/dag/graph.go index 263802d21..7fac2a534 100644 --- a/dag/graph.go +++ b/dag/graph.go @@ -73,6 +73,30 @@ func (g *Graph) Remove(v Vertex) Vertex { return nil } +// Replace replaces the original Vertex with replacement. If the original +// does not exist within the graph, then false is returned. Otherwise, true +// is returned. +func (g *Graph) Replace(original, replacement Vertex) bool { + // If we don't have the original, we can't do anything + if !g.vertices.Include(original) { + return false + } + + // Add our new vertex, then copy all the edges + g.Add(replacement) + for _, target := range g.DownEdges(original).List() { + g.Connect(BasicEdge(replacement, target)) + } + for _, source := range g.UpEdges(original).List() { + g.Connect(BasicEdge(source, replacement)) + } + + // Remove our old vertex, which will also remove all the edges + g.Remove(original) + + return true +} + // RemoveEdge removes an edge from the graph. func (g *Graph) RemoveEdge(edge Edge) { g.once.Do(g.init) diff --git a/dag/graph_test.go b/dag/graph_test.go index b7a9ae537..e67535659 100644 --- a/dag/graph_test.go +++ b/dag/graph_test.go @@ -47,12 +47,29 @@ func TestGraph_remove(t *testing.T) { } } +func TestGraph_replace(t *testing.T) { + var g Graph + g.Add(1) + g.Add(2) + g.Add(3) + g.Connect(BasicEdge(1, 2)) + g.Connect(BasicEdge(2, 3)) + g.Replace(2, 42) + + actual := strings.TrimSpace(g.String()) + expected := strings.TrimSpace(testGraphReplaceStr) + if actual != expected { + t.Fatalf("bad: %s", actual) + } +} + const testGraphBasicStr = ` 1 3 2 3 ` + const testGraphEmptyStr = ` 1 2 @@ -63,3 +80,11 @@ const testGraphRemoveStr = ` 1 2 ` + +const testGraphReplaceStr = ` +1 + 42 +3 +42 + 3 +`