From f4e07b507ab399277a7942f15fb917d3f11d7b8e Mon Sep 17 00:00:00 2001 From: Knut Ahlers Date: Mon, 17 Sep 2018 23:21:36 +0200 Subject: [PATCH] Add support for time.Time flags Signed-off-by: Knut Ahlers --- config.go | 110 +++++++++++++++++++++++++++++++++++++++++++++------ time_test.go | 51 ++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 13 deletions(-) create mode 100644 time_test.go diff --git a/config.go b/config.go index eedde00..7df4554 100644 --- a/config.go +++ b/config.go @@ -16,10 +16,25 @@ import ( validator "gopkg.in/validator.v2" ) +type afterFunc func() error + var ( autoEnv bool fs *pflag.FlagSet variableDefaults map[string]string + + timeParserFormats = []string{ + // Default constants + time.RFC3339Nano, time.RFC3339, + time.RFC1123Z, time.RFC1123, + time.RFC822Z, time.RFC822, + time.RFC850, time.RubyDate, time.UnixDate, time.ANSIC, + "2006-01-02 15:04:05.999999999 -0700 MST", + // More uncommon time formats + "2006-01-02 15:04:05", "2006-01-02 15:04:05Z07:00", // Simplified ISO time format + "01/02/2006 15:04:05", "01/02/2006 15:04:05Z07:00", // US time format + "02.01.2006 15:04:05", "02.01.2006 15:04:05Z07:00", // DE time format + } ) func init() { @@ -61,6 +76,11 @@ func Args() []string { return fs.Args() } +// AddTimeParserFormats adds custom formats to parse time.Time fields +func AddTimeParserFormats(f ...string) { + timeParserFormats = append(timeParserFormats, f...) +} + // AutoEnv enables or disables automated env variable guessing. If no `env` struct // tag was set and AutoEnv is enabled the env variable name is derived from the // name of the field: `MyFieldName` will get `MY_FIELD_NAME` @@ -97,22 +117,37 @@ func parse(in interface{}, args []string) error { } fs = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError) - if err := execTags(in, fs); err != nil { + afterFuncs, err := execTags(in, fs) + if err != nil { return err } - return fs.Parse(args) + if err := fs.Parse(args); err != nil { + return err + } + + if afterFuncs != nil { + for _, f := range afterFuncs { + if err := f(); err != nil { + return err + } + } + } + + return nil } -func execTags(in interface{}, fs *pflag.FlagSet) error { +func execTags(in interface{}, fs *pflag.FlagSet) ([]afterFunc, error) { if reflect.TypeOf(in).Kind() != reflect.Ptr { - return errors.New("Calling parser with non-pointer") + return nil, errors.New("Calling parser with non-pointer") } if reflect.ValueOf(in).Elem().Kind() != reflect.Struct { - return errors.New("Calling parser with pointer to non-struct") + return nil, errors.New("Calling parser with pointer to non-struct") } + afterFuncs := []afterFunc{} + st := reflect.ValueOf(in).Elem() for i := 0; i < st.NumField(); i++ { valField := st.Field(i) @@ -134,7 +169,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { if value == "" { v = time.Duration(0) } else { - return err + return nil, err } } @@ -148,6 +183,53 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { valField.Set(reflect.ValueOf(v)) } continue + + case reflect.TypeOf(time.Time{}): + var sVar string + + if typeField.Tag.Get("flag") != "" { + if len(parts) == 1 { + fs.StringVar(&sVar, parts[0], value, typeField.Tag.Get("description")) + } else { + fs.StringVarP(&sVar, parts[0], parts[1], value, typeField.Tag.Get("description")) + } + } else { + sVar = value + } + + afterFuncs = append(afterFuncs, func(valField reflect.Value, sVar *string) func() error { + return func() error { + if *sVar == "" { + // No time, no problem + return nil + } + + // Check whether we could have a timestamp + if ts, err := strconv.ParseInt(*sVar, 10, 64); err == nil { + t := time.Unix(ts, 0) + valField.Set(reflect.ValueOf(t)) + return nil + } + + // We haven't so lets walk through possible time formats + matched := false + for _, tf := range timeParserFormats { + if t, err := time.Parse(tf, *sVar); err == nil { + matched = true + valField.Set(reflect.ValueOf(t)) + return nil + } + } + + if !matched { + return fmt.Errorf("Value %q did not match expected time formats", *sVar) + } + + return nil + } + }(valField, &sVar)) + + continue } switch typeField.Type.Kind() { @@ -180,7 +262,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { if value == "" { vt = 0 } else { - return err + return nil, err } } if typeField.Tag.Get("flag") != "" { @@ -195,7 +277,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { if value == "" { vt = 0 } else { - return err + return nil, err } } if typeField.Tag.Get("flag") != "" { @@ -210,7 +292,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { if value == "" { vt = 0.0 } else { - return err + return nil, err } } if typeField.Tag.Get("flag") != "" { @@ -220,9 +302,11 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { } case reflect.Struct: - if err := execTags(valField.Addr().Interface(), fs); err != nil { - return err + afs, err := execTags(valField.Addr().Interface(), fs) + if err != nil { + return nil, err } + afterFuncs = append(afterFuncs, afs...) case reflect.Slice: switch typeField.Type.Elem().Kind() { @@ -231,7 +315,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { for _, v := range strings.Split(value, ",") { it, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64) if err != nil { - return err + return nil, err } def = append(def, int(it)) } @@ -258,7 +342,7 @@ func execTags(in interface{}, fs *pflag.FlagSet) error { } } - return nil + return afterFuncs, nil } func registerFlagFloat(t reflect.Kind, fs *pflag.FlagSet, field interface{}, parts []string, vt float64, desc string) { diff --git a/time_test.go b/time_test.go new file mode 100644 index 0000000..2508a23 --- /dev/null +++ b/time_test.go @@ -0,0 +1,51 @@ +package rconfig + +import ( + "fmt" + "testing" + "time" +) + +func TestParseTime(t *testing.T) { + type ts struct { + Test time.Time `flag:"time"` + TestS time.Time `flag:"other-time,o"` + TestDef time.Time `default:"2006-01-02T15:04:05.999999999Z"` + TestDE time.Time `default:"18.09.2018 20:25:31"` + } + + var ( + err error + args []string + cfg ts + ) + + for _, tf := range timeParserFormats { + expect := time.Now().Format(tf) + + cfg = ts{} + args = []string{ + fmt.Sprintf("--time=%s", expect), + "-o", expect, + } + + if err = parse(&cfg, args); err != nil { + t.Fatalf("Time format %q did not parse: %s", tf, err) + } + + for name, ti := range map[string]time.Time{ + "Long flag": cfg.Test, + "Short flag": cfg.TestS, + "Default flag": cfg.TestDef, + "DE flag": cfg.TestDE, + } { + if ti.IsZero() { + t.Errorf("%s did parse to zero with format %q", name, tf) + } + } + + if e := cfg.Test.Format(tf); e != expect { + t.Errorf("Parsed time %q did not match expectation %q", e, expect) + } + } +}