diff --git a/dag/dag.go b/dag/dag.go index f2716257b..22b279524 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -2,6 +2,7 @@ package dag import ( "fmt" + "strings" "sync" ) @@ -51,7 +52,17 @@ func (g *AcyclicGraph) Validate() error { } } if len(cycles) > 0 { - return fmt.Errorf("cycles: %#v", cycles) + cyclesStr := make([]string, len(cycles)) + for i, cycle := range cycles { + cycleStr := make([]string, len(cycle)) + for j, vertex := range cycle { + cycleStr[j] = VertexName(vertex) + } + + cyclesStr[i] = strings.Join(cycleStr, ", ") + } + + return fmt.Errorf("cycles: %s", cyclesStr) } return nil @@ -60,46 +71,49 @@ func (g *AcyclicGraph) Validate() error { // Walk walks the graph, calling your callback as each node is visited. // This will walk nodes in parallel if it can. func (g *AcyclicGraph) Walk(cb WalkFunc) error { - // We require a root to walk. - root, err := g.Root() - if err != nil { - return err - } + // Cache the vertices since we use it multiple times + vertices := g.Vertices() // Build the waitgroup that signals when we're done var wg sync.WaitGroup - wg.Add(g.vertices.Len()) + wg.Add(len(vertices)) doneCh := make(chan struct{}) go func() { defer close(doneCh) wg.Wait() }() - // Start walking! - visitCh := make(chan Vertex, g.vertices.Len()) - visitCh <- root - for { - select { - case v := <-visitCh: - go g.walkVertex(v, cb, visitCh, &wg) - case <-doneCh: - goto WALKDONE + // The map of channels to watch to wait for vertices to finish + vertMap := make(map[Vertex]chan struct{}) + for _, v := range vertices { + vertMap[v] = make(chan struct{}) + } + for _, v := range vertices { + // Get the list of channels to wait on + deps := g.DownEdges(v).List() + depChs := make([]<-chan struct{}, len(deps)) + for i, dep := range deps { + depChs[i] = vertMap[dep.(Vertex)] } + + // Get our channel + ourCh := vertMap[v] + + // Start the goroutine + go func(v Vertex, doneCh chan<- struct{}, chs []<-chan struct{}) { + defer close(doneCh) + defer wg.Done() + + // Wait on all our dependencies + for _, ch := range chs { + <-ch + } + + // Call our callback + cb(v) + }(v, ourCh, depChs) } -WALKDONE: + <-doneCh return nil } - -func (g *AcyclicGraph) walkVertex( - v Vertex, cb WalkFunc, nextCh chan<- Vertex, wg *sync.WaitGroup) { - defer wg.Done() - - // Call the callback on this vertex - cb(v) - - // Walk all the children in parallel - for _, v := range g.DownEdges(v).List() { - nextCh <- v.(Vertex) - } -} diff --git a/dag/dag_test.go b/dag/dag_test.go index e607c7ba1..fac5d37f4 100644 --- a/dag/dag_test.go +++ b/dag/dag_test.go @@ -95,8 +95,8 @@ func TestAcyclicGraphWalk(t *testing.T) { } expected := [][]Vertex{ - {3, 1, 2}, - {3, 2, 1}, + {1, 2, 3}, + {2, 1, 3}, } for _, e := range expected { if reflect.DeepEqual(visits, e) { diff --git a/dag/tarjan.go b/dag/tarjan.go index 3475dda25..9d8b25ce2 100644 --- a/dag/tarjan.go +++ b/dag/tarjan.go @@ -6,71 +6,102 @@ package dag // use. func StronglyConnected(g *Graph) [][]Vertex { vs := g.Vertices() - data := tarjanData{ - index: make(map[interface{}]int), - stack: make([]*tarjanVertex, 0, len(vs)), - vertices: make([]*tarjanVertex, 0, len(vs)), + acct := sccAcct{ + NextIndex: 1, + VertexIndex: make(map[Vertex]int, len(vs)), } - for _, v := range vs { - if _, ok := data.index[v]; !ok { - strongConnect(g, v, &data) + // Recurse on any non-visited nodes + if acct.VertexIndex[v] == 0 { + stronglyConnected(&acct, g, v) } } - - return data.result + return acct.SCC } -type tarjanData struct { - index map[interface{}]int - result [][]Vertex - stack []*tarjanVertex - vertices []*tarjanVertex -} +func stronglyConnected(acct *sccAcct, g *Graph, v Vertex) int { + // Initial vertex visit + index := acct.visit(v) + minIdx := index -type tarjanVertex struct { - V Vertex - Lowlink int - Index int - Stack bool -} - -func strongConnect(g *Graph, v Vertex, data *tarjanData) *tarjanVertex { - index := len(data.index) - data.index[v] = index - tv := &tarjanVertex{V: v, Lowlink: index, Index: index, Stack: true} - data.stack = append(data.stack, tv) - data.vertices = append(data.vertices, tv) - - for _, raw := range g.downEdges[v].List() { + for _, raw := range g.DownEdges(v).List() { target := raw.(Vertex) + targetIdx := acct.VertexIndex[target] - if idx, ok := data.index[target]; !ok { - if tv2 := strongConnect(g, target, data); tv2.Lowlink < tv.Lowlink { - tv.Lowlink = tv2.Lowlink - } - } else if data.vertices[idx].Stack { - if idx < tv.Lowlink { - tv.Lowlink = idx - } + // Recurse on successor if not yet visited + if targetIdx == 0 { + minIdx = min(minIdx, stronglyConnected(acct, g, target)) + } else if acct.inStack(target) { + // Check if the vertex is in the stack + minIdx = min(minIdx, targetIdx) } } - if tv.Lowlink == index { - vs := make([]Vertex, 0, 2) - for i := len(data.stack) - 1; i >= 0; i-- { - v := data.stack[i] - data.stack[i] = nil - data.stack = data.stack[:i] - data.vertices[data.index[v]].Stack = false - vs = append(vs, v.V) - if data.index[v] == i { + // Pop the strongly connected components off the stack if + // this is a root vertex + if index == minIdx { + var scc []Vertex + for { + v2 := acct.pop() + scc = append(scc, v2) + if v2 == v { break } } - data.result = append(data.result, vs) + acct.SCC = append(acct.SCC, scc) } - return tv + return minIdx +} + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +// sccAcct is used ot pass around accounting information for +// the StronglyConnectedComponents algorithm +type sccAcct struct { + NextIndex int + VertexIndex map[Vertex]int + Stack []Vertex + SCC [][]Vertex +} + +// visit assigns an index and pushes a vertex onto the stack +func (s *sccAcct) visit(v Vertex) int { + idx := s.NextIndex + s.VertexIndex[v] = idx + s.NextIndex++ + s.push(v) + return idx +} + +// push adds a vertex to the stack +func (s *sccAcct) push(n Vertex) { + s.Stack = append(s.Stack, n) +} + +// pop removes a vertex from the stack +func (s *sccAcct) pop() Vertex { + n := len(s.Stack) + if n == 0 { + return nil + } + vertex := s.Stack[n-1] + s.Stack = s.Stack[:n-1] + return vertex +} + +// inStack checks if a vertex is in the stack +func (s *sccAcct) inStack(needle Vertex) bool { + for _, n := range s.Stack { + if n == needle { + return true + } + } + return false }