dag: walk should be able to be halted

This commit is contained in:
Mitchell Hashimoto 2015-02-08 17:06:17 -08:00
parent 28a23a45f4
commit 54fd742ef6
2 changed files with 91 additions and 18 deletions

View File

@ -15,7 +15,7 @@ type AcyclicGraph struct {
} }
// WalkFunc is the callback used for walking the graph. // WalkFunc is the callback used for walking the graph.
type WalkFunc func(Vertex) type WalkFunc func(Vertex) error
// Root returns the root of the DAG, or an error. // Root returns the root of the DAG, or an error.
// //
@ -79,7 +79,8 @@ func (g *AcyclicGraph) Validate() error {
} }
// Walk walks the graph, calling your callback as each node is visited. // Walk walks the graph, calling your callback as each node is visited.
// This will walk nodes in parallel if it can. // This will walk nodes in parallel if it can. Because the walk is done
// in parallel, the error returned will be a multierror.
func (g *AcyclicGraph) Walk(cb WalkFunc) error { func (g *AcyclicGraph) Walk(cb WalkFunc) error {
// Cache the vertices since we use it multiple times // Cache the vertices since we use it multiple times
vertices := g.Vertices() vertices := g.Vertices()
@ -98,32 +99,65 @@ func (g *AcyclicGraph) Walk(cb WalkFunc) error {
for _, v := range vertices { for _, v := range vertices {
vertMap[v] = make(chan struct{}) vertMap[v] = make(chan struct{})
} }
// The map of whether a vertex errored or not during the walk
var errLock sync.Mutex
var errs error
errMap := make(map[Vertex]bool)
for _, v := range vertices { for _, v := range vertices {
// Get the list of channels to wait on // Build our list of dependencies and the list of channels to
deps := g.DownEdges(v).List() // wait on until we start executing for this vertex.
depsRaw := g.DownEdges(v).List()
deps := make([]Vertex, len(depsRaw))
depChs := make([]<-chan struct{}, len(deps)) depChs := make([]<-chan struct{}, len(deps))
for i, dep := range deps { for i, raw := range depsRaw {
depChs[i] = vertMap[dep.(Vertex)] deps[i] = raw.(Vertex)
depChs[i] = vertMap[deps[i]]
} }
// Get our channel // Get our channel so that we can close it when we're done
ourCh := vertMap[v] ourCh := vertMap[v]
// Start the goroutine // Start the goroutine to wait for our dependencies
go func(v Vertex, doneCh chan<- struct{}, chs []<-chan struct{}) { readyCh := make(chan bool)
defer close(doneCh) go func(deps []Vertex, chs []<-chan struct{}, readyCh chan<- bool) {
defer wg.Done() // First wait for all the dependencies
// Wait on all our dependencies
for _, ch := range chs { for _, ch := range chs {
<-ch <-ch
} }
// Call our callback // Then, check the map to see if any of our dependencies failed
cb(v) errLock.Lock()
}(v, ourCh, depChs) defer errLock.Unlock()
for _, dep := range deps {
if errMap[dep] {
readyCh <- false
return
}
}
readyCh <- true
}(deps, depChs, readyCh)
// Start the goroutine that executes
go func(v Vertex, doneCh chan<- struct{}, readyCh <-chan bool) {
defer close(doneCh)
defer wg.Done()
var err error
if ready := <-readyCh; ready {
err = cb(v)
}
errLock.Lock()
defer errLock.Unlock()
if err != nil {
errMap[v] = true
errs = multierror.Append(errs, err)
}
}(v, ourCh, readyCh)
} }
<-doneCh <-doneCh
return nil return errs
} }

View File

@ -1,6 +1,7 @@
package dag package dag
import ( import (
"fmt"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
@ -96,10 +97,11 @@ func TestAcyclicGraphWalk(t *testing.T) {
var visits []Vertex var visits []Vertex
var lock sync.Mutex var lock sync.Mutex
err := g.Walk(func(v Vertex) { err := g.Walk(func(v Vertex) error {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
visits = append(visits, v) visits = append(visits, v)
return nil
}) })
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
@ -117,3 +119,40 @@ func TestAcyclicGraphWalk(t *testing.T) {
t.Fatalf("bad: %#v", visits) t.Fatalf("bad: %#v", visits)
} }
func TestAcyclicGraphWalk_error(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(3, 2))
g.Connect(BasicEdge(3, 1))
var visits []Vertex
var lock sync.Mutex
err := g.Walk(func(v Vertex) error {
lock.Lock()
defer lock.Unlock()
if v == 2 {
return fmt.Errorf("error")
}
visits = append(visits, v)
return nil
})
if err == nil {
t.Fatal("should error")
}
expected := [][]Vertex{
{1},
}
for _, e := range expected {
if reflect.DeepEqual(visits, e) {
return
}
}
t.Fatalf("bad: %#v", visits)
}