Skip to content

Commit 062d7aa

Browse files
committed
v1 add time.Time support
1 parent 56ec9a6 commit 062d7aa

File tree

4 files changed

+167
-39
lines changed

4 files changed

+167
-39
lines changed

readers_priors.go

+83-35
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,63 @@
1717
package rtl
1818

1919
import (
20+
"encoding"
21+
"errors"
2022
"fmt"
2123
"reflect"
2224
)
2325

2426
type headerValueReader func(th TypeHeader, length int, vr ValueReader, value reflect.Value, nesting int) error
2527

2628
var (
27-
_priorStructReaders = map[reflect.Type]headerValueReader{
28-
typeOfBigInt: bigIntReader0,
29-
typeOfBigRat: bigRatReader0,
30-
typeOfBigFloat: bigFloatReader0,
29+
_priorStructReaders = map[reflect.Type]map[TypeHeader]typeReaderFunc{
30+
typeOfBigInt: bigIntReaders,
31+
typeOfBigRat: bigRatReaders,
32+
typeOfBigFloat: bigFloatReaders,
33+
typeOfTime: binaryUnmarshalerReaders,
34+
}
35+
36+
binaryUnmarshalerReaders = map[TypeHeader]typeReaderFunc{
37+
THSingleByte: func(length int, vr ValueReader, value reflect.Value, nesting int) error {
38+
return setToBinaryUnmarshaler(value, []byte{byte(length)})
39+
},
40+
THZeroValue: func(length int, vr ValueReader, value reflect.Value, nesting int) error {
41+
value.Set(reflect.Zero(value.Type()))
42+
return nil
43+
},
44+
THStringSingle: func(length int, vr ValueReader, value reflect.Value, nesting int) error {
45+
buf, err := vr.ReadBytes(length, nil)
46+
if err != nil {
47+
return err
48+
}
49+
return setToBinaryUnmarshaler(value, buf)
50+
},
51+
THStringMulti: func(length int, vr ValueReader, value reflect.Value, nesting int) error {
52+
l, err := vr.ReadMultiLength(length)
53+
if err != nil {
54+
return err
55+
}
56+
buf, err := vr.ReadBytes(int(l), nil)
57+
if err != nil {
58+
return err
59+
}
60+
return setToBinaryUnmarshaler(value, buf)
61+
},
3162
}
3263
)
3364

3465
func checkPriorStructsReader(th TypeHeader, length int, vr ValueReader, value reflect.Value, nesting int) (matched bool, err error) {
3566
typ := value.Type()
3667
for _, prior := range _priorStructOrder {
3768
if typ.AssignableTo(prior) || typ.AssignableTo(reflect.PtrTo(prior)) {
38-
fn, exist := _priorStructReaders[prior]
69+
readers, exist := _priorStructReaders[prior]
3970
if exist {
40-
err = fn(th, length, vr, value, nesting)
71+
fn := getFunc(typ, readers, th)
72+
if typ.AssignableTo(prior) {
73+
err = fn(length, vr, value.Addr(), nesting)
74+
} else {
75+
err = fn(length, vr, value, nesting)
76+
}
4177
return true, err
4278
}
4379
}
@@ -62,36 +98,48 @@ func bigIntReader0(th TypeHeader, length int, vr ValueReader, value reflect.Valu
6298
return fmt.Errorf("rtl: should be big.Int or *big.Int, but %s", typ.Name())
6399
}
64100

65-
func bigRatReader0(th TypeHeader, length int, vr ValueReader, value reflect.Value, nesting int) error {
66-
typ := value.Type()
67-
68-
// big.Rat
69-
if typ.AssignableTo(typeOfBigRat) {
70-
f := getFunc(typ, bigRatReaders, th)
71-
return f(length, vr, value.Addr(), nesting)
72-
}
73-
// *big.Rat
74-
if typ.AssignableTo(reflect.PtrTo(typeOfBigRat)) {
75-
f := getFunc(typ, bigRatReaders, th)
76-
return f(length, vr, value, nesting)
77-
}
78-
79-
return fmt.Errorf("rtl: should be big.Rat or *big.Rat, but %s", typ.Name())
80-
}
81-
82-
func bigFloatReader0(th TypeHeader, length int, vr ValueReader, value reflect.Value, nesting int) error {
83-
typ := value.Type()
101+
// func bigRatReader0(th TypeHeader, length int, vr ValueReader, value reflect.Value, nesting int) error {
102+
// typ := value.Type()
103+
//
104+
// // big.Rat
105+
// if typ.AssignableTo(typeOfBigRat) {
106+
// f := getFunc(typ, bigRatReaders, th)
107+
// return f(length, vr, value.Addr(), nesting)
108+
// }
109+
// // *big.Rat
110+
// if typ.AssignableTo(reflect.PtrTo(typeOfBigRat)) {
111+
// f := getFunc(typ, bigRatReaders, th)
112+
// return f(length, vr, value, nesting)
113+
// }
114+
//
115+
// return fmt.Errorf("rtl: should be big.Rat or *big.Rat, but %s", typ.Name())
116+
// }
117+
//
118+
// func bigFloatReader0(th TypeHeader, length int, vr ValueReader, value reflect.Value, nesting int) error {
119+
// typ := value.Type()
120+
//
121+
// // big.Float
122+
// if typ.AssignableTo(typeOfBigFloat) {
123+
// f := getFunc(typ, bigFloatReaders, th)
124+
// return f(length, vr, value.Addr(), nesting)
125+
// }
126+
// // *big.Float
127+
// if typ.AssignableTo(reflect.PtrTo(typeOfBigFloat)) {
128+
// f := getFunc(typ, bigFloatReaders, th)
129+
// return f(length, vr, value, nesting)
130+
// }
131+
//
132+
// return fmt.Errorf("rtl: should be big.Float or *big.Float, but %s", typ.Name())
133+
// }
84134

85-
// big.Float
86-
if typ.AssignableTo(typeOfBigFloat) {
87-
f := getFunc(typ, bigFloatReaders, th)
88-
return f(length, vr, value.Addr(), nesting)
135+
// value must be a pointer of a type, and implemented encoding.BinaryUnmarshaler
136+
func setToBinaryUnmarshaler(value reflect.Value, bs []byte) error {
137+
if value.Kind() != reflect.Pointer {
138+
return errors.New("rtl: BinaryUnmarshaler need a pointer")
89139
}
90-
// *big.Float
91-
if typ.AssignableTo(reflect.PtrTo(typeOfBigFloat)) {
92-
f := getFunc(typ, bigFloatReaders, th)
93-
return f(length, vr, value, nesting)
140+
if value.IsNil() {
141+
value.Set(reflect.New(value.Type().Elem()))
94142
}
95-
96-
return fmt.Errorf("rtl: should be big.Float or *big.Float, but %s", typ.Name())
143+
bu := value.Interface().(encoding.BinaryUnmarshaler)
144+
return bu.UnmarshalBinary(bs)
97145
}

types.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"strconv"
2727
"strings"
2828
"sync"
29+
"time"
2930
)
3031

3132
type THValueType byte
@@ -187,6 +188,8 @@ var (
187188
typeOfBigRat = reflect.TypeOf(big.Rat{})
188189
typeOfBigFloat = reflect.TypeOf(big.Float{})
189190

191+
typeOfTime = reflect.TypeOf((*time.Time)(nil)).Elem()
192+
190193
// []interface{} type
191194
typeOfInterfaceSlice = reflect.TypeOf([]interface{}{})
192195
typeOfInterface = reflect.TypeOf((*interface{})(nil)).Elem()
@@ -325,8 +328,8 @@ func (h headMaker) numeric(isNegative bool, length int) ([]byte, error) {
325328

326329
if length <= 8 {
327330
r := make([]byte, 1)
328-
h.numericBuf(isNegative, length, r)
329-
return r, nil
331+
_, err := h.numericBuf(isNegative, length, r)
332+
return r, err
330333
}
331334

332335
r := make([]byte, 9)

writers_priors.go

+12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package rtl
1818

1919
import (
20+
"encoding"
2021
"encoding/gob"
2122
"io"
2223
"math/big"
@@ -28,12 +29,14 @@ var (
2829
typeOfBigInt: bigIntWriter,
2930
typeOfBigRat: gobEncoderNumberWriter,
3031
typeOfBigFloat: gobEncoderNumberWriter,
32+
typeOfTime: binaryMarshalerBytesWriter,
3133
}
3234

3335
_priorStructOrder = []reflect.Type{
3436
typeOfBigInt,
3537
typeOfBigRat,
3638
typeOfBigFloat,
39+
typeOfTime,
3740
}
3841
)
3942

@@ -85,3 +88,12 @@ func gobEncoderNumberWriter(w io.Writer, v reflect.Value) (int, error) {
8588
}
8689
return _writeNumberBytes(w, false, b)
8790
}
91+
92+
func binaryMarshalerBytesWriter(w io.Writer, v reflect.Value) (int, error) {
93+
bm := v.Interface().(encoding.BinaryMarshaler)
94+
b, err := bm.MarshalBinary()
95+
if err != nil {
96+
return 0, err
97+
}
98+
return bytesWriter(w, b)
99+
}

writers_test.go

+67-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"reflect"
2626
"strings"
2727
"testing"
28+
"time"
2829
)
2930

3031
func Unhex(str string) []byte {
@@ -119,7 +120,8 @@ type hasIgnoredField struct {
119120
}
120121

121122
type param struct {
122-
val interface{}
123+
val interface{}
124+
equaler func(a interface{}, b reflect.Value) bool
123125
}
124126

125127
type mapstruct struct {
@@ -143,11 +145,21 @@ type arrayAndSlice struct {
143145
B []byte
144146
}
145147

148+
type timeObjects struct {
149+
A time.Time
150+
B *time.Time
151+
}
152+
146153
var (
147154
string1 = "string1"
148155
string2 = "string2"
149156
)
150157

158+
func timePointer() *time.Time {
159+
n := time.Now()
160+
return &n
161+
}
162+
151163
var encTests = []param{
152164

153165
{val: float32(111.3)},
@@ -338,6 +350,31 @@ var encTests = []param{
338350

339351
{val: &arrayAndSlice{A: [32]byte{8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}, B: []byte("sssssssss")}},
340352
{val: &arrayAndSlice{A: [32]byte{8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}, B: nil}},
353+
354+
{val: time.Now()},
355+
{val: timePointer()},
356+
357+
{val: &timeObjects{A: time.Now(), B: timePointer()},
358+
equaler: func(a interface{}, b reflect.Value) bool {
359+
p := a.(*timeObjects)
360+
if b.Kind() != reflect.Pointer || b.IsNil() {
361+
return false
362+
}
363+
o := b.Interface().(*timeObjects)
364+
return p.A.Equal(o.A) && o.B != nil && (*p.B).Equal(*o.B)
365+
}},
366+
{val: &timeObjects{A: time.Now()},
367+
equaler: func(a interface{}, b reflect.Value) bool {
368+
p := a.(*timeObjects)
369+
if b.Kind() != reflect.Pointer || b.IsNil() {
370+
return false
371+
}
372+
o := b.Interface().(*timeObjects)
373+
if o.B != nil {
374+
return false
375+
}
376+
return p.A.Equal(o.A)
377+
}},
341378
}
342379

343380
func TestEncode(t *testing.T) {
@@ -364,14 +401,42 @@ func TestEncode(t *testing.T) {
364401
// }
365402

366403
fmt.Printf("%v: %#v\n\t%X\n%v: %#v\n", typ, test.val, bs, nvv.Type(), nvv)
367-
if reflect.DeepEqual(test.val, nvv.Interface()) {
404+
equaler := _valueEqualer
405+
if test.equaler != nil {
406+
equaler = test.equaler
407+
}
408+
if equaler(test.val, nvv) {
368409
t.Log(test.val, "check")
369410
} else {
370411
t.Error(test.val, "error")
371412
}
372413
}
373414
}
374415

416+
func _valueEqualer(a interface{}, b reflect.Value) bool {
417+
switch s := a.(type) {
418+
case *time.Time:
419+
if b.Kind() != reflect.Pointer {
420+
return false
421+
}
422+
if s == nil && b.IsNil() {
423+
return true
424+
}
425+
if b.IsNil() {
426+
return false
427+
}
428+
return (*s).Equal(b.Elem().Interface().(time.Time))
429+
case time.Time:
430+
return s.Equal(b.Interface().(time.Time))
431+
default:
432+
if reflect.DeepEqual(a, b.Interface()) {
433+
return true
434+
} else {
435+
return false
436+
}
437+
}
438+
}
439+
375440
// string -> []int
376441
func TestStringArray(t *testing.T) {
377442
s := "this is a string"

0 commit comments

Comments
 (0)