dag: replace dag.Walk with our walker
This commit is contained in:
parent
b1aa6fd598
commit
28fff99ea8
94
dag/dag.go
94
dag/dag.go
|
@ -2,11 +2,8 @@ package dag
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
@ -169,94 +166,9 @@ func (g *AcyclicGraph) Cycles() [][]Vertex {
|
|||
func (g *AcyclicGraph) Walk(cb WalkFunc) error {
|
||||
defer g.debug.BeginOperation(typeWalk, "").End("")
|
||||
|
||||
// Cache the vertices since we use it multiple times
|
||||
vertices := g.Vertices()
|
||||
|
||||
// Build the waitgroup that signals when we're done
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(vertices))
|
||||
doneCh := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneCh)
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
// The map of channels to watch to wait for vertices to finish
|
||||
vertMap := make(map[Vertex]chan struct{})
|
||||
for _, v := range vertices {
|
||||
vertMap[v] = make(chan struct{})
|
||||
}
|
||||
|
||||
// The map of whether a vertex errored or not during the walk
|
||||
var errLock sync.Mutex
|
||||
var errs error
|
||||
errMap := make(map[Vertex]bool)
|
||||
for _, v := range vertices {
|
||||
// Build our list of dependencies and the list of channels to
|
||||
// wait on until we start executing for this vertex.
|
||||
deps := AsVertexList(g.DownEdges(v))
|
||||
depChs := make([]<-chan struct{}, len(deps))
|
||||
for i, dep := range deps {
|
||||
depChs[i] = vertMap[dep]
|
||||
}
|
||||
|
||||
// Get our channel so that we can close it when we're done
|
||||
ourCh := vertMap[v]
|
||||
|
||||
// Start the goroutine to wait for our dependencies
|
||||
readyCh := make(chan bool)
|
||||
go func(v Vertex, deps []Vertex, chs []<-chan struct{}, readyCh chan<- bool) {
|
||||
// First wait for all the dependencies
|
||||
for i, ch := range chs {
|
||||
DepSatisfied:
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
break DepSatisfied
|
||||
case <-time.After(time.Second * 5):
|
||||
log.Printf("[DEBUG] vertex %q, waiting for: %q",
|
||||
VertexName(v), VertexName(deps[i]))
|
||||
}
|
||||
}
|
||||
log.Printf("[DEBUG] vertex %q, got dep: %q",
|
||||
VertexName(v), VertexName(deps[i]))
|
||||
}
|
||||
|
||||
// Then, check the map to see if any of our dependencies failed
|
||||
errLock.Lock()
|
||||
defer errLock.Unlock()
|
||||
for _, dep := range deps {
|
||||
if errMap[dep] {
|
||||
errMap[v] = true
|
||||
readyCh <- false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
readyCh <- true
|
||||
}(v, deps, depChs, readyCh)
|
||||
|
||||
// Start the goroutine that executes
|
||||
go func(v Vertex, doneCh chan<- struct{}, readyCh <-chan bool) {
|
||||
defer close(doneCh)
|
||||
defer wg.Done()
|
||||
|
||||
var err error
|
||||
if ready := <-readyCh; ready {
|
||||
err = cb(v)
|
||||
}
|
||||
|
||||
errLock.Lock()
|
||||
defer errLock.Unlock()
|
||||
if err != nil {
|
||||
errMap[v] = true
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
}(v, ourCh, readyCh)
|
||||
}
|
||||
|
||||
<-doneCh
|
||||
return errs
|
||||
w := &walker{Callback: cb, Reverse: true}
|
||||
w.Update(g.vertices, g.edges)
|
||||
return w.Wait()
|
||||
}
|
||||
|
||||
// simple convenience helper for converting a dag.Set to a []Vertex
|
||||
|
|
118
dag/walk.go
118
dag/walk.go
|
@ -1,6 +1,7 @@
|
|||
package dag
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
@ -12,6 +13,9 @@ import (
|
|||
// walker performs a graph walk and supports walk-time changing of vertices
|
||||
// and edges.
|
||||
//
|
||||
// The walk is depth first by default. This can be changed with the Reverse
|
||||
// option.
|
||||
//
|
||||
// A single walker is only valid for one graph walk. After the walk is complete
|
||||
// you must construct a new walker to walk again. State for the walk is never
|
||||
// deleted in case vertices or edges are changed.
|
||||
|
@ -19,6 +23,10 @@ type walker struct {
|
|||
// Callback is what is called for each vertex
|
||||
Callback WalkFunc
|
||||
|
||||
// Reverse, if true, causes the source of an edge to depend on a target.
|
||||
// When false (default), the target depends on the source.
|
||||
Reverse bool
|
||||
|
||||
// changeLock must be held to modify any of the fields below. Only Update
|
||||
// should modify these fields. Modifying them outside of Update can cause
|
||||
// serious problems.
|
||||
|
@ -44,7 +52,7 @@ type walkerVertex struct {
|
|||
|
||||
// Dependency information. Any changes to any of these fields requires
|
||||
// holding DepsLock.
|
||||
DepsCh chan struct{}
|
||||
DepsCh chan bool
|
||||
DepsUpdateCh chan struct{}
|
||||
DepsLock sync.Mutex
|
||||
|
||||
|
@ -54,6 +62,11 @@ type walkerVertex struct {
|
|||
depsCancelCh chan struct{}
|
||||
}
|
||||
|
||||
// errWalkUpstream is used in the errMap of a walk to note that an upstream
|
||||
// dependency failed so this vertex wasn't run. This is not shown in the final
|
||||
// user-returned error.
|
||||
var errWalkUpstream = errors.New("upstream dependency failed")
|
||||
|
||||
// Wait waits for the completion of the walk and returns any errors (
|
||||
// in the form of a multierror) that occurred. Update should be called
|
||||
// to populate the walk with vertices and edges prior to calling this.
|
||||
|
@ -72,8 +85,10 @@ func (w *walker) Wait() error {
|
|||
// Build the error
|
||||
var result error
|
||||
for v, err := range w.errMap {
|
||||
result = multierror.Append(result, fmt.Errorf(
|
||||
"%s: %s", VertexName(v), err))
|
||||
if err != nil && err != errWalkUpstream {
|
||||
result = multierror.Append(result, fmt.Errorf(
|
||||
"%s: %s", VertexName(v), err))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
@ -116,12 +131,12 @@ func (w *walker) Update(v, e *Set) {
|
|||
info := &walkerVertex{
|
||||
DoneCh: make(chan struct{}),
|
||||
CancelCh: make(chan struct{}),
|
||||
DepsCh: make(chan struct{}),
|
||||
DepsCh: make(chan bool, 1),
|
||||
deps: make(map[Vertex]chan struct{}),
|
||||
}
|
||||
|
||||
// Close the deps channel immediately so it passes
|
||||
close(info.DepsCh)
|
||||
// Pass dependencies immediately assuming we have no edges
|
||||
info.DepsCh <- true
|
||||
|
||||
// Add it to the map and kick off the walk
|
||||
w.vertexMap[v] = info
|
||||
|
@ -153,12 +168,7 @@ func (w *walker) Update(v, e *Set) {
|
|||
var changedDeps Set
|
||||
for _, raw := range newEdges.List() {
|
||||
edge := raw.(Edge)
|
||||
|
||||
// waiter is the vertex that is "waiting" on this edge
|
||||
waiter := edge.Target()
|
||||
|
||||
// dep is the dependency we're waiting on
|
||||
dep := edge.Source()
|
||||
waiter, dep := w.edgeParts(edge)
|
||||
|
||||
// Get the info for the waiter
|
||||
waiterInfo, ok := w.vertexMap[waiter]
|
||||
|
@ -189,12 +199,7 @@ func (w *walker) Update(v, e *Set) {
|
|||
// Process reoved edges
|
||||
for _, raw := range oldEdges.List() {
|
||||
edge := raw.(Edge)
|
||||
|
||||
// waiter is the vertex that is "waiting" on this edge
|
||||
waiter := edge.Target()
|
||||
|
||||
// dep is the dependency we're waiting on
|
||||
dep := edge.Source()
|
||||
waiter, dep := w.edgeParts(edge)
|
||||
|
||||
// Get the info for the waiter
|
||||
waiterInfo, ok := w.vertexMap[waiter]
|
||||
|
@ -226,7 +231,7 @@ func (w *walker) Update(v, e *Set) {
|
|||
}
|
||||
|
||||
// Create a new done channel
|
||||
doneCh := make(chan struct{})
|
||||
doneCh := make(chan bool, 1)
|
||||
|
||||
// Create the channel we close for cancellation
|
||||
cancelCh := make(chan struct{})
|
||||
|
@ -252,6 +257,10 @@ func (w *walker) Update(v, e *Set) {
|
|||
}
|
||||
info.depsCancelCh = cancelCh
|
||||
|
||||
log.Printf(
|
||||
"[DEBUG] dag/walk: dependencies changed for %q, sending new deps",
|
||||
VertexName(v))
|
||||
|
||||
// Start the waiter
|
||||
go w.waitDeps(v, deps, doneCh, cancelCh)
|
||||
}
|
||||
|
@ -264,6 +273,16 @@ func (w *walker) Update(v, e *Set) {
|
|||
}
|
||||
}
|
||||
|
||||
// edgeParts returns the waiter and the dependency, in that order.
|
||||
// The waiter is waiting on the dependency.
|
||||
func (w *walker) edgeParts(e Edge) (Vertex, Vertex) {
|
||||
if w.Reverse {
|
||||
return e.Source(), e.Target()
|
||||
}
|
||||
|
||||
return e.Target(), e.Source()
|
||||
}
|
||||
|
||||
// walkVertex walks a single vertex, waiting for any dependencies before
|
||||
// executing the callback.
|
||||
func (w *walker) walkVertex(v Vertex, info *walkerVertex) {
|
||||
|
@ -273,16 +292,20 @@ func (w *walker) walkVertex(v Vertex, info *walkerVertex) {
|
|||
// When we're done, always close our done channel
|
||||
defer close(info.DoneCh)
|
||||
|
||||
// Wait for our dependencies
|
||||
depsCh := info.DepsCh
|
||||
// Wait for our dependencies. We create a [closed] deps channel so
|
||||
// that we can immediately fall through to load our actual DepsCh.
|
||||
var depsSuccess bool
|
||||
depsCh := make(chan bool, 1)
|
||||
depsCh <- true
|
||||
close(depsCh)
|
||||
for {
|
||||
select {
|
||||
case <-info.CancelCh:
|
||||
// Cancel
|
||||
return
|
||||
|
||||
case <-depsCh:
|
||||
// Deps complete!
|
||||
case depsSuccess = <-depsCh:
|
||||
// Deps complete! Mark as nil to trigger completion handling.
|
||||
depsCh = nil
|
||||
|
||||
case <-info.DepsUpdateCh:
|
||||
|
@ -306,9 +329,27 @@ func (w *walker) walkVertex(v Vertex, info *walkerVertex) {
|
|||
}
|
||||
}
|
||||
|
||||
// Call our callback
|
||||
log.Printf("[DEBUG] dag/walk: walking %q", VertexName(v))
|
||||
if err := w.Callback(v); err != nil {
|
||||
// If we passed dependencies, we just want to check once more that
|
||||
// we're not cancelled, since this can happen just as dependencies pass.
|
||||
select {
|
||||
case <-info.CancelCh:
|
||||
// Cancelled during an update while dependencies completed.
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Run our callback or note that our upstream failed
|
||||
var err error
|
||||
if depsSuccess {
|
||||
log.Printf("[DEBUG] dag/walk: walking %q", VertexName(v))
|
||||
err = w.Callback(v)
|
||||
} else {
|
||||
log.Printf("[DEBUG] dag/walk: upstream errored, not walking %q", VertexName(v))
|
||||
err = errWalkUpstream
|
||||
}
|
||||
|
||||
// Record the error
|
||||
if err != nil {
|
||||
w.errLock.Lock()
|
||||
defer w.errLock.Unlock()
|
||||
|
||||
|
@ -322,11 +363,8 @@ func (w *walker) walkVertex(v Vertex, info *walkerVertex) {
|
|||
func (w *walker) waitDeps(
|
||||
v Vertex,
|
||||
deps map[Vertex]<-chan struct{},
|
||||
doneCh chan<- struct{},
|
||||
doneCh chan<- bool,
|
||||
cancelCh <-chan struct{}) {
|
||||
// Whenever we return, mark ourselves as complete
|
||||
defer close(doneCh)
|
||||
|
||||
// For each dependency given to us, wait for it to complete
|
||||
for dep, depCh := range deps {
|
||||
DepSatisfied:
|
||||
|
@ -337,13 +375,29 @@ func (w *walker) waitDeps(
|
|||
break DepSatisfied
|
||||
|
||||
case <-cancelCh:
|
||||
// Wait cancelled
|
||||
// Wait cancelled. Note that we didn't satisfy dependencies
|
||||
// so that anything waiting on us also doesn't run.
|
||||
doneCh <- false
|
||||
return
|
||||
|
||||
case <-time.After(time.Second * 5):
|
||||
log.Printf("[DEBUG] vertex %q, waiting for: %q",
|
||||
log.Printf("[DEBUG] dag/walk: vertex %q, waiting for: %q",
|
||||
VertexName(v), VertexName(dep))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dependencies satisfied! We need to check if any errored
|
||||
w.errLock.Lock()
|
||||
defer w.errLock.Unlock()
|
||||
for dep, _ := range deps {
|
||||
if w.errMap[dep] != nil {
|
||||
// One of our dependencies failed, so return false
|
||||
doneCh <- false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// All dependencies satisfied and successful
|
||||
doneCh <- true
|
||||
}
|
||||
|
|
|
@ -33,6 +33,44 @@ func TestWalker_basic(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWalker_error(t *testing.T) {
|
||||
var g Graph
|
||||
g.Add(1)
|
||||
g.Add(2)
|
||||
g.Add(3)
|
||||
g.Add(4)
|
||||
g.Connect(BasicEdge(1, 2))
|
||||
g.Connect(BasicEdge(2, 3))
|
||||
g.Connect(BasicEdge(3, 4))
|
||||
|
||||
// Record function
|
||||
var order []interface{}
|
||||
recordF := walkCbRecord(&order)
|
||||
|
||||
// Build a callback that delays until we close a channel
|
||||
cb := func(v Vertex) error {
|
||||
if v == 2 {
|
||||
return fmt.Errorf("error!")
|
||||
}
|
||||
|
||||
return recordF(v)
|
||||
}
|
||||
|
||||
w := &walker{Callback: cb}
|
||||
w.Update(g.vertices, g.edges)
|
||||
|
||||
// Wait
|
||||
if err := w.Wait(); err == nil {
|
||||
t.Fatal("expect error")
|
||||
}
|
||||
|
||||
// Check
|
||||
expected := []interface{}{1}
|
||||
if !reflect.DeepEqual(order, expected) {
|
||||
t.Fatalf("bad: %#v", order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalker_newVertex(t *testing.T) {
|
||||
// Run it a bunch of times since it is timing dependent
|
||||
for i := 0; i < 50; i++ {
|
||||
|
@ -82,26 +120,20 @@ func TestWalker_removeVertex(t *testing.T) {
|
|||
recordF := walkCbRecord(&order)
|
||||
|
||||
// Build a callback that delays until we close a channel
|
||||
gateCh := make(chan struct{})
|
||||
var w *walker
|
||||
cb := func(v Vertex) error {
|
||||
if v == 1 {
|
||||
<-gateCh
|
||||
g.Remove(2)
|
||||
w.Update(g.vertices, g.edges)
|
||||
}
|
||||
|
||||
return recordF(v)
|
||||
}
|
||||
|
||||
// Add the initial vertices
|
||||
w := &walker{Callback: cb}
|
||||
w = &walker{Callback: cb}
|
||||
w.Update(g.vertices, g.edges)
|
||||
|
||||
// Remove a vertex
|
||||
g.Remove(2)
|
||||
w.Update(g.vertices, g.edges)
|
||||
|
||||
// Open gate
|
||||
close(gateCh)
|
||||
|
||||
// Wait
|
||||
if err := w.Wait(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
|
|
Loading…
Reference in New Issue