Skip to content

Commit

Permalink
feat(pkg/scale): Use New() receiver function for construction of cu…
Browse files Browse the repository at this point in the history
…stom `VaryingDataType` (#3315)
  • Loading branch information
timwu20 authored Jun 13, 2023
1 parent 1a34972 commit 9688f6c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,22 @@ func (ds *decodeState) decodeVaryingDataTypeSlice(dstv reflect.Value) (err error

func (ds *decodeState) decodeCustomVaryingDataType(dstv reflect.Value) (err error) {
initialType := dstv.Type()

methodVal := dstv.MethodByName("New")
if methodVal.IsValid() && !methodVal.IsZero() {
if methodVal.Type().Out(0).String() != dstv.Type().String() {
return fmt.Errorf("%s.New() returns %s instead of %s", dstv.Type(), methodVal.Type().Out(0), dstv.Type())
}

values := methodVal.Call(nil)
if len(values) > 1 {
return fmt.Errorf("%s.New() returns too many values", dstv.Type())
} else if len(values) == 0 {
return fmt.Errorf("%s.New() does not return a value", dstv.Type())
}
dstv.Set(values[0])
}

converted := dstv.Convert(reflect.TypeOf(VaryingDataType{}))
tempVal := reflect.New(converted.Type())
tempVal.Elem().Set(converted)
Expand Down
39 changes: 39 additions & 0 deletions pkg/scale/varying_data_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ func mustNewVaryingDataTypeAndSet(value VaryingDataTypeValue, values ...VaryingD

type customVDT VaryingDataType

type customVDTWithNew VaryingDataType

func (cvwn customVDTWithNew) New() customVDTWithNew {
return customVDTWithNew(mustNewVaryingDataType(VDTValue{}, VDTValue1{}, VDTValue2{}, VDTValue3(0)))
}

type VDTValue struct {
A *big.Int
B int
Expand Down Expand Up @@ -406,6 +412,39 @@ func Test_decodeState_decodeCustomVaryingDataType(t *testing.T) {
}
}

func Test_decodeState_decodeCustomVaryingDataTypeWithNew(t *testing.T) {
for _, tt := range varyingDataTypeTests {
t.Run(tt.name, func(t *testing.T) {
dst := customVDTWithNew{}
if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
return
}

dstVDT := reflect.ValueOf(tt.in).Convert(reflect.TypeOf(VaryingDataType{})).Interface().(VaryingDataType)
inVDT := reflect.ValueOf(tt.in).Convert(reflect.TypeOf(VaryingDataType{})).Interface().(VaryingDataType)
dstVDTVal, err := dstVDT.Value()
if err != nil {
t.Errorf("%v", err)
return
}
inVDTVal, err := inVDT.Value()
if err != nil {
t.Errorf("%v", err)
return
}
diff := cmp.Diff(dstVDTVal, inVDTVal,
cmpopts.IgnoreUnexported(big.Int{}, VDTValue2{}, MyStructWithIgnore{}))
if diff != "" {
t.Errorf("decodeState.unmarshal() = %s", diff)
}
if reflect.TypeOf(dst) != reflect.TypeOf(customVDTWithNew{}) {
t.Errorf("types mismatch dst: %v expected: %v", reflect.TypeOf(dst), reflect.TypeOf(customVDT{}))
}
})
}
}

func TestNewVaryingDataType(t *testing.T) {
type args struct {
values []VaryingDataTypeValue
Expand Down

0 comments on commit 9688f6c

Please sign in to comment.