dag: Remove, RemoveEdge, DownEdges, UpEdges

This commit is contained in:
Mitchell Hashimoto 2015-01-30 12:56:03 -08:00
parent ae4d20f8ce
commit 9dfce9c93a
6 changed files with 160 additions and 38 deletions

View File

@ -1,9 +1,15 @@
package dag
import (
"fmt"
)
// Edge represents an edge in the graph, with a source and target vertex.
type Edge interface {
Source() Vertex
Target() Vertex
hashable
}
// BasicEdge returns an Edge implementation that simply tracks the source
@ -18,6 +24,10 @@ type basicEdge struct {
S, T Vertex
}
func (e *basicEdge) Hashcode() interface{} {
return fmt.Sprintf("%p-%p", e.S, e.T)
}
func (e *basicEdge) Source() Vertex {
return e.S
}

26
dag/edge_test.go Normal file
View File

@ -0,0 +1,26 @@
package dag
import (
"testing"
)
func TestBasicEdgeHashcode(t *testing.T) {
e1 := BasicEdge(1, 2)
e2 := BasicEdge(1, 2)
if e1.Hashcode() != e2.Hashcode() {
t.Fatalf("bad")
}
}
func TestBasicEdgeHashcode_pointer(t *testing.T) {
type test struct {
Value string
}
v1, v2 := &test{"foo"}, &test{"bar"}
e1 := BasicEdge(v1, v2)
e2 := BasicEdge(v1, v2)
if e1.Hashcode() != e2.Hashcode() {
t.Fatalf("bad")
}
}

View File

