dag: walk should be able to be halted
This commit is contained in:
parent
28a23a45f4
commit
54fd742ef6
68
dag/dag.go
68
dag/dag.go
|
@ -15,7 +15,7 @@ type AcyclicGraph struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WalkFunc is the callback used for walking the graph.
|
// WalkFunc is the callback used for walking the graph.
|
||||||
type WalkFunc func(Vertex)
|
type WalkFunc func(Vertex) error
|
||||||
|
|
||||||
// Root returns the root of the DAG, or an error.
|
// Root returns the root of the DAG, or an error.
|
||||||
//
|
//
|
||||||
|
@ -79,7 +79,8 @@ func (g *AcyclicGraph) Validate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk walks the graph, calling your callback as each node is visited.
|
// Walk walks the graph, calling your callback as each node is visited.
|
||||||
// This will walk nodes in parallel if it can.
|
// This will walk nodes in parallel if it can. Because the walk is done
|
||||||
|
// in parallel, the error returned will be a multierror.
|
||||||
func (g *AcyclicGraph) Walk(cb WalkFunc) error {
|
func (g *AcyclicGraph) Walk(cb WalkFunc) error {
|
||||||
// Cache the vertices since we use it multiple times
|
// Cache the vertices since we use it multiple times
|
||||||
vertices := g.Vertices()
|
vertices := g.Vertices()
|
||||||
|
@ -98,32 +99,65 @@ func (g *AcyclicGraph) Walk(cb WalkFunc) error {
|
||||||
for _, v := range vertices {
|
for _, v := range vertices {
|
||||||
vertMap[v] = make(chan struct{})
|
vertMap[v] = make(chan struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The map of whether a vertex errored or not during the walk
|
||||||
|
var errLock sync.Mutex
|
||||||
|
var errs error
|
||||||
|
errMap := make(map[Vertex]bool)
|
||||||
for _, v := range vertices {
|
for _, v := range vertices {
|
||||||
// Get the list of channels to wait on
|
// Build our list of dependencies and the list of channels to
|
||||||
deps := g.DownEdges(v).List()
|
// wait on until we start executing for this vertex.
|
||||||
|
depsRaw := g.DownEdges(v).List()
|
||||||
|
deps := make([]Vertex, len(depsRaw))
|
||||||
depChs := make([]<-chan struct{}, len(deps))
|
depChs := make([]<-chan struct{}, len(deps))
|
||||||
for i, dep := range deps {
|
for i, raw := range depsRaw {
|
||||||
depChs[i] = vertMap[dep.(Vertex)]
|
deps[i] = raw.(Vertex)
|
||||||
|
depChs[i] = vertMap[deps[i]]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get our channel
|
// Get our channel so that we can close it when we're done
|
||||||
ourCh := vertMap[v]
|
ourCh := vertMap[v]
|
||||||
|
|
||||||
// Start the goroutine
|
// Start the goroutine to wait for our dependencies
|
||||||
go func(v Vertex, doneCh chan<- struct{}, chs []<-chan struct{}) {
|
readyCh := make(chan bool)
|
||||||
defer close(doneCh)
|
go func(deps []Vertex, chs []<-chan struct{}, readyCh chan<- bool) {
|
||||||
defer wg.Done()
|
// First wait for all the dependencies
|
||||||
|
|
||||||
// Wait on all our dependencies
|
|
||||||
for _, ch := range chs {
|
for _, ch := range chs {
|
||||||
<-ch
|
<-ch
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call our callback
|
// Then, check the map to see if any of our dependencies failed
|
||||||
cb(v)
|
errLock.Lock()
|
||||||
}(v, ourCh, depChs)
|
defer errLock.Unlock()
|
||||||
|
for _, dep := range deps {
|
||||||
|
if errMap[dep] {
|
||||||
|
readyCh <- false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
readyCh <- true
|
||||||
|
}(deps, depChs, readyCh)
|
||||||
|
|
||||||
|
// Start the goroutine that executes
|
||||||
|
go func(v Vertex, doneCh chan<- struct{}, readyCh <-chan bool) {
|
||||||
|
defer close(doneCh)
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if ready := <-readyCh; ready {
|
||||||
|
err = cb(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
errLock.Lock()
|
||||||
|
defer errLock.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
errMap[v] = true
|
||||||
|
errs = multierror.Append(errs, err)
|
||||||
|
}
|
||||||
|
}(v, ourCh, readyCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
<-doneCh
|
<-doneCh
|
||||||
return nil
|
return errs
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package dag
|
package dag
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -96,10 +97,11 @@ func TestAcyclicGraphWalk(t *testing.T) {
|
||||||
|
|
||||||
var visits []Vertex
|
var visits []Vertex
|
||||||
var lock sync.Mutex
|
var lock sync.Mutex
|
||||||
err := g.Walk(func(v Vertex) {
|
err := g.Walk(func(v Vertex) error {
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
visits = append(visits, v)
|
visits = append(visits, v)
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
|
@ -117,3 +119,40 @@ func TestAcyclicGraphWalk(t *testing.T) {
|
||||||
|
|
||||||
t.Fatalf("bad: %#v", visits)
|
t.Fatalf("bad: %#v", visits)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAcyclicGraphWalk_error(t *testing.T) {
|
||||||
|
var g AcyclicGraph
|
||||||
|
g.Add(1)
|
||||||
|
g.Add(2)
|
||||||
|
g.Add(3)
|
||||||
|
g.Connect(BasicEdge(3, 2))
|
||||||
|
g.Connect(BasicEdge(3, 1))
|
||||||
|
|
||||||
|
var visits []Vertex
|
||||||
|
var lock sync.Mutex
|
||||||
|
err := g.Walk(func(v Vertex) error {
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
if v == 2 {
|
||||||
|
return fmt.Errorf("error")
|
||||||
|
}
|
||||||
|
|
||||||
|
visits = append(visits, v)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("should error")
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := [][]Vertex{
|
||||||
|
{1},
|
||||||
|
}
|
||||||
|
for _, e := range expected {
|
||||||
|
if reflect.DeepEqual(visits, e) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatalf("bad: %#v", visits)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue