diff --git a/api.go b/api.go index fee8b4c..8c4a86c 100644 --- a/api.go +++ b/api.go @@ -66,12 +66,9 @@ func (s *subcriptionStore) UnsubscribeSocket(id string) { } func compileSocketMessage(msgType string, msg interface{}) socketMessage { - assetVersionsLock.RLock() - defer assetVersionsLock.RUnlock() - versionParts := []string{version} - for _, asset := range assets { - versionParts = append(versionParts, assetVersions[asset]) + for _, asset := range assetVersions.Keys() { + versionParts = append(versionParts, assetVersions.Get(asset)) } hash := sha256.New() diff --git a/assets.go b/assets.go new file mode 100644 index 0000000..2eb65dd --- /dev/null +++ b/assets.go @@ -0,0 +1,79 @@ +package main + +import ( + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "sync" + + "github.com/pkg/errors" +) + +var assetVersions = newAssetVersionStore() + +type assetVersionStore struct { + store map[string]string + lock sync.RWMutex +} + +func newAssetVersionStore() *assetVersionStore { + return &assetVersionStore{ + store: make(map[string]string), + } +} + +func (a *assetVersionStore) Get(key string) string { + a.lock.RLock() + defer a.lock.RUnlock() + + return a.store[key] +} + +func (a *assetVersionStore) Keys() []string { + a.lock.RLock() + defer a.lock.RUnlock() + + var out []string + + for k := range a.store { + out = append(out, k) + } + + sort.Strings(out) + + return out +} + +func (a *assetVersionStore) UpdateAssetHashes(dir string) error { + a.lock.Lock() + defer a.lock.Unlock() + + return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + // There was a previous error + return err + } + + if info.IsDir() { + // We can't hash directories + return nil + } + + hash := sha256.New() + f, err := os.Open(path) + if err != nil { + return errors.Wrap(err, "open asset file") + } + defer f.Close() + + if _, err = io.Copy(hash, f); err != nil { + return errors.Wrap(err, "read asset file") + } + + a.store[path] = fmt.Sprintf("%x", hash.Sum(nil)) + return nil + }) +} diff --git a/main.go b/main.go index 885b9e7..d273244 100644 --- a/main.go +++ b/main.go @@ -1,19 +1,13 @@ package main import ( - "crypto/sha256" "fmt" - "io" "net/http" "os" - "path" - "strings" - "sync" "time" "github.com/gofrs/uuid" "github.com/gorilla/mux" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/Luzifer/rconfig/v2" @@ -36,10 +30,6 @@ var ( WebHookTimeout time.Duration `flag:"webhook-timeout" default:"15m" description:"When to re-register the webhooks"` }{} - assets = []string{"app.js", "overlay.html"} - assetVersions = map[string]string{} - assetVersionsLock = new(sync.RWMutex) - store *storage webhookSecret = uuid.Must(uuid.NewV4()).String() @@ -70,20 +60,20 @@ func main() { log.WithError(err).Fatal("Unable to load store") } - if err := updateAssetHashes(); err != nil { + if err := assetVersions.UpdateAssetHashes(cfg.AssetDir); err != nil { log.WithError(err).Fatal("Unable to read asset hashes") } - router := mux.NewRouter() + var ( + assetServer = http.FileServer(http.Dir(cfg.AssetDir)) + router = mux.NewRouter() + ) registerAPI(router) - router.HandleFunc( - fmt.Sprintf("/{file:(?:%s)}", strings.Join(assets, "|")), - func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "no-cache") - http.ServeFile(w, r, path.Join(cfg.AssetDir, mux.Vars(r)["file"])) - }, - ) + router.PathPrefix("/public").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-cache") + assetServer.ServeHTTP(w, r) + }) go func() { if err := http.ListenAndServe(cfg.Listen, router); err != nil { @@ -105,7 +95,7 @@ func main() { for { select { case <-timerAssetCheck.C: - if err := updateAssetHashes(); err != nil { + if err := assetVersions.UpdateAssetHashes(cfg.AssetDir); err != nil { log.WithError(err).Error("Unable to update asset hashes") } @@ -127,25 +117,3 @@ func main() { } } } - -func updateAssetHashes() error { - assetVersionsLock.Lock() - defer assetVersionsLock.Unlock() - - for _, asset := range assets { - hash := sha256.New() - f, err := os.Open(asset) - if err != nil { - return errors.Wrap(err, "open asset file") - } - defer f.Close() - - if _, err = io.Copy(hash, f); err != nil { - return errors.Wrap(err, "read asset file") - } - - assetVersions[asset] = fmt.Sprintf("%x", hash.Sum(nil)) - } - - return nil -}