From a1b18776678ee894cd0b558cf8683845056a2dfa Mon Sep 17 00:00:00 2001
From: Simon Zolin <s.zolin@adguard.com>
Date: Thu, 18 Apr 2019 14:31:13 +0300
Subject: [PATCH] + parental, safesearch: use our own DNS resolver instead of
 system default

---
 dns.go                 |  1 +
 dnsfilter/dnsfilter.go | 58 +++++++++++++++++++++++++++++++++++++++---
 2 files changed, 55 insertions(+), 4 deletions(-)

diff --git a/dns.go b/dns.go
index 9abbc80e..b135babf 100644
--- a/dns.go
+++ b/dns.go
@@ -50,6 +50,7 @@ func generateServerConfig() dnsforward.ServerConfig {
 		FilteringConfig: config.DNS.FilteringConfig,
 		Filters:         filters,
 	}
+	newconfig.ResolverAddress = fmt.Sprintf("%s:%d", config.DNS.BindHost, config.DNS.Port)
 
 	if config.TLS.Enabled {
 		newconfig.TLSConfig = config.TLS.TLSConfig
diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go
index 83818f1b..c487d7d2 100644
--- a/dnsfilter/dnsfilter.go
+++ b/dnsfilter/dnsfilter.go
@@ -3,6 +3,7 @@ package dnsfilter
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"crypto/sha256"
 	"encoding/json"
 	"errors"
@@ -16,6 +17,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/AdguardTeam/dnsproxy/upstream"
 	"github.com/AdguardTeam/golibs/log"
 	"github.com/bluele/gcache"
 	"golang.org/x/net/publicsuffix"
@@ -45,10 +47,11 @@ const enableDelayedCompilation = true // flag for debugging, must be true in pro
 
 // Config allows you to configure DNS filtering with New() or just change variables directly.
 type Config struct {
-	ParentalSensitivity int  `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17
-	ParentalEnabled     bool `yaml:"parental_enabled"`
-	SafeSearchEnabled   bool `yaml:"safesearch_enabled"`
-	SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
+	ParentalSensitivity int    `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17
+	ParentalEnabled     bool   `yaml:"parental_enabled"`
+	SafeSearchEnabled   bool   `yaml:"safesearch_enabled"`
+	SafeBrowsingEnabled bool   `yaml:"safebrowsing_enabled"`
+	ResolverAddress     string // DNS server address
 }
 
 type privateConfig struct {
@@ -159,6 +162,8 @@ var (
 	safeSearchCache   gcache.Cache
 )
 
+var resolverAddr string // DNS server address
+
 // Result holds state of hostname check
 type Result struct {
 	IsFiltered bool   `json:",omitempty"` // True if the host name is filtered
@@ -971,6 +976,47 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
 // lifecycle helper functions
 //
 
+// Connect to a remote server resolving hostname using our own DNS server
+func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+	log.Tracef("network:%v  addr:%v", network, addr)
+
+	host, port, err := net.SplitHostPort(addr)
+	if err != nil {
+		return nil, err
+	}
+
+	dialer := &net.Dialer{
+		Timeout: time.Minute * 5,
+	}
+
+	if net.ParseIP(host) != nil {
+		con, err := dialer.DialContext(ctx, network, addr)
+		return con, err
+	}
+
+	r := upstream.NewResolver(resolverAddr, 30*time.Second)
+	addrs, e := r.LookupIPAddr(ctx, host)
+	log.Tracef("LookupIPAddr: %s: %v", host, addrs)
+	if e != nil {
+		return nil, e
+	}
+
+	var firstErr error
+	firstErr = nil
+	for _, a := range addrs {
+		addr = fmt.Sprintf("%s:%s", a.String(), port)
+		con, err := dialer.DialContext(ctx, network, addr)
+		if err != nil {
+			if firstErr == nil {
+				firstErr = err
+			}
+			continue
+		}
+		return con, err
+	}
+	return nil, firstErr
+}
+
 // New creates properly initialized DNS Filter that is ready to be used
 func New(c *Config) *Dnsfilter {
 	d := new(Dnsfilter)
@@ -990,6 +1036,10 @@ func New(c *Config) *Dnsfilter {
 		TLSHandshakeTimeout:   10 * time.Second,
 		ExpectContinueTimeout: 1 * time.Second,
 	}
+	if len(c.ResolverAddress) != 0 {
+		resolverAddr = c.ResolverAddress
+		d.transport.DialContext = customDialContext
+	}
 	d.client = http.Client{
 		Transport: d.transport,
 		Timeout:   defaultHTTPTimeout,