From dcc575402bea9d47e5e4d9c90c4c1e9c212d2395 Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Thu, 30 Jan 2020 10:25:02 +0300
Subject: [PATCH] Merge: * clients: update runtime clients of type DHCP by
 event from DHCP module Close #1378

Squashed commit of the following:

commit e45e2d0e2768fe0677eee43538d381b3eaba39ca
Merge: bea8f79d 5e9c21b0
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Jan 29 20:08:20 2020 +0300

    Merge remote-tracking branch 'origin/master' into 1378-dhcp-clients

commit bea8f79dd6f8f3eae87649d853917b503df29616
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Jan 29 20:08:06 2020 +0300

    minor

commit 6f1da9c6ea9db5bf80acf234ffe322a4cd2d8d92
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Jan 29 19:31:08 2020 +0300

    fix

commit a88b46c1ded2b460ef7f0bfbcf1b80a066edf1c1
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Wed Jan 29 12:53:22 2020 +0300

    minor

commit d2897fe0a9b726fcd97a04906e3be3d21f6b42d7
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Jan 28 19:55:10 2020 +0300

    * clients: update runtime clients of type DHCP by event from DHCP module

commit 3aa352ed2372141617d77363b2f2aeaf3a7e47a0
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Jan 28 19:52:08 2020 +0300

    * minor

commit f5c2291e39df4d13b9baf9aa773284890494bb0a
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Jan 28 19:08:23 2020 +0300

    * clients: remove old entries of source type /etc/hosts or ARP
---
 dhcpd/db.go              |  9 +++++
 dhcpd/dhcp_http.go       |  4 +--
 dhcpd/dhcpd.go           | 74 +++++++++++++++++++++++++++++-----------
 dhcpd/dhcpd_test.go      |  2 +-
 dnsforward/dnsforward.go |  3 +-
 home/clients.go          | 66 ++++++++++++++++++++++++++++-------
 6 files changed, 121 insertions(+), 37 deletions(-)

diff --git a/dhcpd/db.go b/dhcpd/db.go
index 8de5ba79..fdf94059 100644
--- a/dhcpd/db.go
+++ b/dhcpd/db.go
@@ -23,6 +23,14 @@ type leaseJSON struct {
 	Expiry   int64  `json:"exp"`
 }
 
