Pull request: 2508 ip conversion vol.1

Merge in DNS/adguard-home from 2508-ip-conversion to master

Updates #2508.

Squashed commit of the following:

commit 3f64709fbc73ef74c11b910997be1e9bc337193c
Merge: 5ac7faaaa 0d67aa251
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 13 16:21:34 2021 +0300

    Merge branch 'master' into 2508-ip-conversion

commit 5ac7faaaa9dda570fdb872acad5d13d078f46b64
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 13 12:00:11 2021 +0300

    all: replace conditions with appropriate functions in tests

commit 9e3fa9a115ed23024c57dd5192d5173477ddbf71
Merge: db992a42a bba74859e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 13 10:47:10 2021 +0300

    Merge branch 'master' into 2508-ip-conversion

commit db992a42a2c6f315421e78a6a0492e2bfb3ce89d
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 18:55:53 2021 +0300

    sysutil: fix linux tests

commit f629b15d62349323ce2da05e68dc9cc0b5f6e194
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 18:41:20 2021 +0300

    all: improve code quality

commit 3bf03a75524040738562298bd1de6db536af130f
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 17:33:26 2021 +0300

    sysutil: fix linux net.IP conversion

commit 5d5b6994916923636e635588631b63b7e7b74e5f
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 12 14:57:26 2021 +0300

    dnsforward: remove redundant net.IP <-> string conversion

commit 0b955d99b7fad40942f21d1dd8734adb99126195
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 11 18:04:25 2021 +0300

    dhcpd: remove net.IP <-> string conversion
This commit is contained in:
Eugene Burkov 2021-01-13 16:56:05 +03:00
parent 0d67aa251d
commit e8c1f5c8d3
39 changed files with 409 additions and 435 deletions

View File

@ -28,27 +28,27 @@ func TestDB(t *testing.T) {
conf := V4ServerConf{ conf := V4ServerConf{
Enabled: true, Enabled: true,
RangeStart: "192.168.10.100", RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: "192.168.10.200", RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: "192.168.10.1", GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: "255.255.255.0", SubnetMask: net.IP{255, 255, 255, 0},
notify: testNotify, notify: testNotify,
} }
s.srv4, err = v4Create(conf) s.srv4, err = v4Create(conf)
assert.True(t, err == nil) assert.Nil(t, err)
s.srv6, err = v6Create(V6ServerConf{}) s.srv6, err = v6Create(V6ServerConf{})
assert.True(t, err == nil) assert.Nil(t, err)
l := Lease{} l := Lease{}
l.IP = net.ParseIP("192.168.10.100").To4() l.IP = net.IP{192, 168, 10, 100}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
exp1 := time.Now().Add(time.Hour) exp1 := time.Now().Add(time.Hour)
l.Expiry = exp1 l.Expiry = exp1
s.srv4.(*v4Server).addLease(&l) s.srv4.(*v4Server).addLease(&l)
l2 := Lease{} l2 := Lease{}
l2.IP = net.ParseIP("192.168.10.101").To4() l2.IP = net.IP{192, 168, 10, 101}
l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb") l2.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:bb")
s.srv4.AddStaticLease(l2) s.srv4.AddStaticLease(l2)
@ -62,7 +62,7 @@ func TestDB(t *testing.T) {
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
assert.Equal(t, "192.168.10.101", ll[0].IP.String()) assert.Equal(t, "192.168.10.101", ll[0].IP.String())
assert.Equal(t, int64(leaseExpireStatic), ll[0].Expiry.Unix()) assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
assert.Equal(t, "192.168.10.100", ll[1].IP.String()) assert.Equal(t, "192.168.10.100", ll[1].IP.String())
@ -75,8 +75,8 @@ func TestIsValidSubnetMask(t *testing.T) {
assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0})) assert.True(t, isValidSubnetMask([]byte{255, 255, 255, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0})) assert.True(t, isValidSubnetMask([]byte{255, 255, 254, 0}))
assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0})) assert.True(t, isValidSubnetMask([]byte{255, 255, 252, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 253, 0})) assert.False(t, isValidSubnetMask([]byte{255, 255, 253, 0}))
assert.True(t, !isValidSubnetMask([]byte{255, 255, 255, 1})) assert.False(t, isValidSubnetMask([]byte{255, 255, 255, 1}))
} }
func TestNormalizeLeases(t *testing.T) { func TestNormalizeLeases(t *testing.T) {
@ -100,7 +100,7 @@ func TestNormalizeLeases(t *testing.T) {
leases := normalizeLeases(staticLeases, dynLeases) leases := normalizeLeases(staticLeases, dynLeases)
assert.True(t, len(leases) == 3) assert.Len(t, leases, 3)
assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4})) assert.True(t, bytes.Equal(leases[0].HWAddr, []byte{1, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4})) assert.True(t, bytes.Equal(leases[0].IP, []byte{0, 2, 3, 4}))
assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4})) assert.True(t, bytes.Equal(leases[1].HWAddr, []byte{2, 2, 3, 4}))
@ -109,22 +109,22 @@ func TestNormalizeLeases(t *testing.T) {
func TestOptions(t *testing.T) { func TestOptions(t *testing.T) {
code, val := parseOptionString(" 12 hex abcdef ") code, val := parseOptionString(" 12 hex abcdef ")
assert.Equal(t, uint8(12), code) assert.EqualValues(t, 12, code)
assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val)) assert.True(t, bytes.Equal([]byte{0xab, 0xcd, 0xef}, val))
code, _ = parseOptionString(" 12 hex abcdef1 ") code, _ = parseOptionString(" 12 hex abcdef1 ")
assert.Equal(t, uint8(0), code) assert.EqualValues(t, 0, code)
code, val = parseOptionString("123 ip 1.2.3.4") code, val = parseOptionString("123 ip 1.2.3.4")
assert.Equal(t, uint8(123), code) assert.EqualValues(t, 123, code)
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String()) assert.Equal(t, "1.2.3.4", net.IP(string(val)).String())
code, _ = parseOptionString("256 ip 1.1.1.1") code, _ = parseOptionString("256 ip 1.1.1.1")
assert.Equal(t, uint8(0), code) assert.EqualValues(t, 0, code)
code, _ = parseOptionString("-1 ip 1.1.1.1") code, _ = parseOptionString("-1 ip 1.1.1.1")
assert.Equal(t, uint8(0), code) assert.EqualValues(t, 0, code)
code, _ = parseOptionString("12 ip 1.1.1.1x") code, _ = parseOptionString("12 ip 1.1.1.1x")
assert.Equal(t, uint8(0), code) assert.EqualValues(t, 0, code)
code, _ = parseOptionString("12 x 1.1.1.1") code, _ = parseOptionString("12 x 1.1.1.1")
assert.Equal(t, uint8(0), code) assert.EqualValues(t, 0, code)
} }

View File

@ -42,10 +42,10 @@ func convertLeases(inputLeases []Lease, includeExpires bool) []map[string]string
} }
type v4ServerConfJSON struct { type v4ServerConfJSON struct {
GatewayIP string `json:"gateway_ip"` GatewayIP net.IP `json:"gateway_ip"`
SubnetMask string `json:"subnet_mask"` SubnetMask net.IP `json:"subnet_mask"`
RangeStart string `json:"range_start"` RangeStart net.IP `json:"range_start"`
RangeEnd string `json:"range_end"` RangeEnd net.IP `json:"range_end"`
LeaseDuration uint32 `json:"lease_duration"` LeaseDuration uint32 `json:"lease_duration"`
} }
@ -61,10 +61,10 @@ func v4ServerConfToJSON(c V4ServerConf) v4ServerConfJSON {
func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf { func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
return V4ServerConf{ return V4ServerConf{
GatewayIP: j.GatewayIP, GatewayIP: j.GatewayIP.To4(),
SubnetMask: j.SubnetMask, SubnetMask: j.SubnetMask.To4(),
RangeStart: j.RangeStart, RangeStart: j.RangeStart.To4(),
RangeEnd: j.RangeEnd, RangeEnd: j.RangeEnd.To4(),
LeaseDuration: j.LeaseDuration, LeaseDuration: j.LeaseDuration,
} }
} }
@ -117,7 +117,7 @@ func (s *Server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
type staticLeaseJSON struct { type staticLeaseJSON struct {
HWAddr string `json:"mac"` HWAddr string `json:"mac"`
IP string `json:"ip"` IP net.IP `json:"ip"`
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
} }
@ -225,10 +225,10 @@ func (s *Server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
type netInterfaceJSON struct { type netInterfaceJSON struct {
Name string `json:"name"` Name string `json:"name"`
GatewayIP string `json:"gateway_ip"` GatewayIP net.IP `json:"gateway_ip"`
HardwareAddr string `json:"hardware_address"` HardwareAddr string `json:"hardware_address"`
Addrs4 []string `json:"ipv4_addresses"` Addrs4 []net.IP `json:"ipv4_addresses"`
Addrs6 []string `json:"ipv6_addresses"` Addrs6 []net.IP `json:"ipv6_addresses"`
Flags string `json:"flags"` Flags string `json:"flags"`
} }
@ -277,9 +277,9 @@ func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
continue continue
} }
if ipnet.IP.To4() != nil { if ipnet.IP.To4() != nil {
jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP.String()) jsonIface.Addrs4 = append(jsonIface.Addrs4, ipnet.IP)
} else { } else {
jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP.String()) jsonIface.Addrs6 = append(jsonIface.Addrs6, ipnet.IP)
} }
} }
if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 { if len(jsonIface.Addrs4)+len(jsonIface.Addrs6) != 0 {
@ -375,50 +375,46 @@ func (s *Server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request
err := json.NewDecoder(r.Body).Decode(&lj) err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }
ip := net.ParseIP(lj.IP) if lj.IP == nil {
if ip != nil && ip.To4() == nil { httpError(r, w, http.StatusBadRequest, "invalid IP")
mac, err := net.ParseMAC(lj.HWAddr)
return
}
ip4 := lj.IP.To4()
mac, err := net.ParseMAC(lj.HWAddr)
lease := Lease{
HWAddr: mac,
}
if ip4 == nil {
lease.IP = lj.IP.To16()
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC") httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{ return
IP: ip,
HWAddr: mac,
} }
err = s.srv6.AddStaticLease(lease) err = s.srv6.AddStaticLease(lease)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) httpError(r, w, http.StatusBadRequest, "%s", err)
return
} }
return return
} }
ip, _ = parseIPv4(lj.IP) lease.IP = ip4
if ip == nil { lease.Hostname = lj.Hostname
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, err := net.ParseMAC(lj.HWAddr)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = s.srv4.AddStaticLease(lease) err = s.srv4.AddStaticLease(lease)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) httpError(r, w, http.StatusBadRequest, "%s", err)
return return
} }
} }
@ -428,46 +424,46 @@ func (s *Server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Requ
err := json.NewDecoder(r.Body).Decode(&lj) err := json.NewDecoder(r.Body).Decode(&lj)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
return return
} }
ip := net.ParseIP(lj.IP) if lj.IP == nil {
if ip != nil && ip.To4() == nil { httpError(r, w, http.StatusBadRequest, "invalid IP")
mac, err := net.ParseMAC(lj.HWAddr)
return
}
ip4 := lj.IP.To4()
mac, err := net.ParseMAC(lj.HWAddr)
lease := Lease{
HWAddr: mac,
}
if ip4 == nil {
lease.IP = lj.IP.To16()
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid MAC") httpError(r, w, http.StatusBadRequest, "invalid MAC")
return
}
lease := Lease{ return
IP: ip,
HWAddr: mac,
} }
err = s.srv6.RemoveStaticLease(lease) err = s.srv6.RemoveStaticLease(lease)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) httpError(r, w, http.StatusBadRequest, "%s", err)
return
} }
return return
} }
ip, _ = parseIPv4(lj.IP) lease.IP = ip4
if ip == nil { lease.Hostname = lj.Hostname
httpError(r, w, http.StatusBadRequest, "invalid IP")
return
}
mac, _ := net.ParseMAC(lj.HWAddr)
lease := Lease{
IP: ip,
HWAddr: mac,
Hostname: lj.Hostname,
}
err = s.srv4.RemoveStaticLease(lease) err = s.srv4.RemoveStaticLease(lease)
if err != nil { if err != nil {
httpError(r, w, http.StatusBadRequest, "%s", err) httpError(r, w, http.StatusBadRequest, "%s", err)
return return
} }
} }

