mercedes-byocar-exporter/internal/mercedes/impl.go
2022-11-20 00:47:49 +01:00

216 lines
5.6 KiB
Go

package mercedes
import (
"context"
"encoding/json"
"io"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/Luzifer/mercedes-byocar-exporter/internal/credential"
)
const (
requestTimeout = 10 * time.Second
stateExpiry = 5 * time.Minute
)
type (
APIClient struct {
clientID, clientSecret string
creds credential.Store
validStateToken string
validStateTokenExpiry time.Time
}
)
var _ Client = (*APIClient)(nil)
func New(clientID, clientSecret string, creds credential.Store) *APIClient {
return &APIClient{
clientID: clientID,
clientSecret: clientSecret,
creds: creds,
}
}
func (a *APIClient) GetAuthStartURL(redirectURL string) string {
a.validStateToken = uuid.Must(uuid.NewV4()).String()
a.validStateTokenExpiry = time.Now().Add(stateExpiry)
return a.getOauth2Config(redirectURL).AuthCodeURL(a.validStateToken, oauth2.AccessTypeOffline)
}
func (a *APIClient) StoreTokenFromRequest(redirectURL string, r *http.Request) error {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
state := r.FormValue("state")
if state != a.validStateToken || a.validStateTokenExpiry.Before(time.Now()) {
return errors.New("invalid or expired state")
}
code := r.FormValue("code")
tok, err := a.getOauth2Config(redirectURL).Exchange(ctx, code, oauth2.AccessTypeOffline)
if err != nil {
return errors.Wrap(err, "exchanging code for token")
}
return errors.Wrap(a.creds.UpdateToken(tok.AccessToken, tok.RefreshToken, tok.Expiry), "updating stored token")
}
func (a APIClient) getOauth2Config(redirectURL string) *oauth2.Config {
return &oauth2.Config{
ClientID: a.clientID,
ClientSecret: a.clientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: oAuthEndpointAuth,
TokenURL: oAuthEndpointToken,
AuthStyle: oauth2.AuthStyleInHeader,
},
RedirectURL: redirectURL,
Scopes: []string{
oAuthScopeOfflineAccess,
oAuthScopeOpenID,
oAuthScopePayAsYouDrive,
oAuthScopeVehicleFuelStatus,
oAuthScopeVehicleLockStatus,
oAuthScopeVehicleStatus,
},
}
}
func (a APIClient) parseGenericAPIResponse(data io.Reader, output any) (err error) {
var tmp genericAPIResponse
if err = json.NewDecoder(data).Decode(&tmp); err != nil {
return errors.Wrap(err, "parsing JSON response")
}
st := reflect.ValueOf(output).Elem()
for i := 0; i < st.NumField(); i++ {
valField := st.Field(i)
typeField := st.Type().Field(i)
name := typeField.Tag.Get("apiField")
if name == "" {
continue
}
value := tmp.Get(name)
if value == nil {
continue
}
value.Timestamp = value.Timestamp * 1000000
switch typeField.Type {
case reflect.TypeOf(TimedBool{}):
fv := TimedBool{}
if fv.v, err = strconv.ParseBool(value.Value); err != nil {
return errors.Wrapf(err, "parsing value for %s", name)
}
fv.t = time.Unix(0, value.Timestamp)
valField.Set(reflect.ValueOf(fv))
case reflect.TypeOf(TimedEnum{}):
fv := TimedEnum{}
if fv.v, err = strconv.ParseInt(value.Value, 10, 64); err != nil {
return errors.Wrapf(err, "parsing value for %s", name)
}
fv.t = time.Unix(0, value.Timestamp)
fv.def = strings.Split(typeField.Tag.Get("values"), ",")
valField.Set(reflect.ValueOf(fv))
case reflect.TypeOf(TimedFloat{}):
fv := TimedFloat{}
if fv.v, err = strconv.ParseFloat(value.Value, 64); err != nil {
return errors.Wrapf(err, "parsing value for %s", name)
}
fv.t = time.Unix(0, value.Timestamp)
valField.Set(reflect.ValueOf(fv))
case reflect.TypeOf(TimedInt{}):
fv := TimedInt{}
if fv.v, err = strconv.ParseInt(value.Value, 10, 64); err != nil {
return errors.Wrapf(err, "parsing value for %s", name)
}
fv.t = time.Unix(0, value.Timestamp)
valField.Set(reflect.ValueOf(fv))
default:
return errors.Errorf("unknown type %v", typeField.Type)
}
}
return nil
}
func (a APIClient) request(path string, output any) error {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
url := strings.Join([]string{
strings.TrimRight(apiPrefix, "/"),
strings.TrimLeft(path, "/"),
}, "/")
logrus.WithField("url", url).Trace("creating request")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return errors.Wrap(err, "creating request")
}
req.Header.Set("accept", "application/json;charset=utf-8")
at, rt, exp, err := a.creds.GetToken()
if err != nil {
return errors.Wrap(err, "getting credentials")
}
tok := &oauth2.Token{AccessToken: at, RefreshToken: rt, Expiry: exp}
// Renew token if required
if tok.Expiry.Before(time.Now()) {
src := a.getOauth2Config("").TokenSource(ctx, tok)
if tok, err = src.Token(); err != nil {
return errors.Wrap(err, "renewing token")
}
if tok.AccessToken != at || tok.RefreshToken != rt || !tok.Expiry.Equal(exp) {
if err := a.creds.UpdateToken(tok.AccessToken, tok.RefreshToken, tok.Expiry); err != nil {
return errors.Wrap(err, "updating stored token")
}
}
}
resp, err := a.getOauth2Config("").Client(ctx, tok).Do(req)
if err != nil {
return errors.Wrap(err, "executing request")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "http status code %d, error reading body", resp.StatusCode)
}
return errors.Errorf("http status code %d, body %s", resp.StatusCode, body)
}
if output == nil {
return nil
}
return errors.Wrap(a.parseGenericAPIResponse(resp.Body, output), "decoding output")
}