package asyncassertion import ( "errors" "fmt" "reflect" "time" "github.com/onsi/gomega/types" ) type AsyncAssertionType uint const ( AsyncAssertionTypeEventually AsyncAssertionType = iota AsyncAssertionTypeConsistently ) type AsyncAssertion struct { asyncType AsyncAssertionType actualInput interface{} timeoutInterval time.Duration pollingInterval time.Duration fail types.GomegaFailHandler offset int } func New(asyncType AsyncAssertionType, actualInput interface{}, fail types.GomegaFailHandler, timeoutInterval time.Duration, pollingInterval time.Duration, offset int) *AsyncAssertion { actualType := reflect.TypeOf(actualInput) if actualType.Kind() == reflect.Func { if actualType.NumIn() != 0 || actualType.NumOut() == 0 { panic("Expected a function with no arguments and one or more return values.") } } return &AsyncAssertion{ asyncType: asyncType, actualInput: actualInput, fail: fail, timeoutInterval: timeoutInterval, pollingInterval: pollingInterval, offset: offset, } } func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool { return assertion.match(matcher, true, optionalDescription...) } func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool { return assertion.match(matcher, false, optionalDescription...) } func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interface{}) string { switch len(optionalDescription) { case 0: return "" default: return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n" } } func (assertion *AsyncAssertion) actualInputIsAFunction() bool { actualType := reflect.TypeOf(assertion.actualInput) return actualType.Kind() == reflect.Func && actualType.NumIn() == 0 && actualType.NumOut() > 0 } func (assertion *AsyncAssertion) pollActual() (interface{}, error) { if assertion.actualInputIsAFunction() { values := reflect.ValueOf(assertion.actualInput).Call([]reflect.Value{}) extras := []interface{}{} for _, value := range values[1:] { extras = append(extras, value.Interface()) } success, message := vetExtras(extras) if !success { return nil, errors.New(message) } return values[0].Interface(), nil } return assertion.actualInput, nil } type oracleMatcher interface { MatchMayChangeInTheFuture(actual interface{}) bool } func (assertion *AsyncAssertion) matcherMayChange(matcher types.GomegaMatcher, value interface{}) bool { if assertion.actualInputIsAFunction() { return true } oracleMatcher, ok := matcher.(oracleMatcher) if !ok { return true } return oracleMatcher.MatchMayChangeInTheFuture(value) } func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool { timer := time.Now() timeout := time.After(assertion.timeoutInterval) description := assertion.buildDescription(optionalDescription...) var matches bool var err error mayChange := true value, err := assertion.pollActual() if err == nil { mayChange = assertion.matcherMayChange(matcher, value) matches, err = matcher.Match(value) } fail := func(preamble string) { errMsg := "" message := "" if err != nil { errMsg = "Error: " + err.Error() } else { if desiredMatch { message = matcher.FailureMessage(value) } else { message = matcher.NegatedFailureMessage(value) } } assertion.fail(fmt.Sprintf("%s after %.3fs.\n%s%s%s", preamble, time.Since(timer).Seconds(), description, message, errMsg), 3+assertion.offset) } if assertion.asyncType == AsyncAssertionTypeEventually { for { if err == nil && matches == desiredMatch { return true } if !mayChange { fail("No future change is possible. Bailing out early") return false } select { case <-time.After(assertion.pollingInterval): value, err = assertion.pollActual() if err == nil { mayChange = assertion.matcherMayChange(matcher, value) matches, err = matcher.Match(value) } case <-timeout: fail("Timed out") return false } } } else if assertion.asyncType == AsyncAssertionTypeConsistently { for { if !(err == nil && matches == desiredMatch) { fail("Failed") return false } if !mayChange { return true } select { case <-time.After(assertion.pollingInterval): value, err = assertion.pollActual() if err == nil { mayChange = assertion.matcherMayChange(matcher, value) matches, err = matcher.Match(value) } case <-timeout: return true } } } return false } func vetExtras(extras []interface{}) (bool, string) { for i, extra := range extras { if extra != nil { zeroValue := reflect.Zero(reflect.TypeOf(extra)).Interface() if !reflect.DeepEqual(zeroValue, extra) { message := fmt.Sprintf("Unexpected non-nil/non-zero extra argument at index %d:\n\t<%T>: %#v", i+1, extra, extra) return false, message } } } return true, "" }