From ad799299c0071481b4df0fff896bcc233819a32a Mon Sep 17 00:00:00 2001 From: Knut Ahlers Date: Wed, 12 Oct 2016 23:08:24 +0200 Subject: [PATCH] Add github-binary update helper --- github/updater.go | 213 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 github/updater.go diff --git a/github/updater.go b/github/updater.go new file mode 100644 index 0000000..2a322b9 --- /dev/null +++ b/github/updater.go @@ -0,0 +1,213 @@ +package github + +import ( + "bufio" + "bytes" + "context" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "runtime" + "strings" + "text/template" + "time" + + update "github.com/inconshreveable/go-update" +) + +const ( + defaultTimeout = 60 * time.Second + defaultNamingScheme = `{{.ProductName}}_{{.GOOS}}_{{.GOARCH}}{{.EXT}}` +) + +var ( + errReleaseNotFound = errors.New("Release not found") +) + +// Updater is the core struct of the update library holding all configurations +type Updater struct { + repo string + myVersion string + + HTTPClient *http.Client + RequestTimeout time.Duration + Context context.Context + Filename string + + releaseCache string +} + +// NewUpdater initializes a new Updater and tries to guess the Filename +func NewUpdater(repo, myVersion string) (*Updater, error) { + var err error + u := &Updater{ + repo: repo, + myVersion: myVersion, + + HTTPClient: http.DefaultClient, + RequestTimeout: defaultTimeout, + Context: context.Background(), + } + + u.Filename, err = u.compileFilename() + + return u, err +} + +// HasUpdate checks which tag was used in the latest version and compares it to the current version. If it differs the function will return true. No comparison is done to determine whether the found version is higher than the current one. +func (u *Updater) HasUpdate(forceRefresh bool) (bool, error) { + if forceRefresh { + u.releaseCache = "" + } + + latest, err := u.getLatestRelease() + switch err { + case nil: + return u.myVersion != latest, nil + case errReleaseNotFound: + return false, nil + default: + return false, err + } +} + +// Apply downloads the new binary from Github, fetches the SHA256 sum from the SHA256SUMS file and applies the update to the currently running binary +func (u *Updater) Apply() error { + updateAvailable, err := u.HasUpdate(false) + if err != nil { + return err + } + if !updateAvailable { + return nil + } + + checksum, err := u.getSHA256(u.Filename) + if err != nil { + return err + } + + newRelease, err := u.getFile(u.Filename) + if err != nil { + return err + } + defer newRelease.Close() + + return update.Apply(newRelease, update.Options{ + Checksum: checksum, + }) +} + +func (u Updater) getSHA256(filename string) ([]byte, error) { + shaFile, err := u.getFile("SHA256SUMS") + if err != nil { + return nil, err + } + defer shaFile.Close() + + scanner := bufio.NewScanner(shaFile) + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, u.Filename) { + continue + } + + return hex.DecodeString(line[0:64]) + } + + return nil, fmt.Errorf("No SHA256 found for file %q", u.Filename) +} + +func (u Updater) getFile(filename string) (io.ReadCloser, error) { + release, err := u.getLatestRelease() + if err != nil { + return nil, err + } + + requestURL := fmt.Sprintf("https://github.com/%s/releases/download/%s/%s", u.repo, release, filename) + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, err + } + + ctx, _ := context.WithTimeout(u.Context, u.RequestTimeout) + + res, err := u.HTTPClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + if res.StatusCode != 200 { + return nil, fmt.Errorf("File not found: %q", requestURL) + } + + return res.Body, nil +} + +func (u *Updater) getLatestRelease() (string, error) { + if u.releaseCache != "" { + return u.releaseCache, nil + } + + result := struct { + TagName string `json:"tag_name"` + }{} + + requestURL := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", u.repo) + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return "", err + } + + ctx, cancel := context.WithTimeout(u.Context, u.RequestTimeout) + defer cancel() + + res, err := u.HTTPClient.Do(req.WithContext(ctx)) + if err != nil { + return "", err + } + defer res.Body.Close() + + if err = json.NewDecoder(res.Body).Decode(&result); err != nil { + return "", err + } + + if res.StatusCode != 200 || result.TagName == "" { + return "", errReleaseNotFound + } + + u.releaseCache = result.TagName + + return result.TagName, nil +} + +func (u Updater) compileFilename() (string, error) { + repoName := strings.Split(u.repo, "/") + if len(repoName) != 2 { + return "", errors.New("Repository name not in format /") + } + + tpl, err := template.New("filename").Parse(defaultNamingScheme) + if err != nil { + return "", err + } + + var ext string + if runtime.GOOS == "windows" { + ext = ".exe" + } + + buf := bytes.NewBuffer([]byte{}) + if err = tpl.Execute(buf, map[string]interface{}{ + "GOOS": runtime.GOOS, + "GOARCH": runtime.GOARCH, + "EXT": ext, + "ProductName": repoName[1], + }); err != nil { + return "", err + } + + return buf.String(), nil +}