From 1b15bee2b09df5b2c31edf6d6d594aa53ac9124c Mon Sep 17 00:00:00 2001
From: Aleksey Dmitrevskiy <ad@adguard.com>
Date: Wed, 6 Mar 2019 18:24:14 +0300
Subject: [PATCH] [change] control: add upstreams validation

---
 control.go | 34 ++++++++++++++++++++++++++--------
 1 file changed, 26 insertions(+), 8 deletions(-)

diff --git a/control.go b/control.go
index fb1cf0af..fe340ecd 100644
--- a/control.go
+++ b/control.go
@@ -322,15 +322,21 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	for _, u := range newconfig.Upstreams {
+		if err = validateUpstream(u); err != nil {
+			httpError(w, http.StatusBadRequest, "%s can not be used as upstream cause: %s", u, err)
+			return
+		}
+	}
+
 	config.DNS.UpstreamDNS = defaultDNS
 	if len(newconfig.Upstreams) > 0 {
 		config.DNS.UpstreamDNS = newconfig.Upstreams
 	}
 
-	// bootstrap servers are plain DNS only. We should return http error if there are tls:// https:// or sdns:// hosts in slice
+	// bootstrap servers are plain DNS only.
 	for _, host := range newconfig.BootstrapDNS {
-		err := checkPlainDNS(host)
-		if err != nil {
+		if err := checkPlainDNS(host); err != nil {
 			httpError(w, http.StatusBadRequest, "%s can not be used as bootstrap dns cause: %s", host, err)
 			return
 		}
@@ -345,26 +351,38 @@ func handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) {
 	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
+func validateUpstream(upstream string) error {
+	if strings.HasPrefix(upstream, "tls://") || strings.HasPrefix(upstream, "https://") || strings.HasPrefix(upstream, "sdns://") || strings.HasPrefix(upstream, "tcp://") {
+		return nil
+	}
+
+	if strings.Contains(upstream, "://") {
+		return fmt.Errorf("wrong protocol")
+	}
+
+	return checkPlainDNS(upstream)
+}
+
 // checkPlainDNS checks if host is plain DNS
-func checkPlainDNS(host string) error {
+func checkPlainDNS(upstream string) error {
 	// Check if host is ip without port
-	if net.ParseIP(host) != nil {
+	if net.ParseIP(upstream) != nil {
 		return nil
 	}
 
 	// Check if host is ip with port
-	ip, port, err := net.SplitHostPort(host)
+	ip, port, err := net.SplitHostPort(upstream)
 	if err != nil {
 		return err
 	}
 
 	if net.ParseIP(ip) == nil {
-		return fmt.Errorf("%s is not valid IP", ip)
+		return fmt.Errorf("%s is not a valid IP", ip)
 	}
 
 	_, err = strconv.ParseInt(port, 0, 64)
 	if err != nil {
-		return fmt.Errorf("%s is not valid port: %s", port, err)
+		return fmt.Errorf("%s is not a valid port: %s", port, err)
 	}
 
 	return nil