View File

@ -14,15 +14,17 @@ func isTimeout(err error) bool {
return operr.Timeout() return operr.Timeout()
} }
func parseIPv4(text string) (net.IP, error) { func tryTo4(ip net.IP) (ip4 net.IP, err error) {
result := net.ParseIP(text) if ip == nil {
if result == nil { return nil, fmt.Errorf("%v is not an IP address", ip)
return nil, fmt.Errorf("%s is not an IP address", text)
} }
if result.To4() == nil {
return nil, fmt.Errorf("%s is not an IPv4 address", text) ip4 = ip.To4()
if ip4 == nil {
return nil, fmt.Errorf("%v is not an IPv4 address", ip)
} }
return result.To4(), nil
return ip4, nil
} }
// Return TRUE if subnet mask is correct (e.g. 255.255.255.0) // Return TRUE if subnet mask is correct (e.g. 255.255.255.0)

View File

@ -36,13 +36,13 @@ type V4ServerConf struct {
Enabled bool `yaml:"-"` Enabled bool `yaml:"-"`
InterfaceName string `yaml:"-"` InterfaceName string `yaml:"-"`
GatewayIP string `yaml:"gateway_ip"` GatewayIP net.IP `yaml:"gateway_ip"`
SubnetMask string `yaml:"subnet_mask"` SubnetMask net.IP `yaml:"subnet_mask"`
// The first & the last IP address for dynamic leases // The first & the last IP address for dynamic leases
// Bytes [0..2] of the last allowed IP address must match the first IP // Bytes [0..2] of the last allowed IP address must match the first IP
RangeStart string `yaml:"range_start"` RangeStart net.IP `yaml:"range_start"`
RangeEnd string `yaml:"range_end"` RangeEnd net.IP `yaml:"range_end"`
LeaseDuration uint32 `yaml:"lease_duration"` // in seconds LeaseDuration uint32 `yaml:"lease_duration"` // in seconds

View File

@ -589,7 +589,7 @@ func (s *v4Server) Start() error {
s.conf.dnsIPAddrs = dnsIPAddrs s.conf.dnsIPAddrs = dnsIPAddrs
laddr := &net.UDPAddr{ laddr := &net.UDPAddr{
IP: net.ParseIP("0.0.0.0"), IP: net.IP{0, 0, 0, 0},
Port: dhcpv4.ServerPort, Port: dhcpv4.ServerPort,
} }
s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger()) s.srv, err = server4.NewServer(iface.Name, laddr, s.packetHandler, server4.WithDebugLogger())
@ -632,19 +632,18 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
} }
var err error var err error
s.conf.routerIP, err = parseIPv4(s.conf.GatewayIP) s.conf.routerIP, err = tryTo4(s.conf.GatewayIP)
if err != nil { if err != nil {
return s, fmt.Errorf("dhcpv4: %w", err) return s, fmt.Errorf("dhcpv4: %w", err)
} }
subnet, err := parseIPv4(s.conf.SubnetMask) if s.conf.SubnetMask == nil {
if err != nil || !isValidSubnetMask(subnet) { return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask)
return s, fmt.Errorf("dhcpv4: invalid subnet mask: %s", s.conf.SubnetMask)
} }
s.conf.subnetMask = make([]byte, 4) s.conf.subnetMask = make([]byte, 4)
copy(s.conf.subnetMask, subnet) copy(s.conf.subnetMask, s.conf.SubnetMask.To4())
s.conf.ipStart, err = parseIPv4(conf.RangeStart) s.conf.ipStart, err = tryTo4(conf.RangeStart)
if s.conf.ipStart == nil { if s.conf.ipStart == nil {
return s, fmt.Errorf("dhcpv4: %w", err) return s, fmt.Errorf("dhcpv4: %w", err)
} }
@ -652,7 +651,7 @@ func v4Create(conf V4ServerConf) (DHCPServer, error) {
return s, fmt.Errorf("dhcpv4: invalid range start IP") return s, fmt.Errorf("dhcpv4: invalid range start IP")
} }
s.conf.ipEnd, err = parseIPv4(conf.RangeEnd) s.conf.ipEnd, err = tryTo4(conf.RangeEnd)
if s.conf.ipEnd == nil { if s.conf.ipEnd == nil {
return s, fmt.Errorf("dhcpv4: %w", err) return s, fmt.Errorf("dhcpv4: %w", err)
} }

View File

@ -16,119 +16,119 @@ func notify4(flags uint32) {
func TestV4StaticLeaseAddRemove(t *testing.T) { func TestV4StaticLeaseAddRemove(t *testing.T) {
conf := V4ServerConf{ conf := V4ServerConf{
Enabled: true, Enabled: true,
RangeStart: "192.168.10.100", RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: "192.168.10.200", RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: "192.168.10.1", GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: "255.255.255.0", SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4, notify: notify4,
} }
s, err := v4Create(conf) s, err := v4Create(conf)
assert.True(t, err == nil) assert.Nil(t, err)
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls)) assert.Empty(t, ls)
// add static lease // add static lease
l := Lease{} l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4() l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// try to add the same static lease - fail // try to add the same static lease - fail
assert.True(t, s.AddStaticLease(l) != nil) assert.NotNil(t, s.AddStaticLease(l))
// check // check
ls = s.GetLeases(LeasesStatic) ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls)) assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
// try to remove static lease - fail // try to remove static lease - fail
l.IP = net.ParseIP("192.168.10.110").To4() l.IP = net.IP{192, 168, 10, 110}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil) assert.NotNil(t, s.RemoveStaticLease(l))
// remove static lease // remove static lease
l.IP = net.ParseIP("192.168.10.150").To4() l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil) assert.Nil(t, s.RemoveStaticLease(l))
// check // check
ls = s.GetLeases(LeasesStatic) ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls)) assert.Empty(t, ls)
} }
func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) { func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V4ServerConf{ conf := V4ServerConf{
Enabled: true, Enabled: true,
RangeStart: "192.168.10.100", RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: "192.168.10.200", RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: "192.168.10.1", GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: "255.255.255.0", SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4, notify: notify4,
} }
sIface, err := v4Create(conf) sIface, err := v4Create(conf)
s := sIface.(*v4Server) s := sIface.(*v4Server)
assert.True(t, err == nil) assert.Nil(t, err)
// add dynamic lease // add dynamic lease
ld := Lease{} ld := Lease{}
ld.IP = net.ParseIP("192.168.10.150").To4() ld.IP = net.IP{192, 168, 10, 150}
ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa") ld.HWAddr, _ = net.ParseMAC("11:aa:aa:aa:aa:aa")
s.addLease(&ld) s.addLease(&ld)
// add dynamic lease // add dynamic lease
{ {
ld := Lease{} ld := Lease{}
ld.IP = net.ParseIP("192.168.10.151").To4() ld.IP = net.IP{192, 168, 10, 151}
ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") ld.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
s.addLease(&ld) s.addLease(&ld)
} }
// add static lease with the same IP // add static lease with the same IP
l := Lease{} l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4() l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// add static lease with the same MAC // add static lease with the same MAC
l = Lease{} l = Lease{}
l.IP = net.ParseIP("192.168.10.152").To4() l.IP = net.IP{192, 168, 10, 152}
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// check // check
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls)) assert.Len(t, ls, 2)
assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
assert.Equal(t, "192.168.10.152", ls[1].IP.String()) assert.Equal(t, "192.168.10.152", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic) assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix())
} }
func TestV4StaticLeaseGet(t *testing.T) { func TestV4StaticLeaseGet(t *testing.T) {
conf := V4ServerConf{ conf := V4ServerConf{
Enabled: true, Enabled: true,
RangeStart: "192.168.10.100", RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: "192.168.10.200", RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: "192.168.10.1", GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: "255.255.255.0", SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4, notify: notify4,
} }
sIface, err := v4Create(conf) sIface, err := v4Create(conf)
s := sIface.(*v4Server) s := sIface.(*v4Server)
assert.True(t, err == nil) assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()} s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
l := Lease{} l := Lease{}
l.IP = net.ParseIP("192.168.10.150").To4() l.IP = net.IP{192, 168, 10, 150}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// "Discover" // "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
@ -160,12 +160,12 @@ func TestV4StaticLeaseGet(t *testing.T) {
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS() dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs)) assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
// check lease // check lease
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls)) assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.Equal(t, "192.168.10.150", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
} }
@ -173,10 +173,10 @@ func TestV4StaticLeaseGet(t *testing.T) {
func TestV4DynamicLeaseGet(t *testing.T) { func TestV4DynamicLeaseGet(t *testing.T) {
conf := V4ServerConf{ conf := V4ServerConf{
Enabled: true, Enabled: true,
RangeStart: "192.168.10.100", RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: "192.168.10.200", RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: "192.168.10.1", GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: "255.255.255.0", SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4, notify: notify4,
Options: []string{ Options: []string{
"81 hex 303132", "81 hex 303132",
@ -185,8 +185,8 @@ func TestV4DynamicLeaseGet(t *testing.T) {
} }
sIface, err := v4Create(conf) sIface, err := v4Create(conf)
s := sIface.(*v4Server) s := sIface.(*v4Server)
assert.True(t, err == nil) assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("192.168.10.1").To4()} s.conf.dnsIPAddrs = []net.IP{{192, 168, 10, 1}}
// "Discover" // "Discover"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
@ -220,19 +220,19 @@ func TestV4DynamicLeaseGet(t *testing.T) {
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS() dnsAddrs := resp.DNS()
assert.Equal(t, 1, len(dnsAddrs)) assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) assert.Equal(t, "192.168.10.1", dnsAddrs[0].String())
// check lease // check lease
ls := s.GetLeases(LeasesDynamic) ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls)) assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.100", ls[0].IP.String()) assert.Equal(t, "192.168.10.100", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
start := net.ParseIP("192.168.10.100").To4() start := net.IP{192, 168, 10, 100}
stop := net.ParseIP("192.168.10.200").To4() stop := net.IP{192, 168, 10, 200}
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.10.99").To4())) assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 10, 99}))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.100").To4())) assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 100}))
assert.True(t, !ip4InRange(start, stop, net.ParseIP("192.168.11.201").To4())) assert.False(t, ip4InRange(start, stop, net.IP{192, 168, 11, 201}))
assert.True(t, ip4InRange(start, stop, net.ParseIP("192.168.10.100").To4())) assert.True(t, ip4InRange(start, stop, net.IP{192, 168, 10, 100}))
} }

View File

@ -21,40 +21,40 @@ func TestV6StaticLeaseAddRemove(t *testing.T) {
notify: notify6, notify: notify6,
} }
s, err := v6Create(conf) s, err := v6Create(conf)
assert.True(t, err == nil) assert.Nil(t, err)
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls)) assert.Empty(t, ls)
// add static lease // add static lease
l := Lease{} l := Lease{}
l.IP = net.ParseIP("2001::1") l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// try to add static lease - fail // try to add static lease - fail
assert.True(t, s.AddStaticLease(l) != nil) assert.NotNil(t, s.AddStaticLease(l))
// check // check
ls = s.GetLeases(LeasesStatic) ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls)) assert.Len(t, ls, 1)
assert.Equal(t, "2001::1", ls[0].IP.String()) assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
// try to remove static lease - fail // try to remove static lease - fail
l.IP = net.ParseIP("2001::2") l.IP = net.ParseIP("2001::2")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) != nil) assert.NotNil(t, s.RemoveStaticLease(l))
// remove static lease // remove static lease
l.IP = net.ParseIP("2001::1") l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.RemoveStaticLease(l) == nil) assert.Nil(t, s.RemoveStaticLease(l))
// check // check
ls = s.GetLeases(LeasesStatic) ls = s.GetLeases(LeasesStatic)
assert.Equal(t, 0, len(ls)) assert.Empty(t, ls)
} }
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
@ -65,7 +65,7 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
} }
sIface, err := v6Create(conf) sIface, err := v6Create(conf)
s := sIface.(*v6Server) s := sIface.(*v6Server)
assert.True(t, err == nil) assert.Nil(t, err)
// add dynamic lease // add dynamic lease
ld := Lease{} ld := Lease{}
@ -85,25 +85,25 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
l := Lease{} l := Lease{}
l.IP = net.ParseIP("2001::1") l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("33:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// add static lease with the same MAC // add static lease with the same MAC
l = Lease{} l = Lease{}
l.IP = net.ParseIP("2001::3") l.IP = net.ParseIP("2001::3")
l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("22:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// check // check
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 2, len(ls)) assert.Len(t, ls, 2)
assert.Equal(t, "2001::1", ls[0].IP.String()) assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, ls[0].Expiry.Unix() == leaseExpireStatic) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
assert.Equal(t, "2001::3", ls[1].IP.String()) assert.Equal(t, "2001::3", ls[1].IP.String())
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.True(t, ls[1].Expiry.Unix() == leaseExpireStatic) assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix())
} }
func TestV6GetLease(t *testing.T) { func TestV6GetLease(t *testing.T) {
@ -114,7 +114,7 @@ func TestV6GetLease(t *testing.T) {
} }
sIface, err := v6Create(conf) sIface, err := v6Create(conf)
s := sIface.(*v6Server) s := sIface.(*v6Server)
assert.True(t, err == nil) assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")} s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
s.sid = dhcpv6.Duid{ s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT, Type: dhcpv6.DUID_LLT,
@ -125,7 +125,7 @@ func TestV6GetLease(t *testing.T) {
l := Lease{} l := Lease{}
l.IP = net.ParseIP("2001::1") l.IP = net.ParseIP("2001::1")
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
assert.True(t, s.AddStaticLease(l) == nil) assert.Nil(t, s.AddStaticLease(l))
// "Solicit" // "Solicit"
mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa")
@ -156,12 +156,12 @@ func TestV6GetLease(t *testing.T) {
assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), oiaAddr.ValidLifetime.Seconds())
dnsAddrs := resp.Options.DNS() dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs)) assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "2000::1", dnsAddrs[0].String()) assert.Equal(t, "2000::1", dnsAddrs[0].String())
// check lease // check lease
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Equal(t, 1, len(ls)) assert.Len(t, ls, 1)
assert.Equal(t, "2001::1", ls[0].IP.String()) assert.Equal(t, "2001::1", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
} }
@ -174,7 +174,7 @@ func TestV6GetDynamicLease(t *testing.T) {
} }
sIface, err := v6Create(conf) sIface, err := v6Create(conf)
s := sIface.(*v6Server) s := sIface.(*v6Server)
assert.True(t, err == nil) assert.Nil(t, err)
s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")} s.conf.dnsIPAddrs = []net.IP{net.ParseIP("2000::1")}
s.sid = dhcpv6.Duid{ s.sid = dhcpv6.Duid{
Type: dhcpv6.DUID_LLT, Type: dhcpv6.DUID_LLT,
@ -209,17 +209,17 @@ func TestV6GetDynamicLease(t *testing.T) {
assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String()) assert.Equal(t, "2001::2", oiaAddr.IPv6Addr.String())
dnsAddrs := resp.Options.DNS() dnsAddrs := resp.Options.DNS()
assert.Equal(t, 1, len(dnsAddrs)) assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "2000::1", dnsAddrs[0].String()) assert.Equal(t, "2000::1", dnsAddrs[0].String())
// check lease // check lease
ls := s.GetLeases(LeasesDynamic) ls := s.GetLeases(LeasesDynamic)
assert.Equal(t, 1, len(ls)) assert.Len(t, ls, 1)
assert.Equal(t, "2001::2", ls[0].IP.String()) assert.Equal(t, "2001::2", ls[0].IP.String())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1"))) assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::1")))
assert.True(t, !ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2"))) assert.False(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2002::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2"))) assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::2")))
assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3"))) assert.True(t, ip6InRange(net.ParseIP("2001::2"), net.ParseIP("2001::3")))
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"net" "net"
"strings"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/AdGuardHome/internal/testutil"
@ -135,7 +134,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) { if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text) assert.Equal(t, "0.0.0.0 block.com", res.Rules[0].Text)
assert.Len(t, res.Rules[0].IP, 0) assert.Empty(t, res.Rules[0].IP)
} }
// IPv6 // IPv6
@ -147,7 +146,7 @@ func TestEtcHostsMatching(t *testing.T) {
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
if assert.Len(t, res.Rules, 1) { if assert.Len(t, res.Rules, 1) {
assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text) assert.Equal(t, "::1 ipv6.com", res.Rules[0].Text)
assert.Len(t, res.Rules[0].IP, 0) assert.Empty(t, res.Rules[0].IP)
} }
// 2 IPv4 (return only the first one) // 2 IPv4 (return only the first one)
@ -180,7 +179,7 @@ func TestSafeBrowsing(t *testing.T) {
defer d.Close() defer d.Close()
d.checkMatch(t, "wmconvirus.narod.ru") d.checkMatch(t, "wmconvirus.narod.ru")
assert.True(t, strings.Contains(logOutput.String(), "SafeBrowsing lookup for wmconvirus.narod.ru")) assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for wmconvirus.narod.ru")
d.checkMatch(t, "test.wmconvirus.narod.ru") d.checkMatch(t, "test.wmconvirus.narod.ru")
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
@ -268,7 +267,7 @@ func TestSafeSearchCacheYandex(t *testing.T) {
res, err := d.CheckHost(domain, dns.TypeA, &setts) res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
assert.Len(t, res.Rules, 0) assert.Empty(t, res.Rules)
d = NewForTest(&Config{SafeSearchEnabled: true}, nil) d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Close() defer d.Close()
@ -298,7 +297,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
res, err := d.CheckHost(domain, dns.TypeA, &setts) res, err := d.CheckHost(domain, dns.TypeA, &setts)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
assert.Len(t, res.Rules, 0) assert.Empty(t, res.Rules)
d = NewForTest(&Config{SafeSearchEnabled: true}, nil) d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Close() defer d.Close()
@ -346,7 +345,7 @@ func TestParentalControl(t *testing.T) {
d := NewForTest(&Config{ParentalEnabled: true}, nil) d := NewForTest(&Config{ParentalEnabled: true}, nil)
defer d.Close() defer d.Close()
d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com")
assert.True(t, strings.Contains(logOutput.String(), "Parental lookup for pornhub.com")) assert.Contains(t, logOutput.String(), "Parental lookup for pornhub.com")
d.checkMatch(t, "www.pornhub.com") d.checkMatch(t, "www.pornhub.com")
d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
@ -468,18 +467,20 @@ func TestWhitelist(t *testing.T) {
// matched by white filter // matched by white filter
res, err := d.CheckHost("host1", dns.TypeA, &setts) res, err := d.CheckHost("host1", dns.TypeA, &setts)
assert.True(t, err == nil) assert.Nil(t, err)
assert.True(t, !res.IsFiltered && res.Reason == NotFilteredAllowList) assert.False(t, res.IsFiltered)
assert.Equal(t, res.Reason, NotFilteredAllowList)
if assert.Len(t, res.Rules, 1) { if assert.Len(t, res.Rules, 1) {
assert.True(t, res.Rules[0].Text == "||host1^") assert.Equal(t, "||host1^", res.Rules[0].Text)
} }
// not matched by white filter, but matched by block filter // not matched by white filter, but matched by block filter
res, err = d.CheckHost("host2", dns.TypeA, &setts) res, err = d.CheckHost("host2", dns.TypeA, &setts)
assert.True(t, err == nil) assert.Nil(t, err)
assert.True(t, res.IsFiltered && res.Reason == FilteredBlockList) assert.True(t, res.IsFiltered)
assert.Equal(t, res.Reason, FilteredBlockList)
if assert.Len(t, res.Rules, 1) { if assert.Len(t, res.Rules, 1) {
assert.True(t, res.Rules[0].Text == "||host2^") assert.Equal(t, "||host2^", res.Rules[0].Text)
} }
} }
@ -529,7 +530,7 @@ func TestClientSettings(t *testing.T) {
// not blocked // not blocked
r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts)
assert.True(t, !r.IsFiltered) assert.False(t, r.IsFiltered)
// override client settings: // override client settings:
applyClientSettings(&setts) applyClientSettings(&setts)
@ -554,7 +555,8 @@ func TestClientSettings(t *testing.T) {
// blocked by additional rules // blocked by additional rules
r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts)
assert.True(t, r.IsFiltered && r.Reason == FilteredBlockedService) assert.True(t, r.IsFiltered)
assert.Equal(t, r.Reason, FilteredBlockedService)
} }
// BENCHMARKS // BENCHMARKS

View File

@ -171,7 +171,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
res, err := f.CheckHostRules(host, dtyp, setts) res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "", res.CanonName) assert.Empty(t, res.CanonName)
if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) { if dnsrr := res.DNSRewriteResult; assert.NotNil(t, dnsrr) {
assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode) assert.Equal(t, dns.RcodeSuccess, dnsrr.RCode)
@ -197,7 +197,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
res, err := f.CheckHostRules(host, dtyp, setts) res, err := f.CheckHostRules(host, dtyp, setts)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "", res.CanonName) assert.Empty(t, res.CanonName)
assert.Len(t, res.Rules, 0) assert.Empty(t, res.Rules)
}) })
} }

