mirror of
https://github.com/Luzifer/nginx-sso.git
synced 2024-12-20 04:41:17 +00:00
Add default redirect URL for missing go-parameter
This adds a configuration option to set a default redirect URL for when no `go` parameter was passed. This allows for users to have bookmarked the login page and be redirected to the right location instead of seeing a 404 page. Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
a3390d6c75
commit
1cb9199bd9
4 changed files with 32 additions and 15 deletions
|
@ -3,6 +3,7 @@
|
||||||
login:
|
login:
|
||||||
title: "luzifer.io - Login"
|
title: "luzifer.io - Login"
|
||||||
default_method: "simple"
|
default_method: "simple"
|
||||||
|
default_redirect: "https://luzifer.io/"
|
||||||
hide_mfa_field: false
|
hide_mfa_field: false
|
||||||
names:
|
names:
|
||||||
simple: "Username / Password"
|
simple: "Username / Password"
|
||||||
|
|
13
main.go
13
main.go
|
@ -36,10 +36,11 @@ type mainConfig struct {
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
} `yaml:"listen"`
|
} `yaml:"listen"`
|
||||||
Login struct {
|
Login struct {
|
||||||
Title string `yaml:"title"`
|
Title string `yaml:"title"`
|
||||||
DefaultMethod string `yaml:"default_method"`
|
DefaultMethod string `yaml:"default_method"`
|
||||||
HideMFAField bool `yaml:"hide_mfa_field"`
|
DefaultRedirect string `yaml:"default_redirect"`
|
||||||
Names map[string]string `yaml:"names"`
|
HideMFAField bool `yaml:"hide_mfa_field"`
|
||||||
|
Names map[string]string `yaml:"names"`
|
||||||
} `yaml:"login"`
|
} `yaml:"login"`
|
||||||
Plugins struct {
|
Plugins struct {
|
||||||
Directory string `yaml:"directory"`
|
Directory string `yaml:"directory"`
|
||||||
|
@ -181,7 +182,7 @@ func handleAuthRequest(res http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleLoginRequest(res http.ResponseWriter, r *http.Request) {
|
func handleLoginRequest(res http.ResponseWriter, r *http.Request) {
|
||||||
redirURL, err := getRedirectURL(r)
|
redirURL, err := getRedirectURL(r, mainCfg.Login.DefaultRedirect)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(res, "Invalid redirect URL specified", http.StatusBadRequest)
|
http.Error(res, "Invalid redirect URL specified", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
@ -254,7 +255,7 @@ func handleLoginRequest(res http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleLogoutRequest(res http.ResponseWriter, r *http.Request) {
|
func handleLogoutRequest(res http.ResponseWriter, r *http.Request) {
|
||||||
redirURL, err := getRedirectURL(r)
|
redirURL, err := getRedirectURL(r, mainCfg.Login.DefaultRedirect)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(res, "Invalid redirect URL specified", http.StatusBadRequest)
|
http.Error(res, "Invalid redirect URL specified", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
23
redirect.go
23
redirect.go
|
@ -7,30 +7,45 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getRedirectURL(r *http.Request) (string, error) {
|
func getRedirectURL(r *http.Request, fallback string) (string, error) {
|
||||||
var (
|
var (
|
||||||
redirURL = r.URL.Query().Get("go")
|
redirURL string
|
||||||
params = r.URL.Query()
|
params url.Values
|
||||||
)
|
)
|
||||||
|
|
||||||
if redirURL == "" {
|
switch {
|
||||||
|
case r.URL.Query().Get("go") != "":
|
||||||
|
// We have a GET request, use "go" query param
|
||||||
|
redirURL = r.URL.Query().Get("go")
|
||||||
|
params = r.URL.Query()
|
||||||
|
|
||||||
|
case r.FormValue("go") != "":
|
||||||
|
// We have a POST request, use "go" form value
|
||||||
redirURL = r.FormValue("go")
|
redirURL = r.FormValue("go")
|
||||||
params = url.Values{} // No need to read other form fields
|
params = url.Values{} // No need to read other form fields
|
||||||
|
|
||||||
|
default:
|
||||||
|
// No URL specified, use specified fallback URL
|
||||||
|
return fallback, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove the "go" parameter as it is a parameter for nginx-sso
|
||||||
params.Del("go")
|
params.Del("go")
|
||||||
|
|
||||||
|
// Parse given URL to extract attached parameters
|
||||||
pURL, err := url.Parse(redirURL)
|
pURL, err := url.Parse(redirURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "Unable to parse redirect URL")
|
return "", errors.Wrap(err, "Unable to parse redirect URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transfer parameters from URL to params set
|
||||||
for k, vs := range pURL.Query() {
|
for k, vs := range pURL.Query() {
|
||||||
for _, v := range vs {
|
for _, v := range vs {
|
||||||
params.Add(k, v)
|
params.Add(k, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Re-add assembled parameters to URL
|
||||||
pURL.RawQuery = params.Encode()
|
pURL.RawQuery = params.Encode()
|
||||||
|
|
||||||
return pURL.String(), nil
|
return pURL.String(), nil
|
||||||
|
|
|
@ -14,7 +14,7 @@ func TestGetRedirectGet(t *testing.T) {
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, testURL, nil)
|
req, _ := http.NewRequest(http.MethodGet, testURL, nil)
|
||||||
|
|
||||||
rURL, err := getRedirectURL(req)
|
rURL, err := getRedirectURL(req, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("getRedirectURL caused an error in GET: %s", err)
|
t.Errorf("getRedirectURL caused an error in GET: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -24,13 +24,13 @@ func TestGetRedirectGet(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRedirectGetEmpty(t *testing.T) {
|
func TestGetRedirectFallback(t *testing.T) {
|
||||||
testURL := "https://example.com/login"
|
testURL := "https://example.com/login"
|
||||||
expectURL := ""
|
expectURL := "https://example.com/default"
|
||||||
|
|
||||||
req, _ := http.NewRequest(http.MethodGet, testURL, nil)
|
req, _ := http.NewRequest(http.MethodGet, testURL, nil)
|
||||||
|
|
||||||
rURL, err := getRedirectURL(req)
|
rURL, err := getRedirectURL(req, expectURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("getRedirectURL caused an error in GET: %s", err)
|
t.Errorf("getRedirectURL caused an error in GET: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ func TestGetRedirectPost(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, testURL, nil)
|
req, _ := http.NewRequest(http.MethodPost, testURL, nil)
|
||||||
req.Form = body // Force-set the form values to emulate parsed form
|
req.Form = body // Force-set the form values to emulate parsed form
|
||||||
|
|
||||||
rURL, err := getRedirectURL(req)
|
rURL, err := getRedirectURL(req, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("getRedirectURL caused an error in POST: %s", err)
|
t.Errorf("getRedirectURL caused an error in POST: %s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue