diff --git a/cidr_radix.go b/cidr_radix.go index 7726b9a..a1b5750 100644 --- a/cidr_radix.go +++ b/cidr_radix.go @@ -99,6 +99,29 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) { return value } +// Finds the most specific match +func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) { + bit := startbit + node := tree.root + + for node != nil { + if node.value != nil { + value = node.value + } + + if ip&bit != 0 { + node = node.right + } else { + node = node.left + } + + bit >>= 1 + + } + + return value +} + // Finds the most specific match func (tree *CIDRTree) Match(ip uint32) (value interface{}) { bit := startbit diff --git a/cidr_radix_test.go b/cidr_radix_test.go index e7461bd..1e3fad1 100644 --- a/cidr_radix_test.go +++ b/cidr_radix_test.go @@ -45,6 +45,45 @@ func TestCIDRTree_Contains(t *testing.T) { assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255")))) } +func TestCIDRTree_MostSpecificContains(t *testing.T) { + tree := NewCIDRTree() + tree.AddCIDR(getCIDR("1.0.0.0/8"), "1") + tree.AddCIDR(getCIDR("2.1.0.0/16"), "2") + tree.AddCIDR(getCIDR("3.1.1.0/24"), "3") + tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a") + tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b") + tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c") + tree.AddCIDR(getCIDR("254.0.0.0/4"), "5") + + tests := []struct { + Result interface{} + IP string + }{ + {"1", "1.0.0.0"}, + {"1", "1.255.255.255"}, + {"2", "2.1.0.0"}, + {"2", "2.1.255.255"}, + {"3", "3.1.1.0"}, + {"3", "3.1.1.255"}, + {"4a", "4.1.1.255"}, + {"4b", "4.1.1.2"}, + {"4c", "4.1.1.1"}, + {"5", "240.0.0.0"}, + {"5", "255.255.255.255"}, + {nil, "239.0.0.0"}, + {nil, "4.1.2.2"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP)))) + } + + tree = NewCIDRTree() + tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool") + assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0")))) + assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255")))) +} + func TestCIDRTree_Match(t *testing.T) { tree := NewCIDRTree() tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")