Don't use a global logger (#423)
This commit is contained in:
31
firewall.go
31
firewall.go
@ -70,6 +70,7 @@ type Firewall struct {
|
||||
|
||||
trackTCPRTT bool
|
||||
metricTCPRTT metrics.Histogram
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
type FirewallConntrack struct {
|
||||
@ -156,7 +157,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
var min, max time.Duration
|
||||
|
||||
@ -195,11 +196,13 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
|
||||
DefaultTimeout: defaultTimeout,
|
||||
localIps: localIps,
|
||||
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
|
||||
fw := NewFirewall(
|
||||
l,
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||
@ -207,12 +210,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
|
||||
//TODO: max_connections
|
||||
)
|
||||
|
||||
err := AddFirewallRulesFromConfig(false, c, fw)
|
||||
err := AddFirewallRulesFromConfig(l, false, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = AddFirewallRulesFromConfig(true, c, fw)
|
||||
err = AddFirewallRulesFromConfig(l, true, c, fw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -240,7 +243,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
||||
if !incoming {
|
||||
direction = "outgoing"
|
||||
}
|
||||
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
|
||||
Info("Firewall rule added")
|
||||
|
||||
var (
|
||||
@ -276,7 +279,7 @@ func (f *Firewall) GetRuleHash() string {
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error {
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
|
||||
var table string
|
||||
if inbound {
|
||||
table = "firewall.inbound"
|
||||
@ -296,7 +299,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
|
||||
|
||||
for i, t := range rs {
|
||||
var groups []string
|
||||
r, err := convertRule(t, table, i)
|
||||
r, err := convertRule(l, t, table, i)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s rule #%v; %s", table, i, err)
|
||||
}
|
||||
@ -459,8 +462,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||
|
||||
// We now know which firewall table to check against
|
||||
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
@ -472,8 +475,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
|
||||
return false
|
||||
}
|
||||
|
||||
if l.Level >= logrus.DebugLevel {
|
||||
h.logger().
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
h.logger(f.l).
|
||||
WithField("fwPacket", fp).
|
||||
WithField("incoming", c.incoming).
|
||||
WithField("rulesVersion", f.rulesVersion).
|
||||
@ -795,7 +798,7 @@ type rule struct {
|
||||
CASha string
|
||||
}
|
||||
|
||||
func convertRule(p interface{}, table string, i int) (rule, error) {
|
||||
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
|
||||
r := rule{}
|
||||
|
||||
m, ok := p.(map[interface{}]interface{})
|
||||
@ -968,14 +971,14 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
|
||||
|
||||
// Get checks if the cache ticker has moved to the next version before returning
|
||||
// the map. If it has moved, we reset the map.
|
||||
func (c *ConntrackCacheTicker) Get() ConntrackCache {
|
||||
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
|
||||
c.cacheV = tick
|
||||
if ll := len(c.cache); ll > 0 {
|
||||
if l.GetLevel() == logrus.DebugLevel {
|
||||
if l.Level == logrus.DebugLevel {
|
||||
l.WithField("len", ll).Debug("resetting conntrack cache")
|
||||
}
|
||||
c.cache = make(ConntrackCache, ll)
|
||||
|
Reference in New Issue
Block a user