@ -9,10 +9,10 @@ import (
// Graph is used to represent a dependency graph.
type Graph struct {
vertices *set
edges []Edge
downEdges map[Vertex]*set
upEdges map[Vertex]*set
vertices *Set
edges *Set
downEdges map[Vertex]*Set
upEdges map[Vertex]*Set
once sync.Once
}
@ -39,7 +39,13 @@ func (g *Graph) Vertices() []Vertex {
// Edges returns the list of all the edges in the graph.
func (g *Graph) Edges() []Edge {
return g.edges
list := g.vertices.List()
result := make([]Edge, len(list))
for i, v := range list {
result[i] = v.(Edge)
}
return result
}
// Add adds a vertex to the graph. This is safe to call multiple time with
@ -50,6 +56,51 @@ func (g *Graph) Add(v Vertex) Vertex {
return v
}
// Remove removes a vertex from the graph. This will also remove any
// edges with this vertex as a source or target.
func (g *Graph) Remove(v Vertex) Vertex {
// Delete the vertex itself
g.vertices.Delete(v)
// Delete the edges to non-existent things
for _, target := range g.DownEdges(v).List() {
g.RemoveEdge(BasicEdge(v, target))
}
for _, source := range g.UpEdges(v).List() {
g.RemoveEdge(BasicEdge(source, v))
}
return nil
}
// RemoveEdge removes an edge from the graph.
func (g *Graph) RemoveEdge(edge Edge) {
g.once.Do(g.init)
// Delete the edge from the set
g.edges.Delete(edge)
// Delete the up/down edges
if s, ok := g.downEdges[edge.Source()]; ok {
s.Delete(edge.Target())
}
if s, ok := g.upEdges[edge.Target()]; ok {
s.Delete(edge.Source())
}
}
// DownEdges returns the outward edges from the source Vertex v.
func (g *Graph) DownEdges(v Vertex) *Set {
g.once.Do(g.init)
return g.downEdges[v]
}
// UpEdges returns the inward edges to the destination Vertex v.
func (g *Graph) UpEdges(v Vertex) *Set {
g.once.Do(g.init)
return g.upEdges[v]
}
// Connect adds an edge with the given source and target. This is safe to
// call multiple times with the same value. Note that the same value is
// verified through pointer equality of the vertices, not through the
@ -65,13 +116,13 @@ func (g *Graph) Connect(edge Edge) {
return
}
// TODO: add all edges
g.edges = append(g.edges, edge)
// Add the edge to the set
g.edges.Add(edge)
// Add the down edge
s, ok := g.downEdges[source]
if !ok {
s = new(set)
s = new(Set)
g.downEdges[source] = s
}
s.Add(target)
@ -79,7 +130,7 @@ func (g *Graph) Connect(edge Edge) {
// Add the up edge
s, ok = g.upEdges[target]
if !ok {
s = new(set)
s = new(Set)
g.upEdges[target] = s
}
s.Add(source)
@ -125,10 +176,10 @@ func (g *Graph) String() string {
}
func (g *Graph) init() {
g.vertices = new(set)
g.edges = make([]Edge, 0, 2)
g.downEdges = make(map[Vertex]*set)
g.upEdges = make(map[Vertex]*set)
g.vertices = new(Set)
g.edges = new(Set)
g.downEdges = make(map[Vertex]*Set)
g.upEdges = make(map[Vertex]*Set)
}
// VertexName returns the name of a vertex.

View File

@ -32,6 +32,21 @@ func TestGraph_basic(t *testing.T) {
}
}
func TestGraph_remove(t *testing.T) {
var g Graph
g.Add(1)
g.Add(2)
g.Add(3)
g.Connect(BasicEdge(1, 3))
g.Remove(3)
actual := strings.TrimSpace(g.String())
expected := strings.TrimSpace(testGraphRemoveStr)
if actual != expected {
t.Fatalf("bad: %s", actual)
}
}
const testGraphBasicStr = `
1
3
@ -43,3 +58,8 @@ const testGraphEmptyStr = `
2
3
`
const testGraphRemoveStr = `
1
2
`

View File

@ -4,34 +4,40 @@ import (
"sync"
)
// set is an internal Set data structure that is based on simply using
// pointers as the hash key into a map.
type set struct {
m map[interface{}]struct{}
// Set is a set data structure.
type Set struct {
m map[interface{}]interface{}
once sync.Once
}
// hashable is the interface used by set to get the hash code of a value.
// If this isn't given, then the value of the item being added to the set
// itself is used as the comparison value.
type hashable interface {
Hashcode() interface{}
}
// Add adds an item to the set
func (s *set) Add(v interface{}) {
func (s *Set) Add(v interface{}) {
s.once.Do(s.init)
s.m[v] = struct{}{}
s.m[s.code(v)] = v
}
// Delete removes an item from the set.
func (s *set) Delete(v interface{}) {
func (s *Set) Delete(v interface{}) {
s.once.Do(s.init)
delete(s.m, v)
delete(s.m, s.code(v))
}
// Include returns true/false of whether a value is in the set.
func (s *set) Include(v interface{}) bool {
func (s *Set) Include(v interface{}) bool {
s.once.Do(s.init)
_, ok := s.m[v]
_, ok := s.m[s.code(v)]
return ok
}
// Len is the number of items in the set.
func (s *set) Len() int {
func (s *Set) Len() int {
if s == nil {
return 0
}
@ -40,19 +46,27 @@ func (s *set) Len() int {
}
// List returns the list of set elements.
func (s *set) List() []interface{} {
func (s *Set) List() []interface{} {
if s == nil {
return nil
}
r := make([]interface{}, 0, len(s.m))
for k, _ := range s.m {
r = append(r, k)
for _, v := range s.m {
r = append(r, v)
}
return r
}
func (s *set) init() {
s.m = make(map[interface{}]struct{})
func (s *Set) code(v interface{}) interface{} {
if h, ok := v.(hashable); ok {
return h.Hashcode()
}
return v
}
func (s *Set) init() {
s.m = make(map[interface{}]interface{})
}

View File

@ -1,8 +1,7 @@
package dag
import (
"bytes"
"fmt"
"sort"
"strings"
"testing"
)
@ -58,28 +57,30 @@ func TestGraphStronglyConnected_three(t *testing.T) {
}
func testSCCStr(list [][]Vertex) string {
var buf bytes.Buffer
var lines []string
for _, vs := range list {
result := make([]string, len(vs))
for i, v := range vs {
result[i] = VertexName(v)
}
buf.WriteString(fmt.Sprintf("%s\n", strings.Join(result, ",")))
sort.Strings(result)
lines = append(lines, strings.Join(result, ","))
}
return buf.String()
sort.Strings(lines)
return strings.Join(lines, "\n")
}
const testGraphStronglyConnectedStr = `2,1`
const testGraphStronglyConnectedStr = `1,2`
const testGraphStronglyConnectedTwoStr = `
2,1
1,2
3
`
const testGraphStronglyConnectedThreeStr = `
2,1
1,2
3
6,5,4
4,5,6
`