diff --git a/README.md b/README.md index d588859..f3cc385 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ ```console # restis --help -Usage of restis: +Usage of ./restis: + --disable-cors Disable setting CORS headers for all requests --listen string Port/IP to listen on (default ":3000") --log-level string Log level (debug, info, warn, error, fatal) (default "info") --redis-conn-string string Connection string for redis (default "redis://localhost:6379/0") diff --git a/cors.go b/cors.go new file mode 100644 index 0000000..b7bccd6 --- /dev/null +++ b/cors.go @@ -0,0 +1,30 @@ +package main + +import ( + "net/http" + "strings" +) + +func corsMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Allow the client to send us credentials + w.Header().Set("Access-Control-Allow-Credentials", "true") + + // We don't care about headers at all, allow sending all + w.Header().Set("Access-Control-Allow-Headers", "*") + + // List all accepted methods no matter whether they are accepted by the specified endpoint + w.Header().Set("Access-Control-Allow-Methods", strings.Join([]string{ + http.MethodDelete, + http.MethodGet, + http.MethodPut, + }, ", ")) + + // Public API: Let everyone in + if origin := r.Header.Get("Origin"); origin != "" { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + h.ServeHTTP(w, r) + }) +} diff --git a/main.go b/main.go index c59351f..b936bd0 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( var ( cfg = struct { + DisableCORS bool `flag:"disable-cors" default:"false" description:"Disable setting CORS headers for all requests"` Listen string `flag:"listen" default:":3000" description:"Port/IP to listen on"` LogLevel string `flag:"log-level" default:"info" description:"Log level (debug, info, warn, error, fatal)"` RedisConnString string `flag:"redis-conn-string" default:"redis://localhost:6379/0" description:"Connection string for redis"` @@ -62,6 +63,20 @@ func main() { router = mux.NewRouter() ) + if !cfg.DisableCORS { + router.Use(corsMiddleware) + } + + router.MethodNotAllowedHandler = corsMiddleware(http.HandlerFunc(func(res http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + // Most likely JS client asking for CORS headers + res.WriteHeader(http.StatusNoContent) + return + } + + res.WriteHeader(http.StatusMethodNotAllowed) + })) + router.Methods(http.MethodDelete).HandlerFunc(handlerDelete(client)) router.Methods(http.MethodGet).HandlerFunc(handlerGet(client)) router.Methods(http.MethodPut).HandlerFunc(handlerPut(client))