* dnsforward: get per-client settings only once
+ dnsforward: add 'ProtectionEnabled = false' test
This commit is contained in:
parent
b3ddae7f85
commit
0ef8e5cdae
@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||||||
// A better approach is for proxy.Stop() to wait until all its workers exit,
|
// A better approach is for proxy.Stop() to wait until all its workers exit,
|
||||||
// but this would require the Upstream interface to have Close() function
|
// but this would require the Upstream interface to have Close() function
|
||||||
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
|
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
|
||||||
res, err := s.filterDNSRequest(d)
|
|
||||||
|
var setts *dnsfilter.RequestFilteringSettings
|
||||||
|
var err error
|
||||||
|
res := &dnsfilter.Result{}
|
||||||
|
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||||
|
if protectionEnabled {
|
||||||
|
setts = s.getClientRequestFilteringSettings(d)
|
||||||
|
res, err = s.filterDNSRequest(d, setts)
|
||||||
|
}
|
||||||
s.RUnlock()
|
s.RUnlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
|||||||
d.Res.Answer = answer
|
d.Res.Answer = answer
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList {
|
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
|
||||||
origResp2 := d.Res
|
origResp2 := d.Res
|
||||||
res, err = s.filterDNSResponse(d)
|
res, err = s.filterDNSResponse(d, setts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
||||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||||
if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
|
|
||||||
return &dnsfilter.Result{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
setts := s.getClientRequestFilteringSettings(d)
|
|
||||||
req := d.Req
|
req := d.Req
|
||||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||||
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
|
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
|
||||||
@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
|
|||||||
|
|
||||||
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
|
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
|
||||||
// If this is a match, we set a new response in d.Res and return.
|
// If this is a match, we set a new response in d.Res and return.
|
||||||
func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) {
|
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
||||||
for _, a := range d.Res.Answer {
|
for _, a := range d.Res.Answer {
|
||||||
host := ""
|
host := ""
|
||||||
|
|
||||||
@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro
|
|||||||
s.RUnlock()
|
s.RUnlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
setts := s.getClientRequestFilteringSettings(d)
|
|
||||||
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
|
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
|
||||||
s.RUnlock()
|
s.RUnlock()
|
||||||
|
|
||||||
|
@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{
|
|||||||
"example.org.": {{127, 0, 0, 255}},
|
"example.org.": {{127, 0, 0, 255}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||||
|
s := createTestServer(t)
|
||||||
|
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
|
||||||
|
s.conf.ProtectionEnabled = false
|
||||||
|
err := s.startWithUpstream(testUpstm)
|
||||||
|
assert.True(t, err == nil)
|
||||||
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
|
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
|
||||||
|
// but protection is disabled - response is NOT blocked
|
||||||
|
req := createTestMessage("badhost.")
|
||||||
|
reply, err := dns.Exchange(req, addr.String())
|
||||||
|
assert.True(t, err == nil)
|
||||||
|
assert.True(t, reply.Rcode == dns.RcodeSuccess)
|
||||||
|
}
|
||||||
|
|
||||||
func TestBlockCNAME(t *testing.T) {
|
func TestBlockCNAME(t *testing.T) {
|
||||||
s := createTestServer(t)
|
s := createTestServer(t)
|
||||||
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
|
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
|
||||||
|
Loading…
Reference in New Issue
Block a user