+func normalizeIP(ip net.IP) net.IP {
+	ip4 := ip.To4()
+	if ip4 != nil {
+		return ip4
+	}
+	return ip
+}
+
 // Safe version of dhcp4.IPInRange()
 func ipInRange(start, stop, ip net.IP) bool {
 	if len(start) != len(stop) ||
@@ -56,6 +64,7 @@ func (s *Server) dbLoad() {
 
 	numLeases := len(obj)
 	for i := range obj {
+		obj[i].IP = normalizeIP(obj[i].IP)
 
 		if obj[i].Expiry != leaseExpireStatic &&
 			!ipInRange(s.leaseStart, s.leaseStop, obj[i].IP) {
diff --git a/dhcpd/dhcp_http.go b/dhcpd/dhcp_http.go
index 4d9ef538..e1b3d4fb 100644
--- a/dhcpd/dhcp_http.go
+++ b/dhcpd/dhcp_http.go
@@ -43,8 +43,8 @@ func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string
 }
 
 func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
-	leases := convertLeases(s.Leases(), true)
-	staticLeases := convertLeases(s.StaticLeases(), false)
+	leases := convertLeases(s.Leases(LeasesDynamic), true)
+	staticLeases := convertLeases(s.Leases(LeasesStatic), false)
 	status := map[string]interface{}{
 		"config":        s.conf,
 		"leases":        leases,
diff --git a/dhcpd/dhcpd.go b/dhcpd/dhcpd.go
index 06bb40ea..b7af5dca 100644
--- a/dhcpd/dhcpd.go
+++ b/dhcpd/dhcpd.go
@@ -55,6 +55,16 @@ type ServerConfig struct {
 	HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `json:"-" yaml:"-"`
 }
 
+type onLeaseChangedT func(flags int)
+
+// flags for onLeaseChanged()
+const (
+	LeaseChangedAdded = iota
+	LeaseChangedAddedStatic
+	LeaseChangedRemovedStatic
+	LeaseChangedBlacklisted
+)
+
 // Server - the current state of the DHCP server
 type Server struct {
 	conn *filterConn // listening UDP socket
@@ -78,6 +88,9 @@ type Server struct {
 	IPpool map[[4]byte]net.HardwareAddr
 
 	conf ServerConfig
+
+	// Called when the leases DB is modified
+	onLeaseChanged onLeaseChangedT
 }
 
 // Print information about the available network interfaces
@@ -101,6 +114,13 @@ func Create(config ServerConfig) *Server {
 	s := Server{}
 	s.conf = config
 	s.conf.DBFilePath = filepath.Join(config.WorkDir, dbFilename)
+	if s.conf.Enabled {
+		err := s.setConfig(config)
+		if err != nil {
+			log.Error("DHCP: %s", err)
+			return nil
+		}
+	}
 	if s.conf.HTTPRegister != nil {
 		s.registerHandlers()
 	}
@@ -120,6 +140,18 @@ func (s *Server) Init(config ServerConfig) error {
 	return nil
 }
 
+// SetOnLeaseChanged - set callback
+func (s *Server) SetOnLeaseChanged(onLeaseChanged onLeaseChangedT) {
+	s.onLeaseChanged = onLeaseChanged
+}
+
+func (s *Server) notify(flags int) {
+	if s.onLeaseChanged == nil {
+		return
+	}
+	s.onLeaseChanged(flags)
+}
+
 // WriteDiskConfig - write configuration
 func (s *Server) WriteDiskConfig(c *ServerConfig) {
 	*c = s.conf
@@ -285,7 +317,6 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) {
 			s.leases[i].IP, hwaddr, s.leases[i].HWAddr, s.leases[i].Expiry)
 		lease.IP = s.leases[i].IP
 		s.leases[i] = lease
-		s.dbStore()
 
 		s.reserveIP(lease.IP, hwaddr)
 		return lease, nil
@@ -294,7 +325,6 @@ func (s *Server) reserveLease(p dhcp4.Packet) (*Lease, error) {
 	log.Tracef("Assigning to %s IP address %s", hwaddr, ip.String())
 	lease.IP = ip
 	s.leases = append(s.leases, lease)
-	s.dbStore()
 	return lease, nil
 }
 
@@ -449,6 +479,7 @@ func (s *Server) blacklistLease(lease *Lease) {
 	lease.Expiry = time.Now().Add(s.leaseTime)
 	s.dbStore()
 	s.leasesLock.Unlock()
+	s.notify(LeaseChangedBlacklisted)
 }
 
 // Return TRUE if DHCP packet is correct
@@ -538,6 +569,10 @@ func (s *Server) handleDHCP4Request(p dhcp4.Packet, options dhcp4.Options) dhcp4
 
 	if lease.Expiry.Unix() != leaseExpireStatic {
 		lease.Expiry = time.Now().Add(s.leaseTime)
+		s.leasesLock.Lock()
+		s.dbStore()
+		s.leasesLock.Unlock()
+		s.notify(LeaseChangedAdded) // Note: maybe we shouldn't call this function if only expiration time is updated
 	}
 	log.Tracef("Replying with ACK.  IP: %s  HW: %s  Expire: %s",
 		lease.IP, lease.HWAddr, lease.Expiry)
@@ -578,17 +613,19 @@ func (s *Server) AddStaticLease(l Lease) error {
 	l.Expiry = time.Unix(leaseExpireStatic, 0)
 
 	s.leasesLock.Lock()
-	defer s.leasesLock.Unlock()
 
 	if s.findReservedHWaddr(l.IP) != nil {
 		err := s.rmDynamicLeaseWithIP(l.IP)
 		if err != nil {
+			s.leasesLock.Unlock()
 			return err
 		}
 	}
 	s.leases = append(s.leases, &l)
 	s.reserveIP(l.IP, l.HWAddr)
 	s.dbStore()
+	s.leasesLock.Unlock()
+	s.notify(LeaseChangedAddedStatic)
 	return nil
 }
 
@@ -637,27 +674,38 @@ func (s *Server) RemoveStaticLease(l Lease) error {
 	}
 
 	s.leasesLock.Lock()
-	defer s.leasesLock.Unlock()
 
 	if s.findReservedHWaddr(l.IP) == nil {
+		s.leasesLock.Unlock()
 		return fmt.Errorf("Lease not found")
 	}
 
 	err := s.rmLease(l)
 	if err != nil {
+		s.leasesLock.Unlock()
 		return err
 	}
 	s.dbStore()
+	s.leasesLock.Unlock()
+	s.notify(LeaseChangedRemovedStatic)
 	return nil
 }
 
+// flags for Leases() function
+const (
+	LeasesDynamic = 1
+	LeasesStatic  = 2
+	LeasesAll     = LeasesDynamic | LeasesStatic
+)
+
 // Leases returns the list of current DHCP leases (thread-safe)
-func (s *Server) Leases() []Lease {
+func (s *Server) Leases(flags int) []Lease {
 	var result []Lease
 	now := time.Now().Unix()
 	s.leasesLock.RLock()
 	for _, lease := range s.leases {
-		if lease.Expiry.Unix() > now {
+		if ((flags&LeasesDynamic) != 0 && lease.Expiry.Unix() > now) ||
+			((flags&LeasesStatic) != 0 && lease.Expiry.Unix() == leaseExpireStatic) {
 			result = append(result, *lease)
 		}
 	}
@@ -666,20 +714,6 @@ func (s *Server) Leases() []Lease {
 	return result
 }
 
-// StaticLeases returns the list of statically-configured DHCP leases (thread-safe)
-func (s *Server) StaticLeases() []Lease {
-	s.leasesLock.Lock()
-	defer s.leasesLock.Unlock()
-
-	var result []Lease
-	for _, lease := range s.leases {
-		if lease.Expiry.Unix() == 1 {
-			result = append(result, *lease)
-		}
-	}
-	return result
-}
-
 // Print information about the current leases
 func (s *Server) printLeases() {
 	log.Tracef("Leases:")
diff --git a/dhcpd/dhcpd_test.go b/dhcpd/dhcpd_test.go
index 95522f3e..b41121a3 100644
--- a/dhcpd/dhcpd_test.go
+++ b/dhcpd/dhcpd_test.go
@@ -130,7 +130,7 @@ func testStaticLeases(t *testing.T, s *Server) {
 	err = s.AddStaticLease(l)
 	check(t, err == nil, "AddStaticLease")
 
-	ll := s.StaticLeases()
+	ll := s.Leases(LeasesStatic)
 	check(t, len(ll) != 0 && bytes.Equal(ll[0].IP, []byte{1, 1, 1, 1}), "StaticLeases")
 
 	err = s.RemoveStaticLease(l)
diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index 6ecc7442..1afdc262 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -685,6 +685,7 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
 		processFilteringBeforeRequest,
 		processUpstream,
 		processFilteringAfterResponse,
+		processQueryLogsAndStats,
 	}
 	for _, process := range mods {
 		r := process(ctx)
@@ -699,8 +700,6 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
 	if d.Res != nil {
 		d.Res.Compress = true // some devices require DNS message compression
 	}
-
-	_ = processQueryLogsAndStats(ctx)
 	return nil
 }
 
diff --git a/home/clients.go b/home/clients.go
index fb96fd76..0de73af7 100644
--- a/home/clients.go
+++ b/home/clients.go
@@ -100,6 +100,9 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
 	if !clients.testing {
 		go clients.periodicUpdate()
 
+		clients.addFromDHCP()
+		clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
+
 		clients.registerWebHandlers()
 	}
 }
@@ -186,11 +189,19 @@ func (clients *clientsContainer) periodicUpdate() {
 	for {
 		clients.addFromHostsFile()
 		clients.addFromSystemARP()
-		clients.addFromDHCP()
 		time.Sleep(clientsUpdatePeriod)
 	}
 }
 
+func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
+	switch flags {
+	case dhcpd.LeaseChangedAdded,
+		dhcpd.LeaseChangedAddedStatic,
+		dhcpd.LeaseChangedRemovedStatic:
+		clients.addFromDHCP()
+	}
+}
+
 // Exists checks if client with this IP already exists
 func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
 	clients.lock.Lock()
@@ -412,7 +423,7 @@ func (clients *clientsContainer) Add(c Client) (bool, error) {
 		clients.idIndex[id] = &c
 	}
 
-	log.Tracef("'%s': ID:%v [%d]", c.Name, c.IDs, len(clients.list))
+	log.Debug("Clients: added '%s': ID:%v [%d]", c.Name, c.IDs, len(clients.list))
 	return true, nil
 }
 
@@ -535,8 +546,12 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
 //  so we overwrite existing entries with an equal or higher priority
 func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
 	clients.lock.Lock()
-	defer clients.lock.Unlock()
+	b, e := clients.addHost(ip, host, source)
+	clients.lock.Unlock()
+	return b, e
+}
 
+func (clients *clientsContainer) addHost(ip, host string, source clientSource) (bool, error) {
 	// check auto-clients index
 	ch, ok := clients.ipHost[ip]
 	if ok && ch.Source > source {
@@ -550,10 +565,23 @@ func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (
 		}
 		clients.ipHost[ip] = ch
 	}
-	log.Tracef("'%s' -> '%s' [%d]", ip, host, len(clients.ipHost))
+	log.Debug("Clients: added '%s' -> '%s' [%d]", ip, host, len(clients.ipHost))
 	return true, nil
 }
 
+// Remove all entries that match the specified source
+func (clients *clientsContainer) rmHosts(source clientSource) int {
+	n := 0
+	for k, v := range clients.ipHost {
+		if v.Source == source {
+			delete(clients.ipHost, k)
+			n++
+		}
+	}
+	log.Debug("Clients: removed %d client aliases", n)
+	return n
+}
+
 // Parse system 'hosts' file and fill clients array
 func (clients *clientsContainer) addFromHostsFile() {
 	hostsFn := "/etc/hosts"
@@ -567,6 +595,10 @@ func (clients *clientsContainer) addFromHostsFile() {
 		return
 	}
 
+	clients.lock.Lock()
+	defer clients.lock.Unlock()
+	_ = clients.rmHosts(ClientSourceHostsFile)
+
 	lines := strings.Split(string(d), "\n")
 	n := 0
 	for _, ln := range lines {
@@ -580,7 +612,7 @@ func (clients *clientsContainer) addFromHostsFile() {
 			continue
 		}
 
-		ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile)
+		ok, e := clients.addHost(fields[0], fields[1], ClientSourceHostsFile)
 		if e != nil {
 			log.Tracef("%s", e)
 		}
@@ -589,7 +621,7 @@ func (clients *clientsContainer) addFromHostsFile() {
 		}
 	}
 
-	log.Debug("Added %d client aliases from %s", n, hostsFn)
+	log.Debug("Clients: added %d client aliases from %s", n, hostsFn)
 }
 
 // Add IP -> Host pairs from the system's `arp -a` command output
@@ -609,6 +641,10 @@ func (clients *clientsContainer) addFromSystemARP() {
 		return
 	}
 
+	clients.lock.Lock()
+	defer clients.lock.Unlock()
+	_ = clients.rmHosts(ClientSourceARP)
+
 	n := 0
 	lines := strings.Split(string(data), "\n")
 	for _, ln := range lines {
@@ -625,7 +661,7 @@ func (clients *clientsContainer) addFromSystemARP() {
 			continue
 		}
 
-		ok, e := clients.AddHost(ip, host, ClientSourceARP)
+		ok, e := clients.addHost(ip, host, ClientSourceARP)
 		if e != nil {
 			log.Tracef("%s", e)
 		}
@@ -634,24 +670,30 @@ func (clients *clientsContainer) addFromSystemARP() {
 		}
 	}
 
-	log.Debug("Added %d client aliases from 'arp -a' command output", n)
+	log.Debug("Clients: added %d client aliases from 'arp -a' command output", n)
 }
 
-// add clients from DHCP that have non-empty Hostname property
+// Add clients from DHCP that have non-empty Hostname property
 func (clients *clientsContainer) addFromDHCP() {
 	if clients.dhcpServer == nil {
 		return
 	}
-	leases := clients.dhcpServer.Leases()
+
+	clients.lock.Lock()
+	defer clients.lock.Unlock()
+
+	_ = clients.rmHosts(ClientSourceDHCP)
+
+	leases := clients.dhcpServer.Leases(dhcpd.LeasesAll)
 	n := 0
 	for _, l := range leases {
 		if len(l.Hostname) == 0 {
 			continue
 		}
-		ok, _ := clients.AddHost(l.IP.String(), l.Hostname, ClientSourceDHCP)
+		ok, _ := clients.addHost(l.IP.String(), l.Hostname, ClientSourceDHCP)
 		if ok {
 			n++
 		}
 	}
-	log.Debug("Added %d client aliases from DHCP", n)
+	log.Debug("Clients: added %d client aliases from DHCP", n)
 }