dag: fix walk order issue, scc issues

This commit is contained in:
Mitchell Hashimoto 2015-02-04 19:38:38 -05:00
parent d9a964f44c
commit e86698c50d
3 changed files with 126 additions and 81 deletions

View File

@ -2,6 +2,7 @@ package dag
import ( import (
"fmt" "fmt"
"strings"
"sync" "sync"
) )
@ -51,7 +52,17 @@ func (g *AcyclicGraph) Validate() error {
} }
} }
if len(cycles) > 0 { if len(cycles) > 0 {
return fmt.Errorf("cycles: %#v", cycles) cyclesStr := make([]string, len(cycles))
for i, cycle := range cycles {
cycleStr := make([]string, len(cycle))
for j, vertex := range cycle {
cycleStr[j] = VertexName(vertex)
}
cyclesStr[i] = strings.Join(cycleStr, ", ")
}
return fmt.Errorf("cycles: %s", cyclesStr)
} }
return nil return nil
@ -60,46 +71,49 @@ func (g *AcyclicGraph) Validate() error {
// Walk walks the graph, calling your callback as each node is visited. // Walk walks the graph, calling your callback as each node is visited.
// This will walk nodes in parallel if it can. // This will walk nodes in parallel if it can.
func (g *AcyclicGraph) Walk(cb WalkFunc) error { func (g *AcyclicGraph) Walk(cb WalkFunc) error {
// We require a root to walk. // Cache the vertices since we use it multiple times
root, err := g.Root() vertices := g.Vertices()
if err != nil {
return err
}
// Build the waitgroup that signals when we're done // Build the waitgroup that signals when we're done
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(g.vertices.Len()) wg.Add(len(vertices))
doneCh := make(chan struct{}) doneCh := make(chan struct{})
go func() { go func() {
defer close(doneCh) defer close(doneCh)
wg.Wait() wg.Wait()
}() }()
// Start walking! // The map of channels to watch to wait for vertices to finish
visitCh := make(chan Vertex, g.vertices.Len()) vertMap := make(map[Vertex]chan struct{})
visitCh <- root for _, v := range vertices {
for { vertMap[v] = make(chan struct{})
select { }
case v := <-visitCh: for _, v := range vertices {
go g.walkVertex(v, cb, visitCh, &wg) // Get the list of channels to wait on
case <-doneCh: deps := g.DownEdges(v).List()
goto WALKDONE depChs := make([]<-chan struct{}, len(deps))
for i, dep := range deps {
depChs[i] = vertMap[dep.(Vertex)]
} }
// Get our channel
ourCh := vertMap[v]
// Start the goroutine
go func(v Vertex, doneCh chan<- struct{}, chs []<-chan struct{}) {
defer close(doneCh)
defer wg.Done()
// Wait on all our dependencies
for _, ch := range chs {
<-ch
}
// Call our callback
cb(v)
}(v, ourCh, depChs)
} }
WALKDONE: <-doneCh
return nil return nil
} }
func (g *AcyclicGraph) walkVertex(
v Vertex, cb WalkFunc, nextCh chan<- Vertex, wg *sync.WaitGroup) {
defer wg.Done()
// Call the callback on this vertex
cb(v)
// Walk all the children in parallel
for _, v := range g.DownEdges(v).List() {
nextCh <- v.(Vertex)
}
}

View File

@ -95,8 +95,8 @@ func TestAcyclicGraphWalk(t *testing.T) {
} }
expected := [][]Vertex{ expected := [][]Vertex{
{3, 1, 2}, {1, 2, 3},
{3, 2, 1}, {2, 1, 3},
} }
for _, e := range expected { for _, e := range expected {
if reflect.DeepEqual(visits, e) { if reflect.DeepEqual(visits, e) {

View File

@ -6,71 +6,102 @@ package dag
// use. // use.
func StronglyConnected(g *Graph) [][]Vertex { func StronglyConnected(g *Graph) [][]Vertex {
vs := g.Vertices() vs := g.Vertices()
data := tarjanData{ acct := sccAcct{
index: make(map[interface{}]int), NextIndex: 1,
stack: make([]*tarjanVertex, 0, len(vs)), VertexIndex: make(map[Vertex]int, len(vs)),
vertices: make([]*tarjanVertex, 0, len(vs)),
} }
for _, v := range vs { for _, v := range vs {
if _, ok := data.index[v]; !ok { // Recurse on any non-visited nodes
strongConnect(g, v, &data) if acct.VertexIndex[v] == 0 {
stronglyConnected(&acct, g, v)
} }
} }
return acct.SCC
return data.result
} }
type tarjanData struct { func stronglyConnected(acct *sccAcct, g *Graph, v Vertex) int {
index map[interface{}]int // Initial vertex visit
result [][]Vertex index := acct.visit(v)
stack []*tarjanVertex minIdx := index
vertices []*tarjanVertex
}
type tarjanVertex struct { for _, raw := range g.DownEdges(v).List() {
V Vertex
Lowlink int
Index int
Stack bool
}
func strongConnect(g *Graph, v Vertex, data *tarjanData) *tarjanVertex {
index := len(data.index)
data.index[v] = index
tv := &tarjanVertex{V: v, Lowlink: index, Index: index, Stack: true}
data.stack = append(data.stack, tv)
data.vertices = append(data.vertices, tv)
for _, raw := range g.downEdges[v].List() {
target := raw.(Vertex) target := raw.(Vertex)
targetIdx := acct.VertexIndex[target]
if idx, ok := data.index[target]; !ok { // Recurse on successor if not yet visited
if tv2 := strongConnect(g, target, data); tv2.Lowlink < tv.Lowlink { if targetIdx == 0 {
tv.Lowlink = tv2.Lowlink minIdx = min(minIdx, stronglyConnected(acct, g, target))
} } else if acct.inStack(target) {
} else if data.vertices[idx].Stack { // Check if the vertex is in the stack
if idx < tv.Lowlink { minIdx = min(minIdx, targetIdx)
tv.Lowlink = idx
}
} }
} }
if tv.Lowlink == index { // Pop the strongly connected components off the stack if
vs := make([]Vertex, 0, 2) // this is a root vertex
for i := len(data.stack) - 1; i >= 0; i-- { if index == minIdx {
v := data.stack[i] var scc []Vertex
data.stack[i] = nil for {
data.stack = data.stack[:i] v2 := acct.pop()
data.vertices[data.index[v]].Stack = false scc = append(scc, v2)
vs = append(vs, v.V) if v2 == v {
if data.index[v] == i {
break break
} }
} }
data.result = append(data.result, vs) acct.SCC = append(acct.SCC, scc)
} }
return tv return minIdx
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}
// sccAcct is used ot pass around accounting information for
// the StronglyConnectedComponents algorithm
type sccAcct struct {
NextIndex int
VertexIndex map[Vertex]int
Stack []Vertex
SCC [][]Vertex
}
// visit assigns an index and pushes a vertex onto the stack
func (s *sccAcct) visit(v Vertex) int {
idx := s.NextIndex
s.VertexIndex[v] = idx
s.NextIndex++
s.push(v)
return idx
}
// push adds a vertex to the stack
func (s *sccAcct) push(n Vertex) {
s.Stack = append(s.Stack, n)
}
// pop removes a vertex from the stack
func (s *sccAcct) pop() Vertex {
n := len(s.Stack)
if n == 0 {
return nil
}
vertex := s.Stack[n-1]
s.Stack = s.Stack[:n-1]
return vertex
}
// inStack checks if a vertex is in the stack
func (s *sccAcct) inStack(needle Vertex) bool {
for _, n := range s.Stack {
if n == needle {
return true
}
}
return false
} }