Add CORS header middleware

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2022-12-27 18:28:44 +01:00
parent 4b9090edd8
commit 956122ef95
Signed by: luzifer
GPG key ID: D91C3E91E4CAD6F5
3 changed files with 47 additions and 1 deletions

View file

@ -13,7 +13,8 @@
```console
# restis --help
Usage of restis:
Usage of ./restis:
--disable-cors Disable setting CORS headers for all requests
--listen string Port/IP to listen on (default ":3000")
--log-level string Log level (debug, info, warn, error, fatal) (default "info")
--redis-conn-string string Connection string for redis (default "redis://localhost:6379/0")

30
cors.go Normal file
View file

@ -0,0 +1,30 @@
package main
import (
"net/http"
"strings"
)
func corsMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Allow the client to send us credentials
w.Header().Set("Access-Control-Allow-Credentials", "true")
// We don't care about headers at all, allow sending all
w.Header().Set("Access-Control-Allow-Headers", "*")
// List all accepted methods no matter whether they are accepted by the specified endpoint
w.Header().Set("Access-Control-Allow-Methods", strings.Join([]string{
http.MethodDelete,
http.MethodGet,
http.MethodPut,
}, ", "))
// Public API: Let everyone in
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
h.ServeHTTP(w, r)
})
}

15
main.go
View file

@ -16,6 +16,7 @@ import (
var (
cfg = struct {
DisableCORS bool `flag:"disable-cors" default:"false" description:"Disable setting CORS headers for all requests"`
Listen string `flag:"listen" default:":3000" description:"Port/IP to listen on"`
LogLevel string `flag:"log-level" default:"info" description:"Log level (debug, info, warn, error, fatal)"`
RedisConnString string `flag:"redis-conn-string" default:"redis://localhost:6379/0" description:"Connection string for redis"`
@ -62,6 +63,20 @@ func main() {
router = mux.NewRouter()
)
if !cfg.DisableCORS {
router.Use(corsMiddleware)
}
router.MethodNotAllowedHandler = corsMiddleware(http.HandlerFunc(func(res http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
// Most likely JS client asking for CORS headers
res.WriteHeader(http.StatusNoContent)
return
}
res.WriteHeader(http.StatusMethodNotAllowed)
}))
router.Methods(http.MethodDelete).HandlerFunc(handlerDelete(client))
router.Methods(http.MethodGet).HandlerFunc(handlerGet(client))
router.Methods(http.MethodPut).HandlerFunc(handlerPut(client))