* dnsforward: get per-client settings only once

+ dnsforward: add 'ProtectionEnabled = false' test
This commit is contained in:
Simon Zolin 2020-01-09 19:52:06 +03:00
parent b3ddae7f85
commit 0ef8e5cdae
2 changed files with 29 additions and 11 deletions

View File

@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
// A better approach is for proxy.Stop() to wait until all its workers exit, // A better approach is for proxy.Stop() to wait until all its workers exit,
// but this would require the Upstream interface to have Close() function // but this would require the Upstream interface to have Close() function
// (to prevent from hanging while waiting for unresponsive DNS server to respond). // (to prevent from hanging while waiting for unresponsive DNS server to respond).
res, err := s.filterDNSRequest(d)
var setts *dnsfilter.RequestFilteringSettings
var err error
res := &dnsfilter.Result{}
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
if protectionEnabled {
setts = s.getClientRequestFilteringSettings(d)
res, err = s.filterDNSRequest(d, setts)
}
s.RUnlock() s.RUnlock()
if err != nil { if err != nil {
return err return err
@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
d.Res.Answer = answer d.Res.Answer = answer
} }
} else if res.Reason != dnsfilter.NotFilteredWhiteList { } else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
origResp2 := d.Res origResp2 := d.Res
res, err = s.filterDNSResponse(d) res, err = s.filterDNSResponse(d, setts)
if err != nil { if err != nil {
return err return err
} }
@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
} }
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
return &dnsfilter.Result{}, nil
}
setts := s.getClientRequestFilteringSettings(d)
req := d.Req req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".") host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts) res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. // If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
// If this is a match, we set a new response in d.Res and return. // If this is a match, we set a new response in d.Res and return.
func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) { func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
for _, a := range d.Res.Answer { for _, a := range d.Res.Answer {
host := "" host := ""
@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro
s.RUnlock() s.RUnlock()
continue continue
} }
setts := s.getClientRequestFilteringSettings(d)
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts) res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
s.RUnlock() s.RUnlock()

View File

@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{
"example.org.": {{127, 0, 0, 255}}, "example.org.": {{127, 0, 0, 255}},
} }
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm)
assert.True(t, err == nil)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
// but protection is disabled - response is NOT blocked
req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String())
assert.True(t, err == nil)
assert.True(t, reply.Rcode == dns.RcodeSuccess)
}
func TestBlockCNAME(t *testing.T) { func TestBlockCNAME(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}