diff --git a/dnsforward/cache.go b/dnsforward/cache.go new file mode 100644 index 00000000..568f284c --- /dev/null +++ b/dnsforward/cache.go @@ -0,0 +1,225 @@ +package dnsforward + +import ( + "encoding/binary" + "log" + "math" + "strings" + "sync" + "time" + + "github.com/miekg/dns" +) + +type item struct { + m *dns.Msg + when time.Time +} + +type cache struct { + items map[string]item + + sync.RWMutex +} + +func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) { + if request == nil { + return nil, false + } + ok, key := key(request) + if !ok { + log.Printf("Get(): key returned !ok") + return nil, false + } + + c.RLock() + item, ok := c.items[key] + c.RUnlock() + if !ok { + return nil, false + } + // get item's TTL + ttl := findLowestTTL(item.m) + // zero TTL? delete and don't serve it + if ttl == 0 { + c.Lock() + delete(c.items, key) + c.Unlock() + return nil, false + } + // too much time has passed? delete and don't serve it + if time.Since(item.when) >= time.Duration(ttl)*time.Second { + c.Lock() + delete(c.items, key) + c.Unlock() + return nil, false + } + response := item.fromItem(request) + return response, true +} + +func (c *cache) Set(m *dns.Msg) { + if m == nil { + return // no-op + } + if !isRequestCacheable(m) { + return + } + if !isResponseCacheable(m) { + return + } + ok, key := key(m) + if !ok { + return + } + + i := toItem(m) + + c.Lock() + if c.items == nil { + c.items = map[string]item{} + } + c.items[key] = i + c.Unlock() +} + +// check only request fields +func isRequestCacheable(m *dns.Msg) bool { + // truncated messages aren't valid + if m.Truncated { + log.Printf("Refusing to cache truncated message") + return false + } + + // if has wrong number of questions, also don't cache + if len(m.Question) != 1 { + log.Printf("Refusing to cache message with wrong number of questions") + return false + } + + // only OK or NXdomain replies are cached + switch m.Rcode { + case dns.RcodeSuccess: + case dns.RcodeNameError: // that's an NXDomain + case dns.RcodeServerFailure: + return false // quietly refuse, don't log + default: + log.Printf("%s: Refusing to cache message with rcode: %s", m.Question[0].Name, dns.RcodeToString[m.Rcode]) + return false + } + + return true +} + +func isResponseCacheable(m *dns.Msg) bool { + ttl := findLowestTTL(m) + if ttl == 0 { + return false + } + + return true +} + +func findLowestTTL(m *dns.Msg) uint32 { + var ttl uint32 = math.MaxUint32 + found := false + + if m.Answer != nil { + for _, r := range m.Answer { + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if m.Ns != nil { + for _, r := range m.Ns { + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if m.Extra != nil { + for _, r := range m.Extra { + if r.Header().Rrtype == dns.TypeOPT { + continue // OPT records use TTL for other purposes + } + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if found == false { + return 0 + } + + return ttl +} + +// key is binary little endian in sequence: +// uint16(qtype) then uint16(qclass) then name +func key(m *dns.Msg) (bool, string) { + if len(m.Question) != 1 { + log.Printf("got msg with len(m.Question) != 1: %d", len(m.Question)) + return false, "" + } + + bb := strings.Builder{} + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, m.Question[0].Qtype) + bb.Write(b) + binary.LittleEndian.PutUint16(b, m.Question[0].Qclass) + bb.Write(b) + name := strings.ToLower(m.Question[0].Name) + bb.WriteString(name) + return true, bb.String() +} + +func toItem(m *dns.Msg) item { + return item{ + m: m, + when: time.Now(), + } +} + +func (i *item) fromItem(request *dns.Msg) *dns.Msg { + response := &dns.Msg{} + response.SetReply(request) + + response.Authoritative = false + response.AuthenticatedData = i.m.AuthenticatedData + response.RecursionAvailable = i.m.RecursionAvailable + response.Rcode = i.m.Rcode + + ttl := findLowestTTL(i.m) + timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds()) + var newttl uint32 + if timeleft > 0 { + newttl = uint32(timeleft) + } + for _, r := range i.m.Answer { + answer := dns.Copy(r) + answer.Header().Ttl = newttl + response.Answer = append(response.Answer, answer) + } + for _, r := range i.m.Ns { + ns := dns.Copy(r) + ns.Header().Ttl = newttl + response.Ns = append(response.Ns, ns) + } + for _, r := range i.m.Extra { + // don't return OPT records as these are hop-by-hop + if r.Header().Rrtype == dns.TypeOPT { + continue + } + extra := dns.Copy(r) + extra.Header().Ttl = newttl + response.Extra = append(response.Extra, extra) + } + return response +} diff --git a/dnsforward/cache_test.go b/dnsforward/cache_test.go new file mode 100644 index 00000000..c9f4577e --- /dev/null +++ b/dnsforward/cache_test.go @@ -0,0 +1,144 @@ +package dnsforward + +import ( + "strings" + "testing" + + "github.com/go-test/deep" + "github.com/miekg/dns" +) + +func RR(rr string) dns.RR { + r, err := dns.NewRR(rr) + if err != nil { + panic(err) + } + return r +} + +// deepEqual is same as deep.Equal, except: +// * ignores Id when comparing +// * question names are not case sensetive +func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string { + temp := *left + temp.Id = right.Id + for i := range left.Question { + left.Question[i].Name = strings.ToLower(left.Question[i].Name) + } + for i := range right.Question { + right.Question[i].Name = strings.ToLower(right.Question[i].Name) + } + return deep.Equal(&temp, right) +} + +func TestCacheSanity(t *testing.T) { + cache := cache{} + request := dns.Msg{} + request.SetQuestion("google.com.", dns.TypeA) + _, ok := cache.Get(&request) + if ok { + t.Fatal("empty cache replied with positive response") + } +} + +type tests struct { + cache []testEntry + cases []testCase +} + +type testEntry struct { + q string + t uint16 + a []dns.RR +} + +type testCase struct { + q string + t uint16 + a []dns.RR + ok bool +} + +func TestCache(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func TestCacheMixedCase(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "gOOgle.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "GOOGLE.COM.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func TestZeroTTL(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func runTests(t *testing.T, tests tests) { + t.Helper() + cache := cache{} + for _, tc := range tests.cache { + reply := dns.Msg{} + reply.SetQuestion(tc.q, tc.t) + reply.Response = true + reply.Answer = tc.a + cache.Set(&reply) + } + for _, tc := range tests.cases { + request := dns.Msg{} + request.SetQuestion(tc.q, tc.t) + val, ok := cache.Get(&request) + if diff := deep.Equal(ok, tc.ok); diff != nil { + t.Error(diff) + } + if tc.a != nil { + if ok == false { + continue + } + reply := dns.Msg{} + reply.SetQuestion(tc.q, tc.t) + reply.Response = true + reply.Answer = tc.a + cache.Set(&reply) + if diff := deepEqualMsg(val, &reply); diff != nil { + t.Error(diff) + } else { + if diff := deep.Equal(val, reply); diff == nil { + t.Error("different message ID were not caught") + } + } + } + } +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go new file mode 100644 index 00000000..546c4eae --- /dev/null +++ b/dnsforward/dnsforward.go @@ -0,0 +1,467 @@ +package dnsforward + +import ( + "fmt" + "log" + "net" + "reflect" + "sync" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +// Server is the main way to start a DNS server +// Example: +// s := dnsforward.Server{} +// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine +// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535 +// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines +// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine +// +// The zero Server is empty and ready for use. +type Server struct { + udpListen *net.UDPConn + + dnsFilter *dnsfilter.Dnsfilter + + cache cache + + sync.RWMutex + ServerConfig +} + +// The zero ServerConfig is empty and ready for use. +type ServerConfig struct { + UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) + BlockedTTL uint32 // if 0, then default is used (3600) + Upstreams []Upstream + Filters []Filter +} + +var defaultValues = ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: 53}, + BlockedTTL: 3600, + Upstreams: []Upstream{ + //// dns over HTTPS + // &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"}, + // &dnsOverHTTPS{Address: "https://dns.google.com/experimental"}, + // &dnsOverHTTPS{Address: "https://doh.cleanbrowsing.org/doh/security-filter/"}, + // &dnsOverHTTPS{Address: "https://dns10.quad9.net/dns-query"}, + // &dnsOverHTTPS{Address: "https://doh.powerdns.org"}, + // &dnsOverHTTPS{Address: "https://doh.securedns.eu/dns-query"}, + + //// dns over TLS + // &dnsOverTLS{Address: "tls://8.8.8.8:853"}, + // &dnsOverTLS{Address: "tls://8.8.4.4:853"}, + &dnsOverTLS{Address: "tls://1.1.1.1:853"}, + &dnsOverTLS{Address: "tls://1.0.0.1:853"}, + + //// plainDNS + // &plainDNS{Address: "8.8.8.8:53"}, + // &plainDNS{Address: "8.8.4.4:53"}, + // &plainDNS{Address: "1.1.1.1:53"}, + // &plainDNS{Address: "1.0.0.1:53"}, + }, +} + +type Filter struct { + ID int64 + Rules []string +} + +// +// packet loop +// +func (s *Server) packetLoop() { + log.Printf("Entering packet handle loop") + b := make([]byte, dns.MaxMsgSize) + for { + s.RLock() + conn := s.udpListen + s.RUnlock() + if conn == nil { + log.Printf("udp socket has disappeared, exiting loop") + break + } + n, addr, err := conn.ReadFrom(b) + // documentation says to handle the packet even if err occurs, so do that first + if n > 0 { + // make a copy of all bytes because ReadFrom() will overwrite contents of b on next call + // we need the contents to survive the call because we're handling them in goroutine + p := make([]byte, n) + copy(p, b) + go s.handlePacket(p, addr, conn) // ignore errors + } + if err != nil { + if isConnClosed(err) { + log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop") + break + } + log.Printf("Got error when reading from udp listen: %s", err) + } + } +} + +// +// Control functions +// + +func (s *Server) Start(config *ServerConfig) error { + s.Lock() + defer s.Unlock() + if config != nil { + s.ServerConfig = *config + } + // TODO: handle being called Start() second time after Stop() + if s.udpListen == nil { + log.Printf("Creating UDP socket") + var err error + addr := s.UDPListenAddr + if addr == nil { + addr = defaultValues.UDPListenAddr + } + s.udpListen, err = net.ListenUDP("udp", addr) + if err != nil { + return errorx.Decorate(err, "Couldn't listen to UDP socket") + } + log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr) + } + + if s.dnsFilter == nil { + log.Printf("Creating dnsfilter") + s.dnsFilter = dnsfilter.New() + } + + go s.packetLoop() + + return nil +} + +func (s *Server) Stop() error { + s.Lock() + defer s.Unlock() + if s.udpListen != nil { + err := s.udpListen.Close() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + s.udpListen = nil + } + return nil +} + +// +// Server reconfigure +// + +func (s *Server) reconfigureListenAddr(new ServerConfig) error { + oldAddr := s.UDPListenAddr + if oldAddr == nil { + oldAddr = defaultValues.UDPListenAddr + } + newAddr := new.UDPListenAddr + if newAddr == nil { + newAddr = defaultValues.UDPListenAddr + } + if newAddr.Port == 0 { + return errorx.IllegalArgument.New("new port cannot be 0") + } + if reflect.DeepEqual(oldAddr, newAddr) { + // do nothing, the addresses are exactly the same + log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr) + return nil + } + + // rebind, using a strategy: + // * if ports are different, bind new first, then close old + // * if ports are same, close old first, then bind new + var newListen *net.UDPConn + var err error + if oldAddr.Port != newAddr.Port { + log.Printf("Rebinding -- ports are different so bind first then close") + newListen, err = net.ListenUDP("udp", newAddr) + if err != nil { + return errorx.Decorate(err, "Couldn't bind to %v", newAddr) + } + if s.udpListen != nil { + err := s.udpListen.Close() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + } + } else { + log.Printf("Rebinding -- ports are same so close first then bind") + if s.udpListen != nil { + err := s.udpListen.Close() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + } + newListen, err = net.ListenUDP("udp", newAddr) + if err != nil { + return errorx.Decorate(err, "Couldn't bind to %v", newAddr) + } + } + s.Lock() + s.udpListen = newListen + s.UDPListenAddr = new.UDPListenAddr + s.Unlock() + log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr) + + go s.packetLoop() // the old one has quit, use new one + + return nil +} + +func (s *Server) reconfigureBlockedTTL(new ServerConfig) { + newVal := new.BlockedTTL + if newVal == 0 { + newVal = defaultValues.BlockedTTL + } + oldVal := s.BlockedTTL + if oldVal == 0 { + oldVal = defaultValues.BlockedTTL + } + if newVal != oldVal { + s.BlockedTTL = new.BlockedTTL + } +} + +func (s *Server) reconfigureUpstreams(new ServerConfig) { + newVal := new.Upstreams + if len(newVal) == 0 { + newVal = defaultValues.Upstreams + } + oldVal := s.Upstreams + if len(oldVal) == 0 { + oldVal = defaultValues.Upstreams + } + if reflect.DeepEqual(newVal, oldVal) { + // they're exactly the same, do nothing + return + } + s.Upstreams = new.Upstreams +} + +func (s *Server) reconfigureFilters(new ServerConfig) { + newFilters := new.Filters + if len(newFilters) == 0 { + newFilters = defaultValues.Filters + } + oldFilters := s.Filters + if len(oldFilters) == 0 { + oldFilters = defaultValues.Filters + } + if reflect.DeepEqual(newFilters, oldFilters) { + // they're exactly the same, do nothing + return + } + + dnsFilter := dnsfilter.New() + for _, f := range newFilters { + for _, rule := range f.Rules { + err := dnsFilter.AddRule(rule, f.ID) + if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { + continue + } + if err != nil { + log.Printf("Cannot add rule %s: %s", rule, err) + // Just ignore invalid rules + continue + } + } + } + + s.Lock() + oldDnsFilter := s.dnsFilter + s.dnsFilter = dnsFilter + s.Unlock() + + oldDnsFilter.Destroy() +} + +func (s *Server) Reconfigure(new ServerConfig) error { + s.reconfigureBlockedTTL(new) + s.reconfigureUpstreams(new) + s.reconfigureFilters(new) + + err := s.reconfigureListenAddr(new) + if err != nil { + return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr) + } + return nil +} + +// +// packet handling functions +// + +func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { + // log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p) + msg := dns.Msg{} + err := msg.Unpack(p) + if err != nil { + log.Printf("failed to unpack DNS packet: %s", err) + return + } + + // + // DNS packet byte format is valid + // + // any errors below here require a response to client + // log.Printf("Unpacked: %v", msg.String()) + if len(msg.Question) != 1 { + log.Printf("Got invalid number of questions: %v", len(msg.Question)) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + } + + { + val, ok := s.cache.Get(&msg) + if ok && val != nil { + err = s.respond(val, addr, conn) + if err != nil { + if isConnClosed(err) { + // ignore this error, the connection was closed and that's ok + return + } + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + return + } + } + host := msg.Question[0].Name + res, err := s.dnsFilter.CheckHost(host) + if err != nil { + log.Printf("dnsfilter failed to check host '%s': %s", host, err) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + } else if res.IsFiltered { + log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) + err := s.respondWithNXDomain(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + } + + // TODO: replace with single-socket implementation + // TODO: replace 8.8.8.8:53 with configurable upstreams + // TODO: support DoH, DoT and TCP + upstream := s.chooseUpstream() + reply, err := upstream.Exchange(&msg) + if err != nil { + log.Printf("talking to upstream failed for host '%s': %s", host, err) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + if isConnClosed(err) { + // ignore this error, the connection was closed and that's ok + return + } + log.Printf("Couldn't respond to UDP packet with server failure: %s", err) + return + } + return + } + if reply == nil { + log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String()) + err := s.respondWithServerFailure(&msg, addr, conn) + if err != nil { + log.Printf("Couldn't respond to UDP packet with should not happen: %s", err) + return + } + return + } + + err = s.respond(reply, addr, conn) + if err != nil { + if isConnClosed(err) { + // ignore this error, the connection was closed and that's ok + return + } + log.Printf("Couldn't respond to UDP packet: %s", err) + return + } + + s.cache.Set(reply) +} + +// +// packet sending functions +// + +func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error { + // log.Printf("Replying to %s with %s", addr, resp) + resp.Compress = true + bytes, err := resp.Pack() + if err != nil { + return errorx.Decorate(err, "Couldn't convert message into wire format") + } + n, err := conn.WriteTo(bytes, addr) + if n == 0 && isConnClosed(err) { + return err + } + if n != len(bytes) { + return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes)) + } + if err != nil { + return errorx.Decorate(err, "WriteTo() returned error") + } + return nil +} + +func (s *Server) respondWithServerFailure(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeServerFailure) + return s.respond(&resp, addr, conn) +} + +func (s *Server) respondWithNXDomain(request *dns.Msg, addr net.Addr, conn *net.UDPConn) error { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNameError) + resp.Ns = s.genSOA(request) + return s.respond(&resp, addr, conn) +} + +func (s *Server) genSOA(request *dns.Msg) []dns.RR { + zone := "" + if len(request.Question) > 0 { + zone = request.Question[0].Name + } + + soa := dns.SOA{ + // values copied from verisign's nonexistent .com domain + // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers + Refresh: 1800, + Retry: 900, + Expire: 604800, + Minttl: 86400, + // copied from AdGuard DNS + Ns: "fake-for-negative-caching.adguard.com.", + Serial: 100500, + // rest is request-specific + Hdr: dns.RR_Header{ + Name: zone, + Rrtype: dns.TypeSOA, + Ttl: s.BlockedTTL, + Class: dns.ClassINET, + }, + Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." + } + if soa.Hdr.Ttl == 0 { + soa.Hdr.Ttl = defaultValues.BlockedTTL + } + if len(zone) > 0 && zone[0] != '.' { + soa.Mbox += zone + } + return []dns.RR{&soa} +} diff --git a/dnsforward/helpers.go b/dnsforward/helpers.go new file mode 100644 index 00000000..339023a0 --- /dev/null +++ b/dnsforward/helpers.go @@ -0,0 +1,43 @@ +package dnsforward + +import ( + "fmt" + "net" + "os" + "path" + "runtime" + "strings" +) + +func isConnClosed(err error) bool { + if err == nil { + return false + } + nerr, ok := err.(*net.OpError) + if !ok { + return false + } + + if strings.Contains(nerr.Err.Error(), "use of closed network connection") { + return true + } + + return false +} + +// --------------------- +// debug logging helpers +// --------------------- +func trace(format string, args ...interface{}) { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + var buf strings.Builder + buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) + text := fmt.Sprintf(format, args...) + buf.WriteString(text) + if len(text) == 0 || text[len(text)-1] != '\n' { + buf.WriteRune('\n') + } + fmt.Fprint(os.Stderr, buf.String()) +} diff --git a/dnsforward/standalone/.gitignore b/dnsforward/standalone/.gitignore new file mode 100644 index 00000000..5f81988c --- /dev/null +++ b/dnsforward/standalone/.gitignore @@ -0,0 +1 @@ +/standalone \ No newline at end of file diff --git a/dnsforward/standalone/standalone.go b/dnsforward/standalone/standalone.go new file mode 100644 index 00000000..ae3e6d13 --- /dev/null +++ b/dnsforward/standalone/standalone.go @@ -0,0 +1,51 @@ +package main + +import ( + "log" + "net" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "runtime" + "syscall" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsforward" +) + +// +// main function +// +func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + go func() { + for range time.Tick(time.Second) { + log.Printf("goroutines = %d", runtime.NumGoroutine()) + } + }() + s := dnsforward.Server{} + err := s.Start(nil) + if err != nil { + panic(err) + } + time.Sleep(time.Second) + err = s.Stop() + if err != nil { + panic(err) + } + err = s.Start(&dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) + if err != nil { + panic(err) + } + err = s.Reconfigure(dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53, IP: net.ParseIP("0.0.0.0")}}) + if err != nil { + panic(err) + } + log.Printf("Now serving DNS") + signal_channel := make(chan os.Signal) + signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM) + <-signal_channel +} diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go new file mode 100644 index 00000000..3e0a61c0 --- /dev/null +++ b/dnsforward/upstream.go @@ -0,0 +1,187 @@ +package dnsforward + +import ( + "bytes" + "crypto/tls" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +const defaultTimeout = time.Second * 10 + +type Upstream interface { + Exchange(m *dns.Msg) (*dns.Msg, error) +} + +// +// plain DNS +// +type plainDNS struct { + Address string +} + +var defaultUDPClient = dns.Client{ + Timeout: defaultTimeout, + UDPSize: dns.MaxMsgSize, +} + +var defaultTCPClient = dns.Client{ + Net: "tcp", + UDPSize: dns.MaxMsgSize, + Timeout: defaultTimeout, +} + +func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { + reply, _, err := defaultUDPClient.Exchange(m, p.Address) + if err != nil && reply != nil && reply.Truncated { + log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) + reply, _, err = defaultTCPClient.Exchange(m, p.Address) + } + return reply, err +} + +// +// DNS-over-TLS +// +type dnsOverTLS struct { + Address string + pool *TLSPool + + sync.RWMutex // protects pool +} + +var defaultTLSClient = dns.Client{ + Net: "tcp-tls", + Timeout: defaultTimeout, + UDPSize: dns.MaxMsgSize, + TLSConfig: &tls.Config{}, +} + +func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { + var pool *TLSPool + p.RLock() + pool = p.pool + p.RUnlock() + if pool == nil { + p.Lock() + // lazy initialize it + p.pool = &TLSPool{Address: p.Address} + p.Unlock() + } + + p.RLock() + poolConn, err := p.pool.Get() + p.RUnlock() + if err != nil { + return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address) + } + c := dns.Conn{Conn: poolConn} + err = c.WriteMsg(m) + if err != nil { + poolConn.Close() + return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address) + } + + reply, err := c.ReadMsg() + if err != nil { + poolConn.Close() + return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address) + } + p.RLock() + p.pool.Put(poolConn) + p.RUnlock() + return reply, nil +} + +// +// DNS-over-https +// +type dnsOverHTTPS struct { + Address string +} + +var defaultHTTPSTransport = http.Transport{} + +var defaultHTTPSClient = http.Client{ + Transport: &defaultHTTPSTransport, + Timeout: defaultTimeout, +} + +func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { + buf, err := m.Pack() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't pack request msg") + } + bb := bytes.NewBuffer(buf) + resp, err := http.Post(p.Address, "application/dns-message", bb) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.Address) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.Address) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.Address) + } + if len(body) == 0 { + return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.Address) + } + response := dns.Msg{} + err = response.Unpack(body) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.Address, string(body)) + } + return &response, nil +} + +func (s *Server) chooseUpstream() Upstream { + upstreams := s.Upstreams + if upstreams == nil { + upstreams = defaultValues.Upstreams + } + if len(upstreams) == 0 { + panic("SHOULD NOT HAPPEN: no default upstreams specified") + } + if len(upstreams) == 1 { + return upstreams[0] + } + n := rand.Intn(len(upstreams)) + upstream := upstreams[n] + return upstream +} + +func GetUpstream(address string) (Upstream, error) { + if strings.Contains(address, "://") { + url, err := url.Parse(address) + if err != nil { + return nil, errorx.Decorate(err, "Failed to parse %s", address) + } + switch url.Scheme { + case "dns": + return &plainDNS{Address: address}, nil + case "tls": + return &dnsOverTLS{Address: address}, nil + case "https": + return &dnsOverHTTPS{Address: address}, nil + default: + return &plainDNS{Address: address}, nil + } + } + + // we don't have scheme in the url, so it's just a plain DNS host:port + return &plainDNS{Address: address}, nil +} diff --git a/dnsforward/upstream_pool.go b/dnsforward/upstream_pool.go new file mode 100644 index 00000000..9756d54f --- /dev/null +++ b/dnsforward/upstream_pool.go @@ -0,0 +1,98 @@ +package dnsforward + +import ( + "crypto/tls" + "fmt" + "net" + "net/url" + "sync" + + "github.com/joomcode/errorx" +) + +// upstream TLS pool. +// +// Example: +// pool := TLSPool{Address: "tls://1.1.1.1:853"} +// netConn, err := pool.Get() +// if err != nil {panic(err)} +// c := dns.Conn{Conn: netConn} +// q := dns.Msg{} +// q.SetQuestion("google.com.", dns.TypeA) +// log.Println(q) +// err = c.WriteMsg(&q) +// if err != nil {panic(err)} +// r, err := c.ReadMsg() +// if err != nil {panic(err)} +// log.Println(r) +// pool.Put(c.Conn) +type TLSPool struct { + Address string + parsedAddress *url.URL + parsedAddressMutex sync.RWMutex + + conns []net.Conn + sync.Mutex // protects conns +} + +func (n *TLSPool) getHost() (string, error) { + n.parsedAddressMutex.RLock() + if n.parsedAddress != nil { + n.parsedAddressMutex.RUnlock() + return n.parsedAddress.Host, nil + } + n.parsedAddressMutex.RUnlock() + + n.parsedAddressMutex.Lock() + defer n.parsedAddressMutex.Unlock() + url, err := url.Parse(n.Address) + if err != nil { + return "", errorx.Decorate(err, "Failed to parse %s", n.Address) + } + if url.Scheme != "tls" { + return "", fmt.Errorf("TLSPool only supports TLS") + } + n.parsedAddress = url + return n.parsedAddress.Host, nil +} + +func (n *TLSPool) Get() (net.Conn, error) { + host, err := n.getHost() + if err != nil { + return nil, err + } + + // get the connection from the slice inside the lock + var c net.Conn + n.Lock() + num := len(n.conns) + if num > 0 { + last := num - 1 + c = n.conns[last] + n.conns = n.conns[:last] + } + n.Unlock() + + // if we got connection from the slice, return it + if c != nil { + // log.Printf("Returning existing connection to %s", host) + return c, nil + } + + // we'll need a new connection, dial now + // log.Printf("Dialing to %s", host) + conn, err := tls.Dial("tcp", host, nil) + if err != nil { + return nil, errorx.Decorate(err, "Failed to connect to %s", host) + } + return conn, nil +} + +func (n *TLSPool) Put(c net.Conn) { + if c == nil { + return + } + n.Lock() + n.conns = append(n.conns, c) + n.Unlock() +}