diff --git a/dag/graph.go b/dag/graph.go index 4d4450cc4..b1cbd84d4 100644 --- a/dag/graph.go +++ b/dag/graph.go @@ -26,6 +26,9 @@ type NamedVertex interface { Name() string } +// WalkFunc is the callback used for walking the graph. +type WalkFunc func(Vertex) + // Vertices returns the list of all the vertices in the graph. func (g *Graph) Vertices() []Vertex { return g.vertices @@ -78,6 +81,10 @@ func (g *Graph) Connect(edge Edge) { s.Add(source) } +// Walk walks the graph, calling your callback as each node is visited. +func (g *Graph) Walk(cb WalkFunc) { +} + // String outputs some human-friendly output for the graph structure. func (g *Graph) String() string { var buf bytes.Buffer diff --git a/dag/tarjan.go b/dag/tarjan.go new file mode 100644 index 000000000..b5d5b5c8a --- /dev/null +++ b/dag/tarjan.go @@ -0,0 +1,76 @@ +package dag + +// StronglyConnected returns the list of strongly connected components +// within the Graph g. This information is primarily used by this package +// for cycle detection, but strongly connected components have widespread +// use. +func StronglyConnected(g *Graph) [][]Vertex { + vs := g.Vertices() + data := tarjanData{ + index: make(map[interface{}]int), + stack: make([]*tarjanVertex, 0, len(vs)), + vertices: make([]*tarjanVertex, 0, len(vs)), + } + + for _, v := range vs { + if _, ok := data.index[v]; !ok { + strongConnect(g, v, &data) + } + } + + return data.result +} + +type tarjanData struct { + index map[interface{}]int + result [][]Vertex + stack []*tarjanVertex + vertices []*tarjanVertex +} + +type tarjanVertex struct { + V Vertex + Lowlink int + Index int + Stack bool +} + +func strongConnect(g *Graph, v Vertex, data *tarjanData) *tarjanVertex { + index := len(data.index) + data.index[v] = index + tv := &tarjanVertex{V: v, Lowlink: index, Index: index, Stack: true} + data.stack = append(data.stack, tv) + data.vertices = append(data.vertices, tv) + + for _, raw := range g.downEdges[v].List() { + target := raw.(Vertex) + + if idx, ok := data.index[target]; !ok { + if tv2 := strongConnect(g, target, data); tv2.Lowlink < tv.Lowlink { + tv.Lowlink = tv2.Lowlink + } + } else if data.vertices[idx].Stack { + if idx < tv.Lowlink { + tv.Lowlink = idx + } + } + } + + if tv.Lowlink == index { + vs := make([]Vertex, 0, 2) + for i := len(data.stack) - 1; ; i-- { + v := data.stack[i] + data.stack[i] = nil + data.stack = data.stack[:i] + data.vertices[data.index[v]].Stack = false + vs = append(vs, v.V) + if data.index[v] == i { + break + } + } + + data.result = append(data.result, vs) + } + + return tv +} diff --git a/dag/tarjan_test.go b/dag/tarjan_test.go new file mode 100644 index 000000000..b5dcd3f0f --- /dev/null +++ b/dag/tarjan_test.go @@ -0,0 +1,85 @@ +package dag + +import ( + "bytes" + "fmt" + "strings" + "testing" +) + +func TestGraphStronglyConnected(t *testing.T) { + var g Graph + g.Add(1) + g.Add(2) + g.Connect(BasicEdge(1, 2)) + g.Connect(BasicEdge(2, 1)) + + actual := strings.TrimSpace(testSCCStr(StronglyConnected(&g))) + expected := strings.TrimSpace(testGraphStronglyConnectedStr) + if actual != expected { + t.Fatalf("bad: %s", actual) + } +} + +func TestGraphStronglyConnected_two(t *testing.T) { + var g Graph + g.Add(1) + g.Add(2) + g.Connect(BasicEdge(1, 2)) + g.Connect(BasicEdge(2, 1)) + g.Add(3) + + actual := strings.TrimSpace(testSCCStr(StronglyConnected(&g))) + expected := strings.TrimSpace(testGraphStronglyConnectedTwoStr) + if actual != expected { + t.Fatalf("bad: %s", actual) + } +} + +func TestGraphStronglyConnected_three(t *testing.T) { + var g Graph + g.Add(1) + g.Add(2) + g.Connect(BasicEdge(1, 2)) + g.Connect(BasicEdge(2, 1)) + g.Add(3) + g.Add(4) + g.Add(5) + g.Add(6) + g.Connect(BasicEdge(4, 5)) + g.Connect(BasicEdge(5, 6)) + g.Connect(BasicEdge(6, 4)) + + actual := strings.TrimSpace(testSCCStr(StronglyConnected(&g))) + expected := strings.TrimSpace(testGraphStronglyConnectedThreeStr) + if actual != expected { + t.Fatalf("bad: %s", actual) + } +} + +func testSCCStr(list [][]Vertex) string { + var buf bytes.Buffer + for _, vs := range list { + result := make([]string, len(vs)) + for i, v := range vs { + result[i] = vertName(v) + } + + buf.WriteString(fmt.Sprintf("%s\n", strings.Join(result, ","))) + } + + return buf.String() +} + +const testGraphStronglyConnectedStr = `2,1` + +const testGraphStronglyConnectedTwoStr = ` +2,1 +3 +` + +const testGraphStronglyConnectedThreeStr = ` +2,1 +3 +6,5,4 +`