diff --git a/config.yaml b/config.yaml index b147a81..48fade1 100644 --- a/config.yaml +++ b/config.yaml @@ -3,6 +3,7 @@ login: title: "luzifer.io - Login" default_method: "simple" + default_redirect: "https://luzifer.io/" hide_mfa_field: false names: simple: "Username / Password" diff --git a/main.go b/main.go index d84cd5b..2359666 100644 --- a/main.go +++ b/main.go @@ -36,10 +36,11 @@ type mainConfig struct { Port int `yaml:"port"` } `yaml:"listen"` Login struct { - Title string `yaml:"title"` - DefaultMethod string `yaml:"default_method"` - HideMFAField bool `yaml:"hide_mfa_field"` - Names map[string]string `yaml:"names"` + Title string `yaml:"title"` + DefaultMethod string `yaml:"default_method"` + DefaultRedirect string `yaml:"default_redirect"` + HideMFAField bool `yaml:"hide_mfa_field"` + Names map[string]string `yaml:"names"` } `yaml:"login"` Plugins struct { Directory string `yaml:"directory"` @@ -181,7 +182,7 @@ func handleAuthRequest(res http.ResponseWriter, r *http.Request) { } func handleLoginRequest(res http.ResponseWriter, r *http.Request) { - redirURL, err := getRedirectURL(r) + redirURL, err := getRedirectURL(r, mainCfg.Login.DefaultRedirect) if err != nil { http.Error(res, "Invalid redirect URL specified", http.StatusBadRequest) } @@ -254,7 +255,7 @@ func handleLoginRequest(res http.ResponseWriter, r *http.Request) { } func handleLogoutRequest(res http.ResponseWriter, r *http.Request) { - redirURL, err := getRedirectURL(r) + redirURL, err := getRedirectURL(r, mainCfg.Login.DefaultRedirect) if err != nil { http.Error(res, "Invalid redirect URL specified", http.StatusBadRequest) } diff --git a/redirect.go b/redirect.go index 9d51225..98dfbd4 100644 --- a/redirect.go +++ b/redirect.go @@ -7,30 +7,45 @@ import ( "github.com/pkg/errors" ) -func getRedirectURL(r *http.Request) (string, error) { +func getRedirectURL(r *http.Request, fallback string) (string, error) { var ( - redirURL = r.URL.Query().Get("go") - params = r.URL.Query() + redirURL string + params url.Values ) - if redirURL == "" { + switch { + case r.URL.Query().Get("go") != "": + // We have a GET request, use "go" query param + redirURL = r.URL.Query().Get("go") + params = r.URL.Query() + + case r.FormValue("go") != "": + // We have a POST request, use "go" form value redirURL = r.FormValue("go") params = url.Values{} // No need to read other form fields + + default: + // No URL specified, use specified fallback URL + return fallback, nil } + // Remove the "go" parameter as it is a parameter for nginx-sso params.Del("go") + // Parse given URL to extract attached parameters pURL, err := url.Parse(redirURL) if err != nil { return "", errors.Wrap(err, "Unable to parse redirect URL") } + // Transfer parameters from URL to params set for k, vs := range pURL.Query() { for _, v := range vs { params.Add(k, v) } } + // Re-add assembled parameters to URL pURL.RawQuery = params.Encode() return pURL.String(), nil diff --git a/redirect_test.go b/redirect_test.go index bd433b3..5d193e2 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -14,7 +14,7 @@ func TestGetRedirectGet(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, testURL, nil) - rURL, err := getRedirectURL(req) + rURL, err := getRedirectURL(req, "") if err != nil { t.Errorf("getRedirectURL caused an error in GET: %s", err) } @@ -24,13 +24,13 @@ func TestGetRedirectGet(t *testing.T) { } } -func TestGetRedirectGetEmpty(t *testing.T) { +func TestGetRedirectFallback(t *testing.T) { testURL := "https://example.com/login" - expectURL := "" + expectURL := "https://example.com/default" req, _ := http.NewRequest(http.MethodGet, testURL, nil) - rURL, err := getRedirectURL(req) + rURL, err := getRedirectURL(req, expectURL) if err != nil { t.Errorf("getRedirectURL caused an error in GET: %s", err) } @@ -51,7 +51,7 @@ func TestGetRedirectPost(t *testing.T) { req, _ := http.NewRequest(http.MethodPost, testURL, nil) req.Form = body // Force-set the form values to emulate parsed form - rURL, err := getRedirectURL(req) + rURL, err := getRedirectURL(req, "") if err != nil { t.Errorf("getRedirectURL caused an error in POST: %s", err) }