diff --git a/dag/dag.go b/dag/dag.go index 22b279524..33f8571a8 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" "sync" + + "github.com/hashicorp/go-multierror" ) // AcyclicGraph is a specialization of Graph that cannot have cycles. With @@ -45,6 +47,8 @@ func (g *AcyclicGraph) Validate() error { return err } + // Look for cycles of more than 1 component + var err error var cycles [][]Vertex for _, cycle := range StronglyConnected(&g.Graph) { if len(cycle) > 1 { @@ -52,20 +56,26 @@ func (g *AcyclicGraph) Validate() error { } } if len(cycles) > 0 { - cyclesStr := make([]string, len(cycles)) - for i, cycle := range cycles { + for _, cycle := range cycles { cycleStr := make([]string, len(cycle)) for j, vertex := range cycle { cycleStr[j] = VertexName(vertex) } - cyclesStr[i] = strings.Join(cycleStr, ", ") + err = multierror.Append(err, fmt.Errorf( + "Cycle: %s", strings.Join(cycleStr, ", "))) } - - return fmt.Errorf("cycles: %s", cyclesStr) } - return nil + // Look for cycles to self + for _, e := range g.Edges() { + if e.Source() == e.Target() { + err = multierror.Append(err, fmt.Errorf( + "Self reference: %s", VertexName(e.Source()))) + } + } + + return err } // Walk walks the graph, calling your callback as each node is visited. diff --git a/dag/dag_test.go b/dag/dag_test.go index fac5d37f4..77f548f4f 100644 --- a/dag/dag_test.go +++ b/dag/dag_test.go @@ -75,6 +75,17 @@ func TestAcyclicGraphValidate_cycle(t *testing.T) { } } +func TestAcyclicGraphValidate_cycleSelf(t *testing.T) { + var g AcyclicGraph + g.Add(1) + g.Add(2) + g.Connect(BasicEdge(1, 1)) + + if err := g.Validate(); err == nil { + t.Fatal("should error") + } +} + func TestAcyclicGraphWalk(t *testing.T) { var g AcyclicGraph g.Add(1) diff --git a/dag/graph.go b/dag/graph.go index c9b0607aa..263802d21 100644 --- a/dag/graph.go +++ b/dag/graph.go @@ -39,7 +39,7 @@ func (g *Graph) Vertices() []Vertex { // Edges returns the list of all the edges in the graph. func (g *Graph) Edges() []Edge { - list := g.vertices.List() + list := g.edges.List() result := make([]Edge, len(list)) for i, v := range list { result[i] = v.(Edge)