dag: find root of AcyclicGraph

This commit is contained in:
Mitchell Hashimoto 2015-02-04 10:10:32 -05:00
parent 9f70c6fad5
commit cfa3d89265
2 changed files with 74 additions and 1 deletions

View File

@ -1,14 +1,41 @@
package dag package dag
import (
"fmt"
)
// AcyclicGraph is a specialization of Graph that cannot have cycles. With // AcyclicGraph is a specialization of Graph that cannot have cycles. With
// this property, we get the property of sane graph traversal. // this property, we get the property of sane graph traversal.
type AcyclicGraph struct { type AcyclicGraph struct {
*Graph Graph
} }
// WalkFunc is the callback used for walking the graph. // WalkFunc is the callback used for walking the graph.
type WalkFunc func(Vertex) type WalkFunc func(Vertex)
// Root returns the root of the DAG, or an error.
//
// Complexity: O(V)
func (g *AcyclicGraph) Root() (Vertex, error) {
roots := make([]Vertex, 0, 1)
for _, v := range g.Vertices() {
if g.UpEdges(v).Len() == 0 {
roots = append(roots, v)
}
}
if len(roots) > 1 {
// TODO(mitchellh): make this error message a lot better
return nil, fmt.Errorf("multiple roots: %#v", roots)
}
if len(roots) == 0 {
return nil, fmt.Errorf("no roots found")
}
return roots[0], nil
}
// Walk walks the graph, calling your callback as each node is visited. // Walk walks the graph, calling your callback as each node is visited.
func (g *AcyclicGraph) Walk(cb WalkFunc) { func (g *AcyclicGraph) Walk(cb WalkFunc) {
} }

46
dag/dag_test.go Normal file
View File

@ -0,0 +1,46 @@
package dag
import (
"testing"
)
func TestAcyclicGraphRoot(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(3, 2))
g.Connect(BasicEdge(3, 1))
if root, err := g.Root(); err != nil {
t.Fatalf("err: %s", err)
} else if root != 3 {
t.Fatalf("bad: %#v", root)
}
}
func TestAcyclicGraphRoot_cycle(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 2))
g.Connect(BasicEdge(2, 3))
g.Connect(BasicEdge(3, 1))
if _, err := g.Root(); err == nil {
t.Fatal("should error")
}
}
func TestAcyclicGraphRoot_multiple(t *testing.T) {
var g AcyclicGraph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(3, 2))
if _, err := g.Root(); err == nil {
t.Fatal("should error")
}
}