dag: Walk
This commit is contained in:
parent
cfa3d89265
commit
e94c43e0dc
46
dag/dag.go
46
dag/dag.go
|
@ -2,6 +2,7 @@ package dag
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AcyclicGraph is a specialization of Graph that cannot have cycles. With
|
// AcyclicGraph is a specialization of Graph that cannot have cycles. With
|
||||||
|
@ -37,5 +38,48 @@ func (g *AcyclicGraph) Root() (Vertex, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk walks the graph, calling your callback as each node is visited.
|
// Walk walks the graph, calling your callback as each node is visited.
|
||||||
func (g *AcyclicGraph) Walk(cb WalkFunc) {
|
// This will walk nodes in parallel if it can.
|
||||||
|
func (g *AcyclicGraph) Walk(cb WalkFunc) error {
|
||||||
|
// We require a root to walk.
|
||||||
|
root, err := g.Root()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the waitgroup that signals when we're done
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(g.vertices.Len())
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(doneCh)
|
||||||
|
wg.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start walking!
|
||||||
|
visitCh := make(chan Vertex, g.vertices.Len())
|
||||||
|
visitCh <- root
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case v := <-visitCh:
|
||||||
|
go g.walkVertex(v, cb, visitCh, &wg)
|
||||||
|
case <-doneCh:
|
||||||
|
goto WALKDONE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
WALKDONE:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *AcyclicGraph) walkVertex(
|
||||||
|
v Vertex, cb WalkFunc, nextCh chan<- Vertex, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Call the callback on this vertex
|
||||||
|
cb(v)
|
||||||
|
|
||||||
|
// Walk all the children in parallel
|
||||||
|
for _, v := range g.DownEdges(v).List() {
|
||||||
|
nextCh <- v.(Vertex)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package dag
|
package dag
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,3 +46,35 @@ func TestAcyclicGraphRoot_multiple(t *testing.T) {
|
||||||
t.Fatal("should error")
|
t.Fatal("should error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAcyclicGraphWalk(t *testing.T) {
|
||||||
|
var g AcyclicGraph
|
||||||
|
g.Add(1)
|
||||||
|
g.Add(2)
|
||||||
|
g.Add(3)
|
||||||
|
g.Connect(BasicEdge(3, 2))
|
||||||
|
g.Connect(BasicEdge(3, 1))
|
||||||
|
|
||||||
|
var visits []Vertex
|
||||||
|
var lock sync.Mutex
|
||||||
|
err := g.Walk(func(v Vertex) {
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
visits = append(visits, v)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := [][]Vertex{
|
||||||
|
{3, 1, 2},
|
||||||
|
{3, 2, 1},
|
||||||
|
}
|
||||||
|
for _, e := range expected {
|
||||||
|
if reflect.DeepEqual(visits, e) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatalf("bad: %#v", visits)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue