From cdd55139fae9eedecf28cbce2d066e0187f2e4bd Mon Sep 17 00:00:00 2001
From: Andrey Meshkov <am@adguard.com>
Date: Mon, 23 Dec 2019 19:31:27 +0300
Subject: [PATCH] *(dnsforward): cache upstream instances
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

✅ Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1296
---
 dnsforward/dnsforward.go | 12 +++------
 home/clients.go          | 57 ++++++++++++++++++++++++++++++++++++++--
 home/dns.go              | 11 +++-----
 3 files changed, 62 insertions(+), 18 deletions(-)

diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go
index 0af717be..a2cf5b65 100644
--- a/dnsforward/dnsforward.go
+++ b/dnsforward/dnsforward.go
@@ -113,7 +113,7 @@ type FilteringConfig struct {
 	FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
 
 	// This callback function returns the list of upstream servers for a client specified by IP address
-	GetUpstreamsByClient func(clientAddr string) []string `yaml:"-"`
+	GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"`
 
 	ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
 
@@ -465,13 +465,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
 		if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
 			clientIP := ipFromAddr(d.Addr)
 			upstreams := s.conf.GetUpstreamsByClient(clientIP)
-			for _, us := range upstreams {
-				u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: 30 * time.Second})
-				if err != nil {
-					log.Error("upstream.AddressToUpstream: %s: %s", us, err)
-					continue
-				}
-				d.Upstreams = append(d.Upstreams, u)
+			if len(upstreams) > 0 {
+				log.Debug("Using custom upstreams for %s", clientIP)
+				d.Upstreams = upstreams
 			}
 		}
 
diff --git a/home/clients.go b/home/clients.go
index 468cdfe0..8769ebc3 100644
--- a/home/clients.go
+++ b/home/clients.go
@@ -14,6 +14,7 @@ import (
 
 	"github.com/AdguardTeam/AdGuardHome/dhcpd"
 	"github.com/AdguardTeam/AdGuardHome/dnsforward"
+	"github.com/AdguardTeam/dnsproxy/upstream"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/AdguardTeam/golibs/utils"
 )
@@ -62,8 +63,14 @@ type clientsContainer struct {
 	list    map[string]*Client     // name -> client
 	idIndex map[string]*Client     // IP -> client
 	ipHost  map[string]*ClientHost // IP -> Hostname
-	lock    sync.Mutex
 
+	// cache for Upstream instances that are used in the case
+	// when custom DNS servers are configured for a client
+	upstreamsCache map[string][]upstream.Upstream // name -> []Upstream
+
+	lock sync.Mutex
+
+	// dhcpServer is used for looking up clients IP addresses by MAC addresses
 	dhcpServer *dhcpd.Server
 
 	testing bool // if TRUE, this object is used for internal tests
@@ -78,6 +85,7 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
 	clients.list = make(map[string]*Client)
 	clients.idIndex = make(map[string]*Client)
 	clients.ipHost = make(map[string]*ClientHost)
+	clients.upstreamsCache = make(map[string][]upstream.Upstream)
 	clients.dhcpServer = dhcpServer
 	clients.addFromConfig(objects)
 
@@ -191,6 +199,45 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) {
 	return clients.findByIP(ip)
 }
 
+// FindUpstreams looks for upstreams configured for the client
+// If no client found for this IP, or if no custom upstreams are configured,
+// this method returns nil
+func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream {
+	clients.lock.Lock()
+	defer clients.lock.Unlock()
+
+	c, ok := clients.findByIP(ip)
+	if !ok {
+		return nil
+	}
+
+	if len(c.Upstreams) == 0 {
+		return nil
+	}
+
+	upstreams, ok := clients.upstreamsCache[c.Name]
+	if ok {
+		return upstreams
+	}
+
+	for _, us := range c.Upstreams {
+		u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout})
+		if err != nil {
+			log.Error("upstream.AddressToUpstream: %s: %s", us, err)
+			continue
+		}
+		upstreams = append(upstreams, u)
+	}
+
+	if len(upstreams) == 0 {
+		clients.upstreamsCache[c.Name] = nil
+	} else {
+		clients.upstreamsCache[c.Name] = upstreams
+	}
+
+	return upstreams
+}
+
 // Find searches for a client by IP (and does not lock anything)
 func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
 	ipAddr := net.ParseIP(ip)
@@ -355,6 +402,9 @@ func (clients *clientsContainer) Del(name string) bool {
 	// update Name index
 	delete(clients.list, name)
 
+	// update upstreams cache
+	delete(clients.upstreamsCache, name)
+
 	// update ID index
 	for _, id := range c.IDs {
 		delete(clients.idIndex, id)
@@ -418,10 +468,13 @@ func (clients *clientsContainer) Update(name string, c Client) error {
 
 	// update Name index
 	if old.Name != c.Name {
-		delete(clients.list, old.Name)
 		clients.list[c.Name] = old
 	}
 
+	// update upstreams cache
+	delete(clients.upstreamsCache, name)
+	delete(clients.upstreamsCache, old.Name)
+
 	*old = c
 	return nil
 }
diff --git a/home/dns.go b/home/dns.go
index 1a0666fb..f760ff77 100644
--- a/home/dns.go
+++ b/home/dns.go
@@ -11,6 +11,7 @@ import (
 	"github.com/AdguardTeam/AdGuardHome/querylog"
 	"github.com/AdguardTeam/AdGuardHome/stats"
 	"github.com/AdguardTeam/dnsproxy/proxy"
+	"github.com/AdguardTeam/dnsproxy/upstream"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/joomcode/errorx"
 )
@@ -178,18 +179,12 @@ func generateServerConfig() dnsforward.ServerConfig {
 	return newconfig
 }
 
-func getUpstreamsByClient(clientAddr string) []string {
-	c, ok := Context.clients.Find(clientAddr)
-	if !ok {
-		return []string{}
-	}
-	log.Debug("Using upstreams %v for client %s (IP: %s)", c.Upstreams, c.Name, clientAddr)
-	return c.Upstreams
+func getUpstreamsByClient(clientAddr string) []upstream.Upstream {
+	return Context.clients.FindUpstreams(clientAddr)
 }
 
 // If a client has his own settings, apply them
 func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
-
 	ApplyBlockedServices(setts, config.DNS.BlockedServices)
 
 	if len(clientAddr) == 0 {