View File

@ -27,14 +27,14 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("www.host.com", dns.TypeA) r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.Equal(t, 2, len(r.IPList)) assert.Len(t, r.IPList, 2)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
assert.True(t, r.IPList[1].Equal(net.ParseIP("1.2.3.5"))) assert.True(t, r.IPList[1].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host.com", dns.TypeAAAA) r = d.processRewrites("www.host.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4"))) assert.True(t, r.IPList[0].Equal(net.ParseIP("1:2:3::4")))
// wildcard // wildcard
@ -45,11 +45,11 @@ func TestRewrites(t *testing.T) {
d.prepareRewrites() d.prepareRewrites()
r = d.processRewrites("host.com", dns.TypeA) r = d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
r = d.processRewrites("www.host.com", dns.TypeA) r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.5"))) assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 5}))
r = d.processRewrites("www.host2.com", dns.TypeA) r = d.processRewrites("www.host2.com", dns.TypeA)
assert.Equal(t, NotFilteredNotFound, r.Reason) assert.Equal(t, NotFilteredNotFound, r.Reason)
@ -62,8 +62,8 @@ func TestRewrites(t *testing.T) {
d.prepareRewrites() d.prepareRewrites()
r = d.processRewrites("a.host.com", dns.TypeA) r = d.processRewrites("a.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.True(t, len(r.IPList) == 1) assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// wildcard + CNAME // wildcard + CNAME
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{
@ -74,7 +74,7 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("www.host.com", dns.TypeA) r = d.processRewrites("www.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs // 2 CNAMEs
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{
@ -86,8 +86,8 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("b.host.com", dns.TypeA) r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "host.com", r.CanonName) assert.Equal(t, "host.com", r.CanonName)
assert.True(t, len(r.IPList) == 1) assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
// 2 CNAMEs + wildcard // 2 CNAMEs + wildcard
d.Rewrites = []RewriteEntry{ d.Rewrites = []RewriteEntry{
@ -99,8 +99,8 @@ func TestRewrites(t *testing.T) {
r = d.processRewrites("b.host.com", dns.TypeA) r = d.processRewrites("b.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, "x.somehost.com", r.CanonName) assert.Equal(t, "x.somehost.com", r.CanonName)
assert.True(t, len(r.IPList) == 1) assert.Len(t, r.IPList, 1)
assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) assert.True(t, r.IPList[0].Equal(net.IP{1, 2, 3, 4}))
} }
func TestRewritesLevels(t *testing.T) { func TestRewritesLevels(t *testing.T) {
@ -116,19 +116,19 @@ func TestRewritesLevels(t *testing.T) {
// match exact // match exact
r := d.processRewrites("host.com", dns.TypeA) r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.1.1.1", r.IPList[0].String()) assert.Equal(t, "1.1.1.1", r.IPList[0].String())
// match L2 // match L2
r = d.processRewrites("sub.host.com", dns.TypeA) r = d.processRewrites("sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String()) assert.Equal(t, "2.2.2.2", r.IPList[0].String())
// match L3 // match L3
r = d.processRewrites("my.sub.host.com", dns.TypeA) r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.Equal(t, "3.3.3.3", r.IPList[0].String()) assert.Equal(t, "3.3.3.3", r.IPList[0].String())
} }
@ -144,7 +144,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
// match sub-domain // match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA) r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String()) assert.Equal(t, "2.2.2.2", r.IPList[0].String())
// match sub-domain, but handle exception // match sub-domain, but handle exception
@ -164,7 +164,7 @@ func TestRewritesExceptionWC(t *testing.T) {
// match sub-domain // match sub-domain
r := d.processRewrites("my.host.com", dns.TypeA) r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String()) assert.Equal(t, "2.2.2.2", r.IPList[0].String())
// match sub-domain, but handle exception // match sub-domain, but handle exception
@ -187,7 +187,7 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain // match domain
r := d.processRewrites("host.com", dns.TypeA) r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.2.3.4", r.IPList[0].String()) assert.Equal(t, "1.2.3.4", r.IPList[0].String())
// match exception // match exception
@ -201,7 +201,7 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain // match domain
r = d.processRewrites("host2.com", dns.TypeAAAA) r = d.processRewrites("host2.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 1, len(r.IPList)) assert.Len(t, r.IPList, 1)
assert.Equal(t, "::1", r.IPList[0].String()) assert.Equal(t, "::1", r.IPList[0].String())
// match exception // match exception
@ -211,5 +211,5 @@ func TestRewritesExceptionIP(t *testing.T) {
// match domain // match domain
r = d.processRewrites("host3.com", dns.TypeAAAA) r = d.processRewrites("host3.com", dns.TypeAAAA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Equal(t, 0, len(r.IPList)) assert.Empty(t, r.IPList)
} }

View File

@ -37,8 +37,8 @@ func (d *DNSFilter) initSecurityServices() error {
opts := upstream.Options{ opts := upstream.Options{
Timeout: dnsTimeout, Timeout: dnsTimeout,
ServerIPAddrs: []net.IP{ ServerIPAddrs: []net.IP{
net.ParseIP("94.140.14.15"), {94, 140, 14, 15},
net.ParseIP("94.140.15.16"), {94, 140, 15, 16},
net.ParseIP("2a10:50c0::bad1:ff"), net.ParseIP("2a10:50c0::bad1:ff"),
net.ParseIP("2a10:50c0::bad2:ff"), net.ParseIP("2a10:50c0::bad2:ff"),
}, },

View File

@ -14,7 +14,7 @@ import (
func TestSafeBrowsingHash(t *testing.T) { func TestSafeBrowsingHash(t *testing.T) {
// test hostnameToHashes() // test hostnameToHashes()
hashes := hostnameToHashes("1.2.3.sub.host.com") hashes := hostnameToHashes("1.2.3.sub.host.com")
assert.Equal(t, 3, len(hashes)) assert.Len(t, hashes, 3)
_, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))] _, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))]
assert.True(t, ok) assert.True(t, ok)
_, ok = hashes[sha256.Sum256([]byte("sub.host.com"))] _, ok = hashes[sha256.Sum256([]byte("sub.host.com"))]
@ -31,9 +31,9 @@ func TestSafeBrowsingHash(t *testing.T) {
q := c.getQuestion() q := c.getQuestion()
assert.True(t, strings.Contains(q, "7a1b.")) assert.Contains(t, q, "7a1b.")
assert.True(t, strings.Contains(q, "af5a.")) assert.Contains(t, q, "af5a.")
assert.True(t, strings.Contains(q, "eb11.")) assert.Contains(t, q, "eb11.")
assert.True(t, strings.HasSuffix(q, "sb.dns.adguard.com.")) assert.True(t, strings.HasSuffix(q, "sb.dns.adguard.com."))
} }
@ -81,7 +81,7 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com" c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com")) hash = sha256.Sum256([]byte("nonexisting.com"))
c.hashToHost[hash] = "nonexisting.com" c.hashToHost[hash] = "nonexisting.com"
assert.Equal(t, 0, c.getCached()) assert.Empty(t, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com")) hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash] _, ok := c.hashToHost[hash]
@ -103,7 +103,7 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com" c.hashToHost[hash] = "sub.host.com"
c.cache.Set(hash[0:2], make([]byte, 32)) c.cache.Set(hash[0:2], make([]byte, 32))
assert.Equal(t, 0, c.getCached()) assert.Empty(t, c.getCached())
} }
// testErrUpstream implements upstream.Upstream interface for replacing real // testErrUpstream implements upstream.Upstream interface for replacing real

View File

@ -8,28 +8,28 @@ import (
func TestIsBlockedIPAllowed(t *testing.T) { func TestIsBlockedIPAllowed(t *testing.T) {
a := &accessCtx{} a := &accessCtx{}
assert.True(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil) == nil) assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
assert.True(t, disallowed) assert.True(t, disallowed)
assert.Equal(t, "", disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
assert.True(t, disallowed) assert.True(t, disallowed)
assert.Equal(t, "", disallowedRule) assert.Empty(t, disallowedRule)
} }
func TestIsBlockedIPDisallowed(t *testing.T) { func TestIsBlockedIPDisallowed(t *testing.T) {
a := &accessCtx{} a := &accessCtx{}
assert.True(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil) == nil) assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1")
assert.True(t, disallowed) assert.True(t, disallowed)
@ -37,7 +37,7 @@ func TestIsBlockedIPDisallowed(t *testing.T) {
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2")
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1")
assert.True(t, disallowed) assert.True(t, disallowed)
@ -45,7 +45,7 @@ func TestIsBlockedIPDisallowed(t *testing.T) {
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1")
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Equal(t, "", disallowedRule) assert.Empty(t, disallowedRule)
} }
func TestIsBlockedIPBlockedDomain(t *testing.T) { func TestIsBlockedIPBlockedDomain(t *testing.T) {
@ -60,13 +60,13 @@ func TestIsBlockedIPBlockedDomain(t *testing.T) {
// match by "host2.com" // match by "host2.com"
assert.True(t, a.IsBlockedDomain("host1")) assert.True(t, a.IsBlockedDomain("host1"))
assert.True(t, a.IsBlockedDomain("host2")) assert.True(t, a.IsBlockedDomain("host2"))
assert.True(t, !a.IsBlockedDomain("host3")) assert.False(t, a.IsBlockedDomain("host3"))
// match by wildcard "*.host.com" // match by wildcard "*.host.com"
assert.True(t, !a.IsBlockedDomain("host.com")) assert.False(t, a.IsBlockedDomain("host.com"))
assert.True(t, a.IsBlockedDomain("asdf.host.com")) assert.True(t, a.IsBlockedDomain("asdf.host.com"))
assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com")) assert.True(t, a.IsBlockedDomain("qwer.asdf.host.com"))
assert.True(t, !a.IsBlockedDomain("asdf.zhost.com")) assert.False(t, a.IsBlockedDomain("asdf.zhost.com"))
// match by wildcard "||host3.com^" // match by wildcard "||host3.com^"
assert.True(t, a.IsBlockedDomain("host3.com")) assert.True(t, a.IsBlockedDomain("host3.com"))

View File

@ -29,17 +29,16 @@ type FilteringConfig struct {
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration // GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client // based on the client IP address. Returns nil if there are no custom upstreams for the client
// TODO(e.burkov): replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"` GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// Protection configuration // Protection configuration
// -- // --
ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features
BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests
BlockingIPv4 string `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request BlockingIPv4 net.IP `yaml:"blocking_ipv4"` // IP address to be returned for a blocked A request
BlockingIPv6 string `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request BlockingIPv6 net.IP `yaml:"blocking_ipv6"` // IP address to be returned for a blocked AAAA request
BlockingIPAddrv4 net.IP `yaml:"-"`
BlockingIPAddrv6 net.IP `yaml:"-"`
BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600)
// IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing // IP (or domain name) which is used to respond to DNS requests blocked by parental control or safe-browsing

View File

@ -182,7 +182,7 @@ func processInternalHosts(ctx *dnsContext) int {
return resultDone return resultDone
} }
log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip.String()) log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip)
resp := s.makeResponse(req) resp := s.makeResponse(req)
@ -278,7 +278,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) int {
return resultDone return resultDone
} }
// Pass request to upstream servers; process the response // processUpstream passes request to upstream servers and handles the response.
func processUpstream(ctx *dnsContext) int { func processUpstream(ctx *dnsContext) int {
s := ctx.srv s := ctx.srv
d := ctx.proxyCtx d := ctx.proxyCtx
@ -287,7 +287,7 @@ func processUpstream(ctx *dnsContext) int {
} }
if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil { if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil {
clientIP := ipFromAddr(d.Addr) clientIP := IPStringFromAddr(d.Addr)
upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP) upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP)
if upstreamsConf != nil { if upstreamsConf != nil {
log.Debug("Using custom upstreams for %s", clientIP) log.Debug("Using custom upstreams for %s", clientIP)

View File

@ -178,9 +178,7 @@ func (s *Server) Prepare(config *ServerConfig) error {
if config != nil { if config != nil {
s.conf = *config s.conf = *config
if s.conf.BlockingMode == "custom_ip" { if s.conf.BlockingMode == "custom_ip" {
s.conf.BlockingIPAddrv4 = net.ParseIP(s.conf.BlockingIPv4) if s.conf.BlockingIPv4 == nil || s.conf.BlockingIPv6 == nil {
s.conf.BlockingIPAddrv6 = net.ParseIP(s.conf.BlockingIPv6)
if s.conf.BlockingIPAddrv4 == nil || s.conf.BlockingIPAddrv6 == nil {
return fmt.Errorf("dns: invalid custom blocking IP address specified") return fmt.Errorf("dns: invalid custom blocking IP address specified")
} }
} }

View File

@ -286,7 +286,7 @@ func TestBlockedRequest(t *testing.T) {
t.Fatalf("Couldn't talk to server %s: %s", addr, err) t.Fatalf("Couldn't talk to server %s: %s", addr, err)
} }
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0"))) assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0}))
err = s.Stop() err = s.Stop()
if err != nil { if err != nil {
@ -300,7 +300,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
uc := &proxy.UpstreamConfig{} uc := &proxy.UpstreamConfig{}
u := &testUpstream{} u := &testUpstream{}
u.ipv4 = map[string][]net.IP{} u.ipv4 = map[string][]net.IP{}
u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")} u.ipv4["host."] = []net.IP{{192, 168, 0, 1}}
uc.Upstreams = append(uc.Upstreams, u) uc.Upstreams = append(uc.Upstreams, u)
return uc return uc
} }
@ -425,7 +425,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.ProtectionEnabled = false s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm) err := s.startWithUpstream(testUpstm)
assert.True(t, err == nil) assert.Nil(t, err)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters: // 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
@ -440,16 +440,16 @@ func TestBlockCNAME(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
err := s.startWithUpstream(testUpstm) err := s.startWithUpstream(testUpstm)
assert.True(t, err == nil) assert.Nil(t, err)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters: // 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
// response is blocked // response is blocked
req := createTestMessage("badhost.") req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err, nil) assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0"))) assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0}))
// 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters
// but 'whitelist.example.org' is in a whitelist: // but 'whitelist.example.org' is in a whitelist:
@ -465,7 +465,7 @@ func TestBlockCNAME(t *testing.T) {
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.ParseIP("0.0.0.0"))) assert.True(t, reply.Answer[0].(*dns.A).A.Equal(net.IP{0, 0, 0, 0}))
_ = s.Stop() _ = s.Stop()
} }
@ -548,13 +548,13 @@ func TestBlockedCustomIP(t *testing.T) {
conf.TCPListenAddr = &net.TCPAddr{Port: 0} conf.TCPListenAddr = &net.TCPAddr{Port: 0}
conf.ProtectionEnabled = true conf.ProtectionEnabled = true
conf.BlockingMode = "custom_ip" conf.BlockingMode = "custom_ip"
conf.BlockingIPv4 = "bad IP" conf.BlockingIPv4 = nil
conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
err := s.Prepare(&conf) err := s.Prepare(&conf)
assert.True(t, err != nil) // invalid BlockingIPv4 assert.NotNil(t, err) // invalid BlockingIPv4
conf.BlockingIPv4 = "0.0.0.1" conf.BlockingIPv4 = net.IP{0, 0, 0, 1}
conf.BlockingIPv6 = "::1" conf.BlockingIPv6 = net.ParseIP("::1")
err = s.Prepare(&conf) err = s.Prepare(&conf)
assert.Nil(t, err) assert.Nil(t, err)
err = s.Start() err = s.Start()
@ -565,7 +565,7 @@ func TestBlockedCustomIP(t *testing.T) {
req := createTestMessageWithType("null.example.org.", dns.TypeA) req := createTestMessageWithType("null.example.org.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, len(reply.Answer)) assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "0.0.0.1", a.A.String()) assert.Equal(t, "0.0.0.1", a.A.String())
@ -573,7 +573,7 @@ func TestBlockedCustomIP(t *testing.T) {
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, len(reply.Answer)) assert.Len(t, reply.Answer, 1)
a6, ok := reply.Answer[0].(*dns.AAAA) a6, ok := reply.Answer[0].(*dns.AAAA)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "::1", a6.AAAA.String()) assert.Equal(t, "::1", a6.AAAA.String())
@ -710,7 +710,7 @@ func TestRewrite(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA) req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String()) reply, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, len(reply.Answer)) assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "1.2.3.4", a.A.String()) assert.Equal(t, "1.2.3.4", a.A.String())
@ -718,12 +718,12 @@ func TestRewrite(t *testing.T) {
req = createTestMessageWithType("test.com.", dns.TypeAAAA) req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 0, len(reply.Answer)) assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA) req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, len(reply.Answer)) assert.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String()) assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String())
@ -731,7 +731,7 @@ func TestRewrite(t *testing.T) {
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored
assert.Equal(t, 2, len(reply.Answer)) assert.Len(t, reply.Answer, 2)
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
@ -765,7 +765,7 @@ func createTestServer(t *testing.T) *Server {
s.conf.ConfigModified = func() {} s.conf.ConfigModified = func() {}
err := s.Prepare(nil) err := s.Prepare(nil)
assert.True(t, err == nil) assert.Nil(t, err)
return s return s
} }
@ -1011,16 +1011,14 @@ func TestValidateUpstreamsSet(t *testing.T) {
assert.NotNil(t, err, "there is an invalid upstream in set, but it pass through validation") assert.NotNil(t, err, "there is an invalid upstream in set, but it pass through validation")
} }
func TestIpFromAddr(t *testing.T) { func TestIPStringFromAddr(t *testing.T) {
addr := net.UDPAddr{} addr := net.UDPAddr{}
addr.IP = net.ParseIP("1:2:3::4") addr.IP = net.ParseIP("1:2:3::4")
addr.Port = 12345 addr.Port = 12345
addr.Zone = "eth0" addr.Zone = "eth0"
a := ipFromAddr(&addr) assert.Equal(t, IPStringFromAddr(&addr), net.ParseIP("1:2:3::4").String())
assert.True(t, a == "1:2:3::4")
a = ipFromAddr(nil) assert.Empty(t, IPStringFromAddr(nil))
assert.True(t, a == "")
} }
func TestMatchDNSName(t *testing.T) { func TestMatchDNSName(t *testing.T) {
@ -1030,9 +1028,9 @@ func TestMatchDNSName(t *testing.T) {
assert.True(t, matchDNSName(dnsNames, "a.host2")) assert.True(t, matchDNSName(dnsNames, "a.host2"))
assert.True(t, matchDNSName(dnsNames, "b.a.host2")) assert.True(t, matchDNSName(dnsNames, "b.a.host2"))
assert.True(t, matchDNSName(dnsNames, "1.2.3.4")) assert.True(t, matchDNSName(dnsNames, "1.2.3.4"))
assert.True(t, !matchDNSName(dnsNames, "host2")) assert.False(t, matchDNSName(dnsNames, "host2"))
assert.True(t, !matchDNSName(dnsNames, "")) assert.False(t, matchDNSName(dnsNames, ""))
assert.True(t, !matchDNSName(dnsNames, "*.host2")) assert.False(t, matchDNSName(dnsNames, "*.host2"))
} }
type testDHCP struct { type testDHCP struct {
@ -1040,7 +1038,7 @@ type testDHCP struct {
func (d *testDHCP) Leases(flags int) []dhcpd.Lease { func (d *testDHCP) Leases(flags int) []dhcpd.Lease {
l := dhcpd.Lease{} l := dhcpd.Lease{}
l.IP = net.ParseIP("127.0.0.1").To4() l.IP = net.IP{127, 0, 0, 1}
l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa")
l.Hostname = "localhost" l.Hostname = "localhost"
return []dhcpd.Lease{l} return []dhcpd.Lease{l}
@ -1058,7 +1056,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil) err := s.Prepare(nil)
assert.True(t, err == nil) assert.Nil(t, err)
assert.Nil(t, s.Start()) assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1067,7 +1065,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
resp, err := dns.Exchange(req, addr.String()) resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, len(resp.Answer)) assert.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr := resp.Answer[0].(*dns.PTR) ptr := resp.Answer[0].(*dns.PTR)
@ -1100,7 +1098,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil) err := s.Prepare(nil)
assert.True(t, err == nil) assert.Nil(t, err)
assert.Nil(t, s.Start()) assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1109,7 +1107,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
resp, err := dns.Exchange(req, addr.String()) resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, len(resp.Answer)) assert.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr := resp.Answer[0].(*dns.PTR) ptr := resp.Answer[0].(*dns.PTR)

View File

@ -12,7 +12,7 @@ import (
) )
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := ipFromAddr(d.Addr) ip := IPStringFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip) disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed { if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip) log.Tracef("Client IP %s is blocked by settings", ip)
@ -36,7 +36,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
setts := s.dnsFilter.GetConfig() setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
if s.conf.FilterHandler != nil { if s.conf.FilterHandler != nil {
clientAddr := ipFromAddr(d.Addr) clientAddr := IPStringFromAddr(d.Addr)
s.conf.FilterHandler(clientAddr, &setts) s.conf.FilterHandler(clientAddr, &setts)
} }
return &setts return &setts

View File

@ -28,8 +28,8 @@ type dnsConfig struct {
ProtectionEnabled *bool `json:"protection_enabled"` ProtectionEnabled *bool `json:"protection_enabled"`
RateLimit *uint32 `json:"ratelimit"` RateLimit *uint32 `json:"ratelimit"`
BlockingMode *string `json:"blocking_mode"` BlockingMode *string `json:"blocking_mode"`
BlockingIPv4 *string `json:"blocking_ipv4"` BlockingIPv4 net.IP `json:"blocking_ipv4"`
BlockingIPv6 *string `json:"blocking_ipv6"` BlockingIPv6 net.IP `json:"blocking_ipv6"`
EDNSCSEnabled *bool `json:"edns_cs_enabled"` EDNSCSEnabled *bool `json:"edns_cs_enabled"`
DNSSECEnabled *bool `json:"dnssec_enabled"` DNSSECEnabled *bool `json:"dnssec_enabled"`
DisableIPv6 *bool `json:"disable_ipv6"` DisableIPv6 *bool `json:"disable_ipv6"`
@ -68,8 +68,8 @@ func (s *Server) getDNSConfig() dnsConfig {
Bootstraps: &bootstraps, Bootstraps: &bootstraps,
ProtectionEnabled: &protectionEnabled, ProtectionEnabled: &protectionEnabled,
BlockingMode: &blockingMode, BlockingMode: &blockingMode,
BlockingIPv4: &BlockingIPv4, BlockingIPv4: BlockingIPv4,
BlockingIPv6: &BlockingIPv6, BlockingIPv6: BlockingIPv6,
RateLimit: &Ratelimit, RateLimit: &Ratelimit,
EDNSCSEnabled: &EnableEDNSClientSubnet, EDNSCSEnabled: &EnableEDNSClientSubnet,
DNSSECEnabled: &EnableDNSSEC, DNSSECEnabled: &EnableDNSSEC,
@ -100,17 +100,11 @@ func (req *dnsConfig) checkBlockingMode() bool {
bm := *req.BlockingMode bm := *req.BlockingMode
if bm == "custom_ip" { if bm == "custom_ip" {
if req.BlockingIPv4 == nil || req.BlockingIPv6 == nil { if req.BlockingIPv4.To4() == nil {
return false return false
} }
ip4 := net.ParseIP(*req.BlockingIPv4) return req.BlockingIPv6 != nil
if ip4 == nil || ip4.To4() == nil {
return false
}
ip6 := net.ParseIP(*req.BlockingIPv6)
return ip6 != nil
} }
for _, valid := range []string{ for _, valid := range []string{
@ -247,10 +241,8 @@ func (s *Server) setConfig(dc dnsConfig) (restart bool) {
if dc.BlockingMode != nil { if dc.BlockingMode != nil {
s.conf.BlockingMode = *dc.BlockingMode s.conf.BlockingMode = *dc.BlockingMode
if *dc.BlockingMode == "custom_ip" { if *dc.BlockingMode == "custom_ip" {
s.conf.BlockingIPv4 = *dc.BlockingIPv4 s.conf.BlockingIPv4 = dc.BlockingIPv4.To4()
s.conf.BlockingIPAddrv4 = net.ParseIP(*dc.BlockingIPv4) s.conf.BlockingIPv6 = dc.BlockingIPv6.To16()
s.conf.BlockingIPv6 = *dc.BlockingIPv6
s.conf.BlockingIPAddrv6 = net.ParseIP(*dc.BlockingIPv6)
} }
} }

View File

@ -60,9 +60,9 @@ func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Resu
switch m.Question[0].Qtype { switch m.Question[0].Qtype {
case dns.TypeA: case dns.TypeA:
return s.genARecord(m, s.conf.BlockingIPAddrv4) return s.genARecord(m, s.conf.BlockingIPv4)
case dns.TypeAAAA: case dns.TypeAAAA:
return s.genAAAARecord(m, s.conf.BlockingIPAddrv6) return s.genAAAARecord(m, s.conf.BlockingIPv6)
} }
} else if s.conf.BlockingMode == "nxdomain" { } else if s.conf.BlockingMode == "nxdomain" {
// means that we should return NXDOMAIN for any blocked request // means that we should return NXDOMAIN for any blocked request

View File

@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
OrigAnswer: ctx.origResp, OrigAnswer: ctx.origResp,
Result: ctx.result, Result: ctx.result,
Elapsed: elapsed, Elapsed: elapsed,
ClientIP: getIP(d.Addr), ClientIP: ipFromAddr(d.Addr),
} }
switch d.Proto { switch d.Proto {

View File

@ -8,38 +8,8 @@ import (
"github.com/AdguardTeam/golibs/utils" "github.com/AdguardTeam/golibs/utils"
) )
// GetIPString is a helper function that extracts IP address from net.Addr // ipFromAddr gets IP address from addr.
func GetIPString(addr net.Addr) string { func ipFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// Get IP address from net.Addr object
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func ipFromAddr(a net.Addr) string {
switch addr := a.(type) {
case *net.UDPAddr:
return addr.IP.String()
case *net.TCPAddr:
return addr.IP.String()
}
return ""
}
// Get IP address from net.Addr
func getIP(addr net.Addr) net.IP {
switch addr := addr.(type) { switch addr := addr.(type) {
case *net.UDPAddr: case *net.UDPAddr:
return addr.IP return addr.IP
@ -49,6 +19,23 @@ func getIP(addr net.Addr) net.IP {
return nil return nil
} }
// IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipstr string) {
if ip := ipFromAddr(addr); ip != nil {
return ip.String()
}
return ""
}
func stringArrayDup(a []string) []string {
a2 := make([]string, len(a))
copy(a2, a)
return a2
}
// Find value in a sorted array // Find value in a sorted array
func findSorted(ar []string, val string) int { func findSorted(ar []string, val string) int {
i := sort.SearchStrings(ar, val) i := sort.SearchStrings(ar, val)

View File

@ -70,7 +70,7 @@ func TestAuth(t *testing.T) {
a.Close() a.Close()
u := a.UserFind("name", "password") u := a.UserFind("name", "password")
assert.True(t, len(u.Name) != 0) assert.NotEmpty(t, u.Name)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
@ -125,9 +125,9 @@ func TestAuthHTTP(t *testing.T) {
r.URL = &url.URL{Path: "/"} r.URL = &url.URL{Path: "/"}
handlerCalled = false handlerCalled = false
handler2(&w, &r) handler2(&w, &r)
assert.True(t, w.statusCode == http.StatusFound) assert.Equal(t, http.StatusFound, w.statusCode)
assert.True(t, w.hdr.Get("Location") != "") assert.NotEmpty(t, w.hdr.Get("Location"))
assert.True(t, !handlerCalled) assert.False(t, handlerCalled)
// go to login page // go to login page
loginURL := w.hdr.Get("Location") loginURL := w.hdr.Get("Location")
@ -139,7 +139,7 @@ func TestAuthHTTP(t *testing.T) {
// perform login // perform login
cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}) cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"})
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, cookie != "") assert.NotEmpty(t, cookie)
// get / // get /
handler2 = optionalAuth(handler) handler2 = optionalAuth(handler)
@ -168,8 +168,8 @@ func TestAuthHTTP(t *testing.T) {
r.URL = &url.URL{Path: loginURL} r.URL = &url.URL{Path: loginURL}
handlerCalled = false handlerCalled = false
handler2(&w, &r) handler2(&w, &r)
assert.True(t, w.hdr.Get("Location") != "") assert.NotEmpty(t, w.hdr.Get("Location"))
assert.True(t, !handlerCalled) assert.False(t, handlerCalled)
r.Header.Del("Cookie") r.Header.Del("Cookie")
// get login page with an invalid cookie // get login page with an invalid cookie

View File

@ -37,15 +37,18 @@ func TestClients(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
c, b = clients.Find("1.1.1.1") c, b = clients.Find("1.1.1.1")
assert.True(t, b && c.Name == "client1") assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, b = clients.Find("1:2:3::4") c, b = clients.Find("1:2:3::4")
assert.True(t, b && c.Name == "client1") assert.True(t, b)
assert.Equal(t, c.Name, "client1")
c, b = clients.Find("2.2.2.2") c, b = clients.Find("2.2.2.2")
assert.True(t, b && c.Name == "client2") assert.True(t, b)
assert.Equal(t, c.Name, "client2")
assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile)) assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
}) })
@ -109,7 +112,7 @@ func TestClients(t *testing.T) {
err := clients.Update("client1", c) err := clients.Update("client1", c)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
c = Client{ c = Client{
@ -123,8 +126,8 @@ func TestClients(t *testing.T) {
c, b := clients.Find("1.1.1.2") c, b := clients.Find("1.1.1.2")
assert.True(t, b) assert.True(t, b)
assert.True(t, c.Name == "client1-renamed") assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.IDs[0] == "1.1.1.2") assert.Equal(t, "1.1.1.2", c.IDs[0])
assert.True(t, c.UseOwnSettings) assert.True(t, c.UseOwnSettings)
assert.Nil(t, clients.list["client1"]) assert.Nil(t, clients.list["client1"])
}) })
@ -172,12 +175,12 @@ func TestClientsWhois(t *testing.T) {
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client // set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois) clients.SetWhoisInfo("1.1.1.255", whois)
assert.True(t, clients.ipHost["1.1.1.255"].WhoisInfo[0][1] == "orgname-val") assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1])
// set whois info on existing auto-client // set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) _, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois) clients.SetWhoisInfo("1.1.1.1", whois)
assert.True(t, clients.ipHost["1.1.1.1"].WhoisInfo[0][1] == "orgname-val") assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1])
// Check that we cannot set whois info on a manually-added client // Check that we cannot set whois info on a manually-added client
c = Client{ c = Client{
@ -186,7 +189,7 @@ func TestClientsWhois(t *testing.T) {
} }
_, _ = clients.Add(c) _, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois) clients.SetWhoisInfo("1.1.1.2", whois)
assert.True(t, clients.ipHost["1.1.1.2"] == nil) assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1") _ = clients.Del("client1")
} }
@ -272,6 +275,6 @@ func TestClientsCustomUpstream(t *testing.T) {
config = clients.FindUpstreams("1.1.1.1") config = clients.FindUpstreams("1.1.1.1")
assert.NotNil(t, config) assert.NotNil(t, config)
assert.Equal(t, 1, len(config.Upstreams)) assert.Len(t, config.Upstreams, 1)
assert.Equal(t, 1, len(config.DomainReservedUpstreams)) assert.Len(t, config.DomainReservedUpstreams, 1)
} }

View File

@ -98,7 +98,7 @@ func isRunning() bool {
} }
func onDNSRequest(d *proxy.DNSContext) { func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.GetIPString(d.Addr) ip := dnsforward.IPStringFromAddr(d.Addr)
if ip == "" { if ip == "" {
// This would be quite weird if we get here // This would be quite weird if we get here
return return

View File

@ -50,16 +50,17 @@ func TestFilters(t *testing.T) {
// download // download
ok, err := Context.filters.update(&f) ok, err := Context.filters.update(&f)
assert.Equal(t, nil, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, 3, f.RulesCount) assert.Equal(t, 3, f.RulesCount)
// refresh // refresh
ok, err = Context.filters.update(&f) ok, err = Context.filters.update(&f)
assert.True(t, !ok && err == nil) assert.False(t, ok)
assert.Nil(t, err)
err = Context.filters.load(&f) err = Context.filters.load(&f)
assert.True(t, err == nil) assert.Nil(t, err)
f.unload() f.unload()
_ = os.Remove(f.Path()) _ = os.Remove(f.Path())

View File

@ -119,7 +119,7 @@ func TestHome(t *testing.T) {
fn := filepath.Join(dir, "AdGuardHome.yaml") fn := filepath.Join(dir, "AdGuardHome.yaml")
// Prepare the test config // Prepare the test config
assert.True(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644) == nil) assert.Nil(t, ioutil.WriteFile(fn, []byte(yamlConf), 0o644))
fn, _ = filepath.Abs(fn) fn, _ = filepath.Abs(fn)
config = configuration{} // the global variable is dirty because of the previous tests run config = configuration{} // the global variable is dirty because of the previous tests run
@ -138,11 +138,11 @@ func TestHome(t *testing.T) {
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
assert.Truef(t, err == nil, "%s", err) assert.Nilf(t, err, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
resp, err = h.Get("http://127.0.0.1:3000/control/status") resp, err = h.Get("http://127.0.0.1:3000/control/status")
assert.Truef(t, err == nil, "%s", err) assert.Nilf(t, err, "%s", err)
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
// test DNS over UDP // test DNS over UDP
@ -159,16 +159,16 @@ func TestHome(t *testing.T) {
req.RecursionDesired = true req.RecursionDesired = true
req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}} req.Question = []dns.Question{{Name: "static.adguard.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}
buf, err := req.Pack() buf, err := req.Pack()
assert.True(t, err == nil, "%s", err) assert.Nil(t, err)
requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf) requestURL := "http://127.0.0.1:3000/dns-query?dns=" + base64.RawURLEncoding.EncodeToString(buf)
resp, err = http.DefaultClient.Get(requestURL) resp, err = http.DefaultClient.Get(requestURL)
assert.True(t, err == nil, "%s", err) assert.Nil(t, err)
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
assert.True(t, err == nil, "%s", err) assert.Nil(t, err)
assert.True(t, resp.StatusCode == http.StatusOK) assert.Equal(t, http.StatusOK, resp.StatusCode)
response := dns.Msg{} response := dns.Msg{}
err = response.Unpack(body) err = response.Unpack(body)
assert.True(t, err == nil, "%s", err) assert.Nil(t, err)
addrs = nil addrs = nil
proxyutil.AppendIPAddrs(&addrs, response.Answer) proxyutil.AppendIPAddrs(&addrs, response.Answer)
haveIP = len(addrs) != 0 haveIP = len(addrs) != 0

View File

@ -23,7 +23,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) { if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@ -51,7 +51,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) { if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@ -89,7 +89,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) { if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)
@ -116,7 +116,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
_, err = plist.Unmarshal(w.Body.Bytes(), &mc) _, err = plist.Unmarshal(w.Body.Bytes(), &mc)
assert.Nil(t, err) assert.Nil(t, err)
if assert.Equal(t, 1, len(mc.PayloadContent)) { if assert.Len(t, mc.PayloadContent, 1) {
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name)
assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName)
assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName)

View File

@ -12,10 +12,10 @@ func TestResolveRDNS(t *testing.T) {
conf := &dnsforward.ServerConfig{} conf := &dnsforward.ServerConfig{}
conf.UpstreamDNS = []string{"8.8.8.8"} conf.UpstreamDNS = []string{"8.8.8.8"}
err := dns.Prepare(conf) err := dns.Prepare(conf)
assert.True(t, err == nil, "%s", err) assert.Nil(t, err)
clients := &clientsContainer{} clients := &clientsContainer{}
rdns := InitRDNS(dns, clients) rdns := InitRDNS(dns, clients)
r := rdns.resolve("1.1.1.1") r := rdns.resolve("1.1.1.1")
assert.True(t, r == "one.one.one.one", "%s", r) assert.Equal(t, "one.one.one.one", r, r)
} }

View File

@ -84,7 +84,7 @@ func TestDecodeLogEntry(t *testing.T) {
decodeLogEntry(got, data) decodeLogEntry(got, data)
s := logOutput.String() s := logOutput.String()
assert.Equal(t, "", s) assert.Empty(t, s)
// Correct for time zones. // Correct for time zones.
got.Time = got.Time.UTC() got.Time = got.Time.UTC()
@ -172,7 +172,7 @@ func TestDecodeLogEntry(t *testing.T) {
s := logOutput.String() s := logOutput.String()
if tc.want == "" { if tc.want == "" {
assert.Equal(t, "", s) assert.Empty(t, s)
} else { } else {
assert.True(t, strings.HasSuffix(s, tc.want), assert.True(t, strings.HasSuffix(s, tc.want),
"got %q", s) "got %q", s)

View File

@ -56,7 +56,7 @@ func TestQueryLog(t *testing.T) {
// get all entries // get all entries
params := newSearchParams() params := newSearchParams()
entries, _ := l.search(params) entries, _ := l.search(params)
assert.Equal(t, 4, len(entries)) assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
@ -70,7 +70,7 @@ func TestQueryLog(t *testing.T) {
value: "TEST.example.org", value: "TEST.example.org",
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 1, len(entries)) assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
// search by domain (not strict) // search by domain (not strict)
@ -81,7 +81,7 @@ func TestQueryLog(t *testing.T) {
value: "example.ORG", value: "example.ORG",
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 3, len(entries)) assert.Len(t, entries, 3)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2")
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1") assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1")
@ -94,7 +94,7 @@ func TestQueryLog(t *testing.T) {
value: "2.2.2.2", value: "2.2.2.2",
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 1, len(entries)) assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2")
// search by client IP (part of) // search by client IP (part of)
@ -105,7 +105,7 @@ func TestQueryLog(t *testing.T) {
value: "2.2.2", value: "2.2.2",
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 4, len(entries)) assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4")
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3")
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2")
@ -138,7 +138,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 0 params.offset = 0
params.limit = 10 params.limit = 10
entries, _ := l.search(params) entries, _ := l.search(params)
assert.Equal(t, 10, len(entries)) assert.Len(t, entries, 10)
assert.Equal(t, entries[0].QHost, "first.example.org") assert.Equal(t, entries[0].QHost, "first.example.org")
assert.Equal(t, entries[9].QHost, "first.example.org") assert.Equal(t, entries[9].QHost, "first.example.org")
@ -146,7 +146,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 10 params.offset = 10
params.limit = 10 params.limit = 10
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 10, len(entries)) assert.Len(t, entries, 10)
assert.Equal(t, entries[0].QHost, "second.example.org") assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[9].QHost, "second.example.org") assert.Equal(t, entries[9].QHost, "second.example.org")
@ -154,7 +154,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 15 params.offset = 15
params.limit = 10 params.limit = 10
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 5, len(entries)) assert.Len(t, entries, 5)
assert.Equal(t, entries[0].QHost, "second.example.org") assert.Equal(t, entries[0].QHost, "second.example.org")
assert.Equal(t, entries[4].QHost, "second.example.org") assert.Equal(t, entries[4].QHost, "second.example.org")
@ -162,7 +162,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = 20 params.offset = 20
params.limit = 10 params.limit = 10
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 0, len(entries)) assert.Empty(t, entries)
} }
func TestQueryLogMaxFileScanEntries(t *testing.T) { func TestQueryLogMaxFileScanEntries(t *testing.T) {
@ -186,11 +186,11 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
params := newSearchParams() params := newSearchParams()
params.maxFileScanEntries = 5 // do not scan more than 5 records params.maxFileScanEntries = 5 // do not scan more than 5 records
entries, _ := l.search(params) entries, _ := l.search(params)
assert.Equal(t, 5, len(entries)) assert.Len(t, entries, 5)
params.maxFileScanEntries = 0 // disable the limit params.maxFileScanEntries = 0 // disable the limit
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Equal(t, 10, len(entries)) assert.Len(t, entries, 10)
} }
func TestQueryLogFileDisabled(t *testing.T) { func TestQueryLogFileDisabled(t *testing.T) {
@ -211,7 +211,7 @@ func TestQueryLogFileDisabled(t *testing.T) {
params := newSearchParams() params := newSearchParams()
ll, _ := l.search(params) ll, _ := l.search(params)
assert.Equal(t, 2, len(ll)) assert.Len(t, ll, 2)
assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost) assert.Equal(t, "example2.org", ll[1].QHost)
} }
@ -262,7 +262,7 @@ func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string)
msg := new(dns.Msg) msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer)) assert.Nil(t, msg.Unpack(entry.Answer))
assert.Equal(t, 1, len(msg.Answer)) assert.Len(t, msg.Answer, 1)
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]) ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0])
assert.NotNil(t, ip) assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String()) assert.Equal(t, answer, ip.String())

