From 413228e6ecaf9019c8eac7034f1404b5c9c497fd Mon Sep 17 00:00:00 2001
From: Eugene Bujak <hmage@hmage.net>
Date: Wed, 10 Oct 2018 20:13:03 +0300
Subject: [PATCH] API backend -- implement ability to turn toggle all
 protection in one go, helpful to temporarily disable all kinds of filtering

---
 config.go    |   6 +-
 control.go   | 272 +++++++++++----------------------------------------
 openapi.yaml |  28 ++++++
 3 files changed, 91 insertions(+), 215 deletions(-)

diff --git a/config.go b/config.go
index 6130dff6..2a19f7b5 100644
--- a/config.go
+++ b/config.go
@@ -35,6 +35,7 @@ type coreDNSConfig struct {
 	coreFile            string
 	FilterFile          string   `yaml:"-"`
 	Port                int      `yaml:"port"`
+	ProtectionEnabled   bool     `yaml:"protection_enabled"`
 	FilteringEnabled    bool     `yaml:"filtering_enabled"`
 	SafeBrowsingEnabled bool     `yaml:"safebrowsing_enabled"`
 	SafeSearchEnabled   bool     `yaml:"safesearch_enabled"`
@@ -69,6 +70,7 @@ var config = configuration{
 		binaryFile:          "coredns",       // only filename, no path
 		coreFile:            "Corefile",      // only filename, no path
 		FilterFile:          "dnsfilter.txt", // only filename, no path
+		ProtectionEnabled:   true,
 		FilteringEnabled:    true,
 		SafeBrowsingEnabled: false,
 		BlockedResponseTTL:  60, // in seconds
@@ -165,13 +167,13 @@ func writeAllConfigs() error {
 }
 
 const coreDNSConfigTemplate = `. {
-    dnsfilter {{if .FilteringEnabled}}{{.FilterFile}}{{end}} {
+    {{if .ProtectionEnabled}}dnsfilter {{if .FilteringEnabled}}{{.FilterFile}}{{end}} {
         {{if .SafeBrowsingEnabled}}safebrowsing{{end}}
         {{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}}
         {{if .SafeSearchEnabled}}safesearch{{end}}
         {{if .QueryLogEnabled}}querylog{{end}}
         blocked_ttl {{.BlockedResponseTTL}}
-    }
+    }{{end}}
     {{.Pprof}}
     hosts {
         fallthrough
diff --git a/control.go b/control.go
index db305d69..e2e631b9 100644
--- a/control.go
+++ b/control.go
@@ -75,6 +75,26 @@ func writeAllConfigsAndReloadCoreDNS() error {
 	return nil
 }
 
+func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) {
+	err := writeAllConfigsAndReloadCoreDNS()
+	if err != nil {
+		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, http.StatusInternalServerError)
+		return
+	}
+	returnOK(w, r)
+}
+
+func returnOK(w http.ResponseWriter, r *http.Request) {
+	_, err := fmt.Fprintf(w, "OK\n")
+	if err != nil {
+		errortext := fmt.Sprintf("Couldn't write body: %s", err)
+		log.Println(errortext)
+		http.Error(w, errortext, http.StatusInternalServerError)
+	}
+}
+
 func isRunning() bool {
 	if coreDNSCommand != nil && coreDNSCommand.Process != nil {
 		pid := coreDNSCommand.Process.Pid
@@ -197,12 +217,13 @@ func handleRestart(w http.ResponseWriter, r *http.Request) {
 
 func handleStatus(w http.ResponseWriter, r *http.Request) {
 	data := map[string]interface{}{
-		"running":          isRunning(),
-		"version":          VersionString,
-		"dns_address":      config.BindHost,
-		"dns_port":         config.CoreDNS.Port,
-		"querylog_enabled": config.CoreDNS.QueryLogEnabled,
-		"upstream_dns":     config.CoreDNS.UpstreamDNS,
+		"dns_address":        config.BindHost,
+		"dns_port":           config.CoreDNS.Port,
+		"protection_enabled": config.CoreDNS.ProtectionEnabled,
+		"querylog_enabled":   config.CoreDNS.QueryLogEnabled,
+		"running":            isRunning(),
+		"upstream_dns":       config.CoreDNS.UpstreamDNS,
+		"version":            VersionString,
 	}
 
 	json, err := json.Marshal(data)
@@ -222,6 +243,16 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
+func handleProtectionEnable(w http.ResponseWriter, r *http.Request) {
+	config.CoreDNS.ProtectionEnabled = true
+	httpUpdateConfigReloadDNSReturnOK(w, r)
+}
+
+func handleProtectionDisable(w http.ResponseWriter, r *http.Request) {
+	config.CoreDNS.ProtectionEnabled = false
+	httpUpdateConfigReloadDNSReturnOK(w, r)
+}
+
 // -----
 // stats
 // -----
@@ -330,37 +361,12 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) {
 
 func handleQueryLogEnable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.QueryLogEnabled = true
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.QueryLogEnabled = false
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleStatsReset(w http.ResponseWriter, r *http.Request) {
@@ -662,38 +668,12 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
 
 func handleFilteringEnable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.FilteringEnabled = true
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleFilteringDisable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.FilteringEnabled = false
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleFilteringStatus(w http.ResponseWriter, r *http.Request) {
@@ -832,13 +812,6 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 	config.Filters = newFilters
-	err = writeAllConfigs()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
 	err = writeFilterFile()
 	if err != nil {
 		errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
@@ -846,14 +819,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, errortext, http.StatusInternalServerError)
 		return
 	}
-	tellCoreDNSToReload()
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
@@ -890,14 +856,6 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	err = writeAllConfigs()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-
 	// kick off refresh of rules from new URLs
 	refreshFiltersIfNeccessary()
 	err = writeFilterFile()
@@ -907,14 +865,7 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, errortext, http.StatusInternalServerError)
 		return
 	}
-	tellCoreDNSToReload()
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
@@ -951,13 +902,6 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	err = writeAllConfigs()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
 	err = writeFilterFile()
 	if err != nil {
 		errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
@@ -965,15 +909,7 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, errortext, http.StatusInternalServerError)
 		return
 	}
-	tellCoreDNSToReload()
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
-	// TODO: regenerate coredns config and tell coredns to reload it if it's running
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
@@ -986,13 +922,6 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
 	}
 
 	config.UserRules = strings.Split(string(body), "\n")
-	err = writeAllConfigs()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
 	err = writeFilterFile()
 	if err != nil {
 		errortext := fmt.Sprintf("Couldn't write filter file: %s", err)
@@ -1000,14 +929,7 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, errortext, http.StatusInternalServerError)
 		return
 	}
-	tellCoreDNSToReload()
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) {
@@ -1184,38 +1106,12 @@ func writeFilterFile() error {
 
 func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.SafeBrowsingEnabled = true
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.SafeBrowsingEnabled = false
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
@@ -1285,38 +1181,12 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) {
 	}
 	config.CoreDNS.ParentalSensitivity = i
 	config.CoreDNS.ParentalEnabled = true
-	err = writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleParentalDisable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.ParentalEnabled = false
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
@@ -1350,38 +1220,12 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) {
 
 func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.SafeSearchEnabled = true
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
 	config.CoreDNS.SafeSearchEnabled = false
-	err := writeAllConfigsAndReloadCoreDNS()
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write config file: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-		return
-	}
-	_, err = fmt.Fprintf(w, "OK\n")
-	if err != nil {
-		errortext := fmt.Sprintf("Couldn't write body: %s", err)
-		log.Println(errortext)
-		http.Error(w, errortext, http.StatusInternalServerError)
-	}
-
+	httpUpdateConfigReloadDNSReturnOK(w, r)
 }
 
 func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
@@ -1411,25 +1255,27 @@ func registerControlHandlers() {
 	http.HandleFunc("/control/stop", optionalAuth(ensurePOST(handleStop)))
 	http.HandleFunc("/control/restart", optionalAuth(ensurePOST(handleRestart)))
 	http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus)))
-	http.HandleFunc("/control/stats", optionalAuth(ensureGET(handleStats)))
-	http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(handleStatsHistory)))
-	http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(handleStatsTop)))
-	http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(handleStatsReset)))
+	http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable)))
+	http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable)))
 	http.HandleFunc("/control/querylog", optionalAuth(ensureGET(handleQueryLog)))
 	http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable)))
 	http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable)))
 	http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS)))
 	http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS)))
+	http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(handleStatsTop)))
+	http.HandleFunc("/control/stats", optionalAuth(ensureGET(handleStats)))
+	http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(handleStatsHistory)))
+	http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(handleStatsReset)))
 	http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON))
 	http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable)))
 	http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable)))
-	http.HandleFunc("/control/filtering/status", optionalAuth(ensureGET(handleFilteringStatus)))
 	http.HandleFunc("/control/filtering/add_url", optionalAuth(ensurePUT(handleFilteringAddURL)))
 	http.HandleFunc("/control/filtering/remove_url", optionalAuth(ensureDELETE(handleFilteringRemoveURL)))
 	http.HandleFunc("/control/filtering/enable_url", optionalAuth(ensurePOST(handleFilteringEnableURL)))
 	http.HandleFunc("/control/filtering/disable_url", optionalAuth(ensurePOST(handleFilteringDisableURL)))
-	http.HandleFunc("/control/filtering/set_rules", optionalAuth(ensurePUT(handleFilteringSetRules)))
 	http.HandleFunc("/control/filtering/refresh", optionalAuth(ensurePOST(handleFilteringRefresh)))
+	http.HandleFunc("/control/filtering/status", optionalAuth(ensureGET(handleFilteringStatus)))
+	http.HandleFunc("/control/filtering/set_rules", optionalAuth(ensurePUT(handleFilteringSetRules)))
 	http.HandleFunc("/control/safebrowsing/enable", optionalAuth(ensurePOST(handleSafeBrowsingEnable)))
 	http.HandleFunc("/control/safebrowsing/disable", optionalAuth(ensurePOST(handleSafeBrowsingDisable)))
 	http.HandleFunc("/control/safebrowsing/status", optionalAuth(ensureGET(handleSafeBrowsingStatus)))
diff --git a/openapi.yaml b/openapi.yaml
index e39fc090..e20d59d3 100644
--- a/openapi.yaml
+++ b/openapi.yaml
@@ -65,12 +65,31 @@ paths:
                         application/json:
                             dns_address: 127.0.0.1
                             dns_port: 53
+                            protection_enabled: true
                             querylog_enabled: true
                             running: true
                             upstream_dns:
                               - 1.1.1.1
                               - 1.0.0.1
                             version: "v0.1"
+    /enable_protection:
+        post:
+            tags:
+                -global
+            operationId: enableProtection
+            summary: "Enable protection (turns on dnsfilter module in coredns)"
+            responses:
+                200:
+                    description: OK
+    /disable_protection:
+        post:
+            tags:
+                -global
+            operationId: disableProtection
+            summary: "Disable protection (turns off filtering, sb, parental, safesearch temporarily by disabling dnsfilter module in coredns)"
+            responses:
+                200:
+                    description: OK
     /querylog:
         get:
             tags:
@@ -316,6 +335,15 @@ paths:
                                 - 123
                                 - 123
                                 - 123
+    /stats_reset:
+        post:
+            tags:
+                -global
+            operationId: statsReset
+            summary: "Reset all statistics to zeroes"
+            responses:
+                200:
+                    description: OK
     /filtering/enable:
         post:
             tags: