Skip to content

Commit 087b655

Browse files
torkelrogstadboyan-soubachov
authored andcommitted
assert: allow comparing time.Time
1 parent 7bcf74e commit 087b655

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

assert/assertion_compare.go

+24
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package assert
33
import (
44
"fmt"
55
"reflect"
6+
"time"
67
)
78

89
type CompareType int
@@ -30,6 +31,8 @@ var (
3031
float64Type = reflect.TypeOf(float64(1))
3132

3233
stringType = reflect.TypeOf("")
34+
35+
timeType = reflect.TypeOf(time.Time{})
3336
)
3437

3538
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@@ -299,6 +302,27 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
299302
return compareLess, true
300303
}
301304
}
305+
// Check for known struct types we can check for compare results.
306+
case reflect.Struct:
307+
{
308+
// All structs enter here. We're not interested in most types.
309+
if !obj1Value.CanConvert(timeType) {
310+
break
311+
}
312+
313+
// time.Time can compared!
314+
timeObj1, ok := obj1.(time.Time)
315+
if !ok {
316+
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
317+
}
318+
319+
timeObj2, ok := obj2.(time.Time)
320+
if !ok {
321+
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
322+
}
323+
324+
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
325+
}
302326
}
303327

304328
return compareEqual, false

assert/assertion_compare_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"reflect"
77
"runtime"
88
"testing"
9+
"time"
910
)
1011

1112
func TestCompare(t *testing.T) {
@@ -22,6 +23,7 @@ func TestCompare(t *testing.T) {
2223
type customFloat32 float32
2324
type customFloat64 float64
2425
type customString string
26+
type customTime time.Time
2527
for _, currCase := range []struct {
2628
less interface{}
2729
greater interface{}
@@ -52,14 +54,17 @@ func TestCompare(t *testing.T) {
5254
{less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"},
5355
{less: float64(1.23), greater: float64(2.34), cType: "float64"},
5456
{less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"},
57+
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
58+
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
5559
} {
5660
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
5761
if !isComparable {
5862
t.Error("object should be comparable for type " + currCase.cType)
5963
}
6064

6165
if resLess != compareLess {
62-
t.Errorf("object less should be less than greater for type " + currCase.cType)
66+
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
67+
currCase.less, currCase.greater)
6368
}
6469

6570
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())

0 commit comments

Comments
 (0)