From 2dd6ea5161e3b39500ddf2daf3b7520856ae66e0 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Thu, 25 Apr 2019 14:57:03 +0300 Subject: [PATCH] + /control/update handler --- app.go | 18 ++- control.go | 45 +----- control_update.go | 341 +++++++++++++++++++++++++++++++++++++++++ control_update_test.go | 39 +++++ 4 files changed, 397 insertions(+), 46 deletions(-) create mode 100644 control_update.go create mode 100644 control_update_test.go diff --git a/app.go b/app.go index 200a3bd3..13c71a3a 100644 --- a/app.go +++ b/app.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "crypto/tls" "fmt" "io" @@ -30,6 +31,7 @@ var httpsServer struct { server *http.Server cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey sync.Mutex // protects config.TLS + shutdown bool // if TRUE, don't restart the server } var pidFileName string // PID file name. Empty if no PID file was created. @@ -171,7 +173,7 @@ func run(args options) { go httpServerLoop() // this loop is used as an ability to change listening host and/or port - for { + for !httpsServer.shutdown { printHTTPAddresses("http") // we need to have new instance, because after Shutdown() the Server is not usable @@ -186,10 +188,13 @@ func run(args options) { } // We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop } + + // wait indefinitely for other go-routines to complete their job + select {} } func httpServerLoop() { - for { + for !httpsServer.shutdown { httpsServer.cond.L.Lock() // this mechanism doesn't let us through until all conditions are met for config.TLS.Enabled == false || @@ -367,6 +372,15 @@ func cleanup() { } } +// Stop HTTP server, possibly waiting for all active connections to be closed +func stopHTTPServer() { + httpsServer.shutdown = true + if httpsServer.server != nil { + httpsServer.server.Shutdown(context.TODO()) + } + httpServer.Shutdown(context.TODO()) +} + // This function is called before application exits func cleanupAlways() { if len(pidFileName) != 0 { diff --git a/control.go b/control.go index 036e44a8..4646b221 100644 --- a/control.go +++ b/control.go @@ -557,50 +557,6 @@ func checkDNS(input string, bootstrap []string) error { return nil } -func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - - now := time.Now() - controlLock.Lock() - cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 - data := versionCheckJSON - controlLock.Unlock() - - if cached { - // return cached copy - w.Header().Set("Content-Type", "application/json") - w.Write(data) - return - } - - resp, err := client.Get(versionCheckURL) - if err != nil { - httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) - return - } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - - // read the body entirely - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(body) - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) - } - - controlLock.Lock() - versionCheckLastTime = now - versionCheckJSON = body - controlLock.Unlock() -} - // --------- // filtering // --------- @@ -1014,6 +970,7 @@ func registerControlHandlers() { http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(handleStatsHistory)))) http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(handleStatsReset)))) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) + http.HandleFunc("/control/update", postInstall(optionalAuth(ensurePOST(handleUpdate)))) http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable)))) http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable)))) http.HandleFunc("/control/filtering/add_url", postInstall(optionalAuth(ensurePOST(handleFilteringAddURL)))) diff --git a/control_update.go b/control_update.go new file mode 100644 index 00000000..0e89f7c7 --- /dev/null +++ b/control_update.go @@ -0,0 +1,341 @@ +package main + +import ( + "archive/zip" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" + + "github.com/AdguardTeam/golibs/log" +) + +// Get the latest available version from the Internet +func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { + log.Tracef("%s %v", r.Method, r.URL) + + now := time.Now() + controlLock.Lock() + cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 + data := versionCheckJSON + controlLock.Unlock() + + if cached { + // return cached copy + w.Header().Set("Content-Type", "application/json") + w.Write(data) + return + } + + resp, err := client.Get(versionCheckURL) + if err != nil { + httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) + return + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + // read the body entirely + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(body) + if err != nil { + httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) + } + + controlLock.Lock() + versionCheckLastTime = now + versionCheckJSON = body + controlLock.Unlock() +} + +// Copy file on disk +func copyFile(src, dst string) error { + d, e := ioutil.ReadFile(src) + if e != nil { + return e + } + e = ioutil.WriteFile(dst, d, 0644) + if e != nil { + return e + } + return nil +} + +type updateInfo struct { + pkgURL string // URL for the new package + pkgName string // Full path to package file + newVer string // New version string + updateDir string // Full path to the directory containing unpacked files from the new package + backupDir string // Full path to backup directory + configName string // Full path to the current configuration file + updateConfigName string // Full path to the configuration file to check by the new binary + curBinName string // Full path to the current executable file + bkpBinName string // Full path to the current executable file in backup directory + newBinName string // Full path to the new executable file +} + +// Fill in updateInfo object +func getUpdateInfo(jsonData []byte) (*updateInfo, error) { + var u updateInfo + + workDir := config.ourWorkingDir + + versionJSON := make(map[string]interface{}) + err := json.Unmarshal(jsonData, &versionJSON) + if err != nil { + return nil, fmt.Errorf("JSON parse: %s", err) + } + + u.pkgURL = versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)].(string) + u.newVer = versionJSON["version"].(string) + if len(u.pkgURL) == 0 || len(u.newVer) == 0 { + return nil, fmt.Errorf("Invalid JSON") + } + + if u.newVer == VersionString { + return nil, fmt.Errorf("No need to update") + } + + _, pkgFileName := filepath.Split(u.pkgURL) + if len(pkgFileName) == 0 { + return nil, fmt.Errorf("Invalid JSON") + } + u.pkgName = filepath.Join(workDir, pkgFileName) + + u.updateDir = filepath.Join(workDir, fmt.Sprintf("update-%s", u.newVer)) + u.backupDir = filepath.Join(workDir, fmt.Sprintf("backup-%s", VersionString)) + u.configName = config.getConfigFilename() + u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml") + if strings.HasSuffix(pkgFileName, ".zip") { + u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome.yaml") + } + + binName := "AdGuardHome" + if runtime.GOOS == "windows" { + binName = "AdGuardHome.exe" + } + u.curBinName = filepath.Join(workDir, binName) + u.bkpBinName = filepath.Join(u.backupDir, binName) + u.newBinName = filepath.Join(u.updateDir, "AdGuardHome", binName) + if strings.HasSuffix(pkgFileName, ".zip") { + u.newBinName = filepath.Join(u.updateDir, binName) + } + + return &u, nil +} + +// Unpack all files from .zip file to the specified directory +func zipFileUnpack(zipfile, outdir string) error { + r, err := zip.OpenReader(zipfile) + if err != nil { + return fmt.Errorf("zip.OpenReader(): %s", err) + } + defer r.Close() + + for _, zf := range r.File { + zr, err := zf.Open() + if err != nil { + return fmt.Errorf("zip file Open(): %s", err) + } + fi := zf.FileInfo() + fn := filepath.Join(outdir, fi.Name()) + + if fi.IsDir() { + err = os.Mkdir(fn, fi.Mode()) + if err != nil { + return fmt.Errorf("zip file Read(): %s", err) + } + continue + } + + f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) + if err != nil { + zr.Close() + return fmt.Errorf("os.OpenFile(): %s", err) + } + _, err = io.Copy(f, zr) + if err != nil { + zr.Close() + return fmt.Errorf("io.Copy(): %s", err) + } + zr.Close() + } + return nil +} + +// Unpack all files from .tar.gz file to the specified directory +func targzFileUnpack(tarfile, outdir string) error { + cmd := exec.Command("tar", "zxf", tarfile, "-C", outdir) + log.Tracef("Unpacking: %v", cmd.Args) + _, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return fmt.Errorf("exec.Command() failed: %s", err) + } + return nil +} + +// Perform an update procedure +func doUpdate(u *updateInfo) error { + log.Info("Updating from %s to %s. URL:%s Package:%s", + VersionString, u.newVer, u.pkgURL, u.pkgName) + + resp, err := client.Get(u.pkgURL) + if err != nil { + return fmt.Errorf("HTTP request failed: %s", err) + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + log.Tracef("Reading HTTP body") + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("ioutil.ReadAll() failed: %s", err) + } + + log.Tracef("Saving package to file") + err = ioutil.WriteFile(u.pkgName, body, 0644) + if err != nil { + return fmt.Errorf("ioutil.WriteFile() failed: %s", err) + } + + log.Tracef("Unpacking the package") + _ = os.Mkdir(u.updateDir, 0755) + _, file := filepath.Split(u.pkgName) + if strings.HasSuffix(file, ".zip") { + err = zipFileUnpack(u.pkgName, u.updateDir) + if err != nil { + return fmt.Errorf("zipFileUnpack() failed: %s", err) + } + } else if strings.HasSuffix(file, ".tar.gz") { + err = targzFileUnpack(u.pkgName, u.updateDir) + if err != nil { + return fmt.Errorf("zipFileUnpack() failed: %s", err) + } + } else { + return fmt.Errorf("Unknown package extension") + } + + log.Tracef("Checking configuration") + err = copyFile(u.configName, u.updateConfigName) + if err != nil { + return fmt.Errorf("copyFile() failed: %s", err) + } + cmd := exec.Command(u.newBinName, "--check-config") + err = cmd.Run() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode()) + } + + log.Tracef("Backing up the current configuration") + _ = os.Mkdir(u.backupDir, 0755) + err = copyFile(u.configName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) + if err != nil { + return fmt.Errorf("copyFile() failed: %s", err) + } + + log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) + err = os.Rename(u.curBinName, u.bkpBinName) + if err != nil { + return err + } + if runtime.GOOS == "windows" { + // rename fails with "File in use" error + err = copyFile(u.newBinName, u.curBinName) + } else { + err = os.Rename(u.newBinName, u.curBinName) + } + if err != nil { + return err + } + log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName) + + _ = os.Remove(u.pkgName) + // _ = os.RemoveAll(u.updateDir) + return nil +} + +// Complete an update procedure +func finishUpdate(u *updateInfo) { + log.Info("Stopping all tasks") + cleanup() + stopHTTPServer() + cleanupAlways() + + if runtime.GOOS == "windows" { + + if config.runningAsService { + // Note: + // we can't restart the service via "kardianos/service" package - it kills the process first + // we can't start a new instance - Windows doesn't allow it + cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome") + err := cmd.Start() + if err != nil { + log.Fatalf("exec.Command() failed: %s", err) + } + os.Exit(0) + } + + cmd := exec.Command(u.curBinName, os.Args[1:]...) + log.Info("Restarting: %v", cmd.Args) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + log.Fatalf("exec.Command() failed: %s", err) + } + os.Exit(0) + + } else { + + log.Info("Restarting: %v", os.Args) + err := syscall.Exec(u.curBinName, os.Args, os.Environ()) + if err != nil { + log.Fatalf("syscall.Exec() failed: %s", err) + } + // Unreachable code + } +} + +// Perform an update procedure to the latest available version +func handleUpdate(w http.ResponseWriter, r *http.Request) { + log.Tracef("%s %v", r.Method, r.URL) + + if len(versionCheckJSON) == 0 { + httpError(w, http.StatusBadRequest, "/update request isn't allowed now") + return + } + + u, err := getUpdateInfo(versionCheckJSON) + if err != nil { + httpError(w, http.StatusInternalServerError, "%s", err) + return + } + + err = doUpdate(u) + if err != nil { + httpError(w, http.StatusInternalServerError, "%s", err) + return + } + + returnOK(w) + + time.Sleep(time.Second) // wait (hopefully) until response is sent (not sure whether it's really necessary) + go finishUpdate(u) +} diff --git a/control_update_test.go b/control_update_test.go new file mode 100644 index 00000000..346b98f3 --- /dev/null +++ b/control_update_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "os" + "testing" +) + +func testDoUpdate(t *testing.T) { + config.DNS.Port = 0 + u := updateInfo{ + pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.95/AdGuardHome_v0.95_linux_amd64.tar.gz", + pkgName: "./AdGuardHome_v0.95_linux_amd64.tar.gz", + newVer: "v0.95", + updateDir: "./update-v0.95", + backupDir: "./backup-v0.94", + configName: "./AdGuardHome.yaml", + updateConfigName: "./update-v0.95/AdGuardHome/AdGuardHome.yaml", + curBinName: "./AdGuardHome", + bkpBinName: "./backup-v0.94/AdGuardHome", + newBinName: "./update-v0.95/AdGuardHome/AdGuardHome", + } + e := doUpdate(&u) + if e != nil { + t.Fatalf("FAILED: %s", e) + } + os.RemoveAll(u.backupDir) + os.RemoveAll(u.updateDir) +} + +func testZipFileUnpack(t *testing.T) { + fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip" + outdir := "./test-unpack" + _ = os.Mkdir(outdir, 0755) + e := zipFileUnpack(fn, outdir) + if e != nil { + t.Fatalf("FAILED: %s", e) + } + os.RemoveAll(outdir) +}