View File

@ -28,12 +28,12 @@ func TestQLogFileEmpty(t *testing.T) {
// seek to the start // seek to the start
pos, err := q.SeekStart() pos, err := q.SeekStart()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(0), pos) assert.EqualValues(t, 0, pos)
// try reading anyway // try reading anyway
line, err := q.ReadNext() line, err := q.ReadNext()
assert.Equal(t, io.EOF, err) assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line) assert.Empty(t, line)
} }
func TestQLogFileLarge(t *testing.T) { func TestQLogFileLarge(t *testing.T) {
@ -53,14 +53,14 @@ func TestQLogFileLarge(t *testing.T) {
// seek to the start // seek to the start
pos, err := q.SeekStart() pos, err := q.SeekStart()
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos) assert.NotEqualValues(t, 0, pos)
read := 0 read := 0
var line string var line string
for err == nil { for err == nil {
line, err = q.ReadNext() line, err = q.ReadNext()
if err == nil { if err == nil {
assert.True(t, len(line) > 0) assert.NotZero(t, len(line))
read++ read++
} }
} }
@ -109,10 +109,10 @@ func TestQLogFileSeekLargeFile(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
// ALMOST the record we need // ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1 timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp) assert.NotEqualValues(t, 0, timestamp)
_, depth, err := q.SeekTS(timestamp) _, depth, err := q.SeekTS(timestamp)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3)) assert.LessOrEqual(t, depth, int(math.Log2(float64(count))+3))
} }
func TestQLogFileSeekSmallFile(t *testing.T) { func TestQLogFileSeekSmallFile(t *testing.T) {
@ -155,22 +155,22 @@ func TestQLogFileSeekSmallFile(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
// ALMOST the record we need // ALMOST the record we need
timestamp := readQLogTimestamp(line) - 1 timestamp := readQLogTimestamp(line) - 1
assert.NotEqual(t, uint64(0), timestamp) assert.NotEqualValues(t, 0, timestamp)
_, depth, err := q.SeekTS(timestamp) _, depth, err := q.SeekTS(timestamp)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.True(t, depth <= int(math.Log2(float64(count))+3)) assert.LessOrEqual(t, depth, int(math.Log2(float64(count))+3))
} }
func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) { func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) {
line, err := getQLogFileLine(q, lineNumber) line, err := getQLogFileLine(q, lineNumber)
assert.Nil(t, err) assert.Nil(t, err)
ts := readQLogTimestamp(line) ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts) assert.NotEqualValues(t, 0, ts)
// try seeking to that line now // try seeking to that line now
pos, _, err := q.SeekTS(ts) pos, _, err := q.SeekTS(ts)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, int64(0), pos) assert.NotEqualValues(t, 0, pos)
testLine, err := q.ReadNext() testLine, err := q.ReadNext()
assert.Nil(t, err) assert.Nil(t, err)
@ -207,27 +207,27 @@ func TestQLogFile(t *testing.T) {
// seek to the start // seek to the start
pos, err := q.SeekStart() pos, err := q.SeekStart()
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, pos > 0) assert.Greater(t, pos, int64(0))
// read first line // read first line
line, err := q.ReadNext() line, err := q.ReadNext()
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, strings.Contains(line, "0.0.0.2"), line) assert.Contains(t, line, "0.0.0.2")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line) assert.True(t, strings.HasSuffix(line, "}"), line)
// read second line // read second line
line, err = q.ReadNext() line, err = q.ReadNext()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(0), q.position) assert.EqualValues(t, 0, q.position)
assert.True(t, strings.Contains(line, "0.0.0.1"), line) assert.Contains(t, line, "0.0.0.1")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line) assert.True(t, strings.HasSuffix(line, "}"), line)
// try reading again (there's nothing to read anymore) // try reading again (there's nothing to read anymore)
line, err = q.ReadNext() line, err = q.ReadNext()
assert.Equal(t, io.EOF, err) assert.Equal(t, io.EOF, err)
assert.Equal(t, "", line) assert.Empty(t, line)
} }
// prepareTestFile - prepares a test query log file with the specified number of lines // prepareTestFile - prepares a test query log file with the specified number of lines

View File

@ -21,7 +21,7 @@ func TestQLogReaderEmpty(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
line, err := r.ReadNext() line, err := r.ReadNext()
assert.Equal(t, "", line) assert.Empty(t, line)
assert.Equal(t, io.EOF, err) assert.Equal(t, io.EOF, err)
} }
@ -241,7 +241,7 @@ func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) {
line, err := getQLogReaderLine(r, lineNumber) line, err := getQLogReaderLine(r, lineNumber)
assert.Nil(t, err) assert.Nil(t, err)
ts := readQLogTimestamp(line) ts := readQLogTimestamp(line)
assert.NotEqual(t, uint64(0), ts) assert.NotEqualValues(t, 0, ts)
// try seeking to that line now // try seeking to that line now
err = r.SeekTS(ts) err = r.SeekTS(ts)

View File

@ -39,13 +39,13 @@ func TestStats(t *testing.T) {
e := Entry{} e := Entry{}
e.Domain = "domain" e.Domain = "domain"
e.Client = net.ParseIP("127.0.0.1") e.Client = net.IP{127, 0, 0, 1}
e.Result = RFiltered e.Result = RFiltered
e.Time = 123456 e.Time = 123456
s.Update(e) s.Update(e)
e.Domain = "domain" e.Domain = "domain"
e.Client = net.ParseIP("127.0.0.1") e.Client = net.IP{127, 0, 0, 1}
e.Result = RNotFiltered e.Result = RNotFiltered
e.Time = 123456 e.Time = 123456
s.Update(e) s.Update(e)
@ -64,23 +64,23 @@ func TestStats(t *testing.T) {
assert.True(t, UIntArrayEquals(d["replaced_parental"].([]uint64), a)) assert.True(t, UIntArrayEquals(d["replaced_parental"].([]uint64), a))
m := d["top_queried_domains"].([]map[string]uint64) m := d["top_queried_domains"].([]map[string]uint64)
assert.True(t, m[0]["domain"] == 1) assert.EqualValues(t, 1, m[0]["domain"])
m = d["top_blocked_domains"].([]map[string]uint64) m = d["top_blocked_domains"].([]map[string]uint64)
assert.True(t, m[0]["domain"] == 1) assert.EqualValues(t, 1, m[0]["domain"])
m = d["top_clients"].([]map[string]uint64) m = d["top_clients"].([]map[string]uint64)
assert.True(t, m[0]["127.0.0.1"] == 2) assert.EqualValues(t, 2, m[0]["127.0.0.1"])
assert.True(t, d["num_dns_queries"].(uint64) == 2) assert.EqualValues(t, 2, d["num_dns_queries"].(uint64))
assert.True(t, d["num_blocked_filtering"].(uint64) == 1) assert.EqualValues(t, 1, d["num_blocked_filtering"].(uint64))
assert.True(t, d["num_replaced_safebrowsing"].(uint64) == 0) assert.EqualValues(t, 0, d["num_replaced_safebrowsing"].(uint64))
assert.True(t, d["num_replaced_safesearch"].(uint64) == 0) assert.EqualValues(t, 0, d["num_replaced_safesearch"].(uint64))
assert.True(t, d["num_replaced_parental"].(uint64) == 0) assert.EqualValues(t, 0, d["num_replaced_parental"].(uint64))
assert.True(t, d["avg_processing_time"].(float64) == 0.123456) assert.EqualValues(t, 0.123456, d["avg_processing_time"].(float64))
topClients := s.GetTopClientsIP(2) topClients := s.GetTopClientsIP(2)
assert.True(t, topClients[0] == "127.0.0.1") assert.Equal(t, "127.0.0.1", topClients[0])
s.clear() s.clear()
s.Close() s.Close()
@ -111,7 +111,7 @@ func TestLargeNumbers(t *testing.T) {
} }
for i := 0; i != n; i++ { for i := 0; i != n; i++ {
e.Domain = fmt.Sprintf("domain%d", i) e.Domain = fmt.Sprintf("domain%d", i)
e.Client = net.ParseIP("127.0.0.1") e.Client = net.IP{127, 0, 0, 1}
e.Client[2] = byte((i & 0xff00) >> 8) e.Client[2] = byte((i & 0xff00) >> 8)
e.Client[3] = byte(i & 0xff) e.Client[3] = byte(i & 0xff)
e.Result = RNotFiltered e.Result = RNotFiltered
@ -121,7 +121,7 @@ func TestLargeNumbers(t *testing.T) {
} }
d := s.getData() d := s.getData()
assert.True(t, d["num_dns_queries"].(uint64) == uint64(int(hour)*n)) assert.EqualValues(t, int(hour)*n, d["num_dns_queries"])
s.Close() s.Close()
os.Remove(conf.Filename) os.Remove(conf.Filename)
@ -152,6 +152,6 @@ func aggregateDataPerDay(firstID uint32) int {
func TestAggregateDataPerTimeUnit(t *testing.T) { func TestAggregateDataPerTimeUnit(t *testing.T) {
for i := 0; i != 25; i++ { for i := 0; i != 25; i++ {
alen := aggregateDataPerDay(uint32(i)) alen := aggregateDataPerDay(uint32(i))
assert.True(t, alen == 30, "i=%d", i) assert.Equalf(t, 30, alen, "i=%d", i)
} }
} }

View File

@ -19,12 +19,12 @@ func IfaceSetStaticIP(ifaceName string) (err error) {
} }
// GatewayIP returns IP address of interface's gateway. // GatewayIP returns IP address of interface's gateway.
func GatewayIP(ifaceName string) string { func GatewayIP(ifaceName string) net.IP {
cmd := exec.Command("ip", "route", "show", "dev", ifaceName) cmd := exec.Command("ip", "route", "show", "dev", ifaceName)
log.Tracef("executing %s %v", cmd.Path, cmd.Args) log.Tracef("executing %s %v", cmd.Path, cmd.Args)
d, err := cmd.Output() d, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 { if err != nil || cmd.ProcessState.ExitCode() != 0 {
return "" return nil
} }
fields := strings.Fields(string(d)) fields := strings.Fields(string(d))
@ -32,13 +32,8 @@ func GatewayIP(ifaceName string) string {
// "default" at first field and default gateway IP address at third // "default" at first field and default gateway IP address at third
// field. // field.
if len(fields) < 3 || fields[0] != "default" { if len(fields) < 3 || fields[0] != "default" {
return "" return nil
} }
ip := net.ParseIP(fields[2]) return net.ParseIP(fields[2])
if ip == nil {
return ""
}
return fields[2]
} }

View File

@ -129,7 +129,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
return err return err
} }
gatewayIP := GatewayIP(ifaceName) gatewayIP := GatewayIP(ifaceName)
add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4.String()) add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4)
body, err := ioutil.ReadFile("/etc/dhcpcd.conf") body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil { if err != nil {
@ -147,14 +147,14 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
// updateStaticIPdhcpcdConf sets static IP address for the interface by writing // updateStaticIPdhcpcdConf sets static IP address for the interface by writing
// into dhcpd.conf. // into dhcpd.conf.
func updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string { func updateStaticIPdhcpcdConf(ifaceName, ip string, gatewayIP, dnsIP net.IP) string {
var body []byte var body []byte
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n", add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
ifaceName, ip) ifaceName, ip)
body = append(body, []byte(add)...) body = append(body, []byte(add)...)
if len(gatewayIP) != 0 { if gatewayIP != nil {
add = fmt.Sprintf("static routers=%s\n", add = fmt.Sprintf("static routers=%s\n",
gatewayIP) gatewayIP)
body = append(body, []byte(add)...) body = append(body, []byte(add)...)

View File

@ -4,6 +4,7 @@ package sysutil
import ( import (
"bytes" "bytes"
"net"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -96,7 +97,7 @@ func TestSetStaticIPdhcpcdConf(t *testing.T) {
`static routers=192.168.0.1` + nl + `static routers=192.168.0.1` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl `static domain_name_servers=192.168.0.2` + nl + nl
s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2") s := updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", net.IP{192, 168, 0, 1}, net.IP{192, 168, 0, 2})
assert.Equal(t, dhcpcdConf, s) assert.Equal(t, dhcpcdConf, s)
// without gateway // without gateway
@ -104,6 +105,6 @@ func TestSetStaticIPdhcpcdConf(t *testing.T) {
`static ip_address=192.168.0.2/24` + nl + `static ip_address=192.168.0.2/24` + nl +
`static domain_name_servers=192.168.0.2` + nl + nl `static domain_name_servers=192.168.0.2` + nl + nl
s = updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2") s = updateStaticIPdhcpcdConf("wlan0", "192.168.0.2/24", nil, net.IP{192, 168, 0, 2})
assert.Equal(t, dhcpcdConf, s) assert.Equal(t, dhcpcdConf, s)
} }

View File

@ -42,7 +42,7 @@ func TestAutoHostsResolution(t *testing.T) {
// Existing host // Existing host
ips := ah.Process("localhost", dns.TypeA) ips := ah.Process("localhost", dns.TypeA)
assert.NotNil(t, ips) assert.NotNil(t, ips)
assert.Equal(t, 1, len(ips)) assert.Len(t, ips, 1)
assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0]) assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0])
// Unknown host // Unknown host
@ -107,7 +107,7 @@ func TestAutoHostsFSNotify(t *testing.T) {
// Check if we are notified about changes // Check if we are notified about changes
ips = ah.Process("newhost", dns.TypeA) ips = ah.Process("newhost", dns.TypeA)
assert.NotNil(t, ips) assert.NotNil(t, ips)
assert.Equal(t, 1, len(ips)) assert.Len(t, ips, 1)
assert.Equal(t, "127.0.0.2", ips[0].String()) assert.Equal(t, "127.0.0.2", ips[0].String())
} }

View File

@ -8,7 +8,8 @@ import (
func TestSplitNext(t *testing.T) { func TestSplitNext(t *testing.T) {
s := " a,b , c " s := " a,b , c "
assert.True(t, SplitNext(&s, ',') == "a") assert.Equal(t, "a", SplitNext(&s, ','))
assert.True(t, SplitNext(&s, ',') == "b") assert.Equal(t, "b", SplitNext(&s, ','))
assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0) assert.Equal(t, "c", SplitNext(&s, ','))
assert.Empty(t, s)
} }