Merge pull request #113 from slackhq/fw-ca
Fixes the issues with caSha and caName
This commit is contained in:
commit
e465b13045
|
@ -141,7 +141,7 @@ firewall:
|
||||||
|
|
||||||
# The firewall is default deny. There is no way to write a deny rule.
|
# The firewall is default deny. There is no way to write a deny rule.
|
||||||
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
|
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
|
||||||
# Logical evaluation is roughly: port AND proto AND ca_sha AND ca_name AND (host OR group OR groups OR cidr)
|
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr)
|
||||||
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
|
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
|
||||||
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
|
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
|
||||||
# proto: `any`, `tcp`, `udp`, or `icmp`
|
# proto: `any`, `tcp`, `udp`, or `icmp`
|
||||||
|
|
125
firewall.go
125
firewall.go
|
@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FirewallCA struct {
|
||||||
|
Any *FirewallRule
|
||||||
|
CANames map[string]*FirewallRule
|
||||||
|
CAShas map[string]*FirewallRule
|
||||||
|
}
|
||||||
|
|
||||||
type FirewallRule struct {
|
type FirewallRule struct {
|
||||||
// Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked
|
// Any makes Hosts, Groups, and CIDR irrelevant
|
||||||
Any bool
|
Any bool
|
||||||
Hosts map[string]struct{}
|
Hosts map[string]struct{}
|
||||||
Groups [][]string
|
Groups [][]string
|
||||||
CIDR *CIDRTree
|
CIDR *CIDRTree
|
||||||
CANames map[string]struct{}
|
|
||||||
CAShas map[string]struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Even though ports are uint16, int32 maps are faster for lookup
|
// Even though ports are uint16, int32 maps are faster for lookup
|
||||||
// Plus we can use `-1` for fragment rules
|
// Plus we can use `-1` for fragment rules
|
||||||
type firewallPort map[int32]*FirewallRule
|
type firewallPort map[int32]*FirewallCA
|
||||||
|
|
||||||
type FirewallPacket struct {
|
type FirewallPacket struct {
|
||||||
LocalIP uint32
|
LocalIP uint32
|
||||||
|
@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
||||||
|
|
||||||
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||||
fw := NewFirewall(
|
fw := NewFirewall(
|
||||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)),
|
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)),
|
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||||
c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)),
|
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||||
nc,
|
nc,
|
||||||
//TODO: max_connections
|
//TODO: max_connections
|
||||||
)
|
)
|
||||||
|
@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
||||||
|
|
||||||
for i := startPort; i <= endPort; i++ {
|
for i := startPort; i <= endPort; i++ {
|
||||||
if _, ok := fp[i]; !ok {
|
if _, ok := fp[i]; !ok {
|
||||||
fp[i] = &FirewallRule{
|
fp[i] = &FirewallCA{
|
||||||
Groups: make([][]string, 0),
|
CANames: make(map[string]*FirewallRule),
|
||||||
Hosts: make(map[string]struct{}),
|
CAShas: make(map[string]*FirewallRule),
|
||||||
CIDR: NewCIDRTree(),
|
|
||||||
CANames: make(map[string]struct{}),
|
|
||||||
CAShas: make(map[string]struct{}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -539,15 +540,70 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
|
||||||
return fp[fwPortAny].match(p, c, caPool)
|
return fp[fwPortAny].match(p, c, caPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
|
func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
|
||||||
if caName != "" {
|
fr := func() *FirewallRule {
|
||||||
fr.CANames[caName] = struct{}{}
|
return &FirewallRule{
|
||||||
|
Hosts: make(map[string]struct{}),
|
||||||
|
Groups: make([][]string, 0),
|
||||||
|
CIDR: NewCIDRTree(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if caSha == "" && caName == "" {
|
||||||
|
if fc.Any == nil {
|
||||||
|
fc.Any = fr()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.Any.addRule(groups, host, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
if caSha != "" {
|
if caSha != "" {
|
||||||
fr.CAShas[caSha] = struct{}{}
|
if _, ok := fc.CAShas[caSha]; !ok {
|
||||||
|
fc.CAShas[caSha] = fr()
|
||||||
|
}
|
||||||
|
err := fc.CAShas[caSha].addRule(groups, host, ip)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if caName != "" {
|
||||||
|
if _, ok := fc.CANames[caName]; !ok {
|
||||||
|
fc.CANames[caName] = fr()
|
||||||
|
}
|
||||||
|
err := fc.CANames[caName].addRule(groups, host, ip)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
||||||
|
if fc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if fc.Any.match(p, c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if t, ok := fc.CAShas[c.Details.Issuer]; ok {
|
||||||
|
if t.match(p, c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := caPool.GetCAForCert(c)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return fc.CANames[s.Details.Name].match(p, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
|
||||||
if fr.Any {
|
if fr.Any {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -576,6 +632,10 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caN
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
|
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
|
||||||
|
if len(groups) == 0 && host == "" && ip == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
if group == "any" {
|
if group == "any" {
|
||||||
return true
|
return true
|
||||||
|
@ -593,28 +653,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
|
func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
|
||||||
if fr == nil {
|
if fr == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// CASha and CAName always need to be checked
|
|
||||||
if len(fr.CAShas) > 0 {
|
|
||||||
if _, ok := fr.CAShas[c.Details.Issuer]; !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(fr.CANames) > 0 {
|
|
||||||
s, err := caPool.GetCAForCert(c)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if _, ok := fr.CANames[s.Details.Name]; !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shortcut path for if groups, hosts, or cidr contained an `any`
|
// Shortcut path for if groups, hosts, or cidr contained an `any`
|
||||||
if fr.Any {
|
if fr.Any {
|
||||||
return true
|
return true
|
||||||
|
@ -773,7 +816,7 @@ func setTCPRTTTracking(c *conn, p []byte) {
|
||||||
ihl := int(p[0]&0x0f) << 2
|
ihl := int(p[0]&0x0f) << 2
|
||||||
|
|
||||||
// Don't track FIN packets
|
// Don't track FIN packets
|
||||||
if uint8(p[ihl+13])&tcpFIN != 0 {
|
if p[ihl+13]&tcpFIN != 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -787,7 +830,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
ihl := int(p[0]&0x0f) << 2
|
ihl := int(p[0]&0x0f) << 2
|
||||||
if uint8(p[ihl+13])&tcpACK == 0 {
|
if p[ihl+13]&tcpACK == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
226
firewall_test.go
226
firewall_test.go
|
@ -51,6 +51,11 @@ func TestNewFirewall(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_AddRule(t *testing.T) {
|
func TestFirewall_AddRule(t *testing.T) {
|
||||||
|
ob := &bytes.Buffer{}
|
||||||
|
out := l.Out
|
||||||
|
l.SetOutput(ob)
|
||||||
|
defer l.SetOutput(out)
|
||||||
|
|
||||||
c := &cert.NebulaCertificate{}
|
c := &cert.NebulaCertificate{}
|
||||||
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.NotNil(t, fw.InRules)
|
assert.NotNil(t, fw.InRules)
|
||||||
|
@ -59,39 +64,38 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
|
_, ti, _ := net.ParseCIDR("1.2.3.4/32")
|
||||||
|
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
|
||||||
// Make sure an empty rule creates structure but doesn't allow anything to flow
|
// An empty rule is any
|
||||||
//TODO: ideally an empty rule would return an error
|
assert.True(t, fw.InRules.TCP[1].Any.Any)
|
||||||
assert.False(t, fw.InRules.TCP[1].Any)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Groups)
|
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||||
assert.Empty(t, fw.InRules.TCP[1].Hosts)
|
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
|
||||||
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
|
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
|
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
|
||||||
assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
|
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
|
||||||
assert.False(t, fw.InRules.UDP[1].Any)
|
assert.False(t, fw.InRules.UDP[1].Any.Any)
|
||||||
assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
|
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
|
||||||
assert.Empty(t, fw.InRules.UDP[1].Hosts)
|
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||||
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
|
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
|
||||||
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
|
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
|
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
|
||||||
assert.False(t, fw.InRules.ICMP[1].Any)
|
assert.False(t, fw.InRules.ICMP[1].Any.Any)
|
||||||
assert.Empty(t, fw.InRules.ICMP[1].Groups)
|
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
|
||||||
assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
|
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
|
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
|
||||||
assert.False(t, fw.OutRules.AnyProto[1].Any)
|
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
|
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
|
||||||
|
@ -104,28 +108,29 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
// Set any and clear fields
|
// Set any and clear fields
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
|
||||||
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
|
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
|
||||||
assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
|
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
|
||||||
assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
|
assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
|
||||||
|
|
||||||
// run twice just to make sure
|
// run twice just to make sure
|
||||||
|
//TODO: these ANY rules should clear the CA firewall portion
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
|
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
|
||||||
assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
|
assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
|
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
|
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
|
||||||
assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
|
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
|
||||||
assert.True(t, fw.OutRules.AnyProto[0].Any)
|
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
|
||||||
|
|
||||||
// Test error conditions
|
// Test error conditions
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
|
||||||
|
@ -134,6 +139,11 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop(t *testing.T) {
|
func TestFirewall_Drop(t *testing.T) {
|
||||||
|
ob := &bytes.Buffer{}
|
||||||
|
out := l.Out
|
||||||
|
l.SetOutput(ob)
|
||||||
|
defer l.SetOutput(out)
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
@ -150,10 +160,11 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
|
|
||||||
c := cert.NebulaCertificate{
|
c := cert.NebulaCertificate{
|
||||||
Details: cert.NebulaCertificateDetails{
|
Details: cert.NebulaCertificateDetails{
|
||||||
Name: "host1",
|
Name: "host1",
|
||||||
Ips: []*net.IPNet{&ipNet},
|
Ips: []*net.IPNet{&ipNet},
|
||||||
Groups: []string{"default-group"},
|
Groups: []string{"default-group"},
|
||||||
Issuer: "signer-shasum",
|
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||||
|
Issuer: "signer-shasum",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h := HostInfo{
|
h := HostInfo{
|
||||||
|
@ -170,6 +181,7 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
|
resetConntrack(fw)
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
|
assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
|
@ -180,27 +192,31 @@ func TestFirewall_Drop(t *testing.T) {
|
||||||
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
|
||||||
p.RemoteIP = oldRemote
|
p.RemoteIP = oldRemote
|
||||||
|
|
||||||
// test caSha assertions true
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
|
||||||
|
|
||||||
// test caSha assertions false
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
|
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
|
||||||
// test caName true
|
// test caSha doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
|
||||||
// test caName false
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
|
|
||||||
|
// test caName doesn't drop on match
|
||||||
|
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
|
||||||
|
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
|
||||||
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
|
@ -209,11 +225,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
_, n, _ := net.ParseCIDR("172.1.1.1/32")
|
||||||
ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
|
_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
|
||||||
ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
|
_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
|
||||||
ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
|
_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
|
||||||
ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
|
_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
|
||||||
ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
|
_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
b.Run("fail on proto", func(b *testing.B) {
|
b.Run("fail on proto", func(b *testing.B) {
|
||||||
|
@ -281,7 +297,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
|
_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
|
||||||
|
|
||||||
b.Run("pass on ip with any port", func(b *testing.B) {
|
b.Run("pass on ip with any port", func(b *testing.B) {
|
||||||
ip := ip2int(net.IPv4(172, 1, 1, 1))
|
ip := ip2int(net.IPv4(172, 1, 1, 1))
|
||||||
|
@ -298,6 +314,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop2(t *testing.T) {
|
func TestFirewall_Drop2(t *testing.T) {
|
||||||
|
ob := &bytes.Buffer{}
|
||||||
|
out := l.Out
|
||||||
|
l.SetOutput(ob)
|
||||||
|
defer l.SetOutput(out)
|
||||||
|
|
||||||
p := FirewallPacket{
|
p := FirewallPacket{
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
ip2int(net.IPv4(1, 2, 3, 4)),
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
@ -347,9 +368,94 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||||
// h1/c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
||||||
// c has the proper groups
|
// c has the proper groups
|
||||||
|
resetConntrack(fw)
|
||||||
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
|
ob := &bytes.Buffer{}
|
||||||
|
out := l.Out
|
||||||
|
l.SetOutput(ob)
|
||||||
|
defer l.SetOutput(out)
|
||||||
|
|
||||||
|
p := FirewallPacket{
|
||||||
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
ip2int(net.IPv4(1, 2, 3, 4)),
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
fwProtoUDP,
|
||||||
|
false,
|
||||||
|
}
|
||||||
|
|
||||||
|
ipNet := net.IPNet{
|
||||||
|
IP: net.IPv4(1, 2, 3, 4),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
c := cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "host-owner",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c1 := cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "host1",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
Issuer: "signer-sha-bad",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h1 := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h1.CreateRemoteCIDR(&c1)
|
||||||
|
|
||||||
|
c2 := cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "host2",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
Issuer: "signer-sha",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h2 := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h2.CreateRemoteCIDR(&c2)
|
||||||
|
|
||||||
|
c3 := cert.NebulaCertificate{
|
||||||
|
Details: cert.NebulaCertificateDetails{
|
||||||
|
Name: "host3",
|
||||||
|
Ips: []*net.IPNet{&ipNet},
|
||||||
|
Issuer: "signer-sha-bad",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h3 := HostInfo{
|
||||||
|
ConnectionState: &ConnectionState{
|
||||||
|
peerCert: &c3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h3.CreateRemoteCIDR(&c3)
|
||||||
|
|
||||||
|
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
|
||||||
|
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
|
||||||
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
|
// c1 should pass because host match
|
||||||
|
assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp))
|
||||||
|
// c2 should pass because ca sha match
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp))
|
||||||
|
// c3 should fail because no match
|
||||||
|
resetConntrack(fw)
|
||||||
|
assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp))
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkLookup(b *testing.B) {
|
func BenchmarkLookup(b *testing.B) {
|
||||||
ml := func(m map[string]struct{}, a [][]string) {
|
ml := func(m map[string]struct{}, a [][]string) {
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
|
@ -748,3 +854,9 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
||||||
mf.nextCallReturn = nil
|
mf.nextCallReturn = nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resetConntrack(fw *Firewall) {
|
||||||
|
fw.connMutex.Lock()
|
||||||
|
fw.Conns = map[FirewallPacket]*conn{}
|
||||||
|
fw.connMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue