Skip to content

Commit

Permalink
replace some Interface() check when encode private embed struct
Browse files Browse the repository at this point in the history
  • Loading branch information
kkHAIKE committed Jun 17, 2022
1 parent eaf0d98 commit 0ae83fe
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ var dblQuotedReplacer = strings.NewReplacer(
"\x7f", `\u007f`,
)

var (
marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
)

// Marshaler is the interface implemented by types that can marshal themselves
// into valid TOML.
type Marshaler interface {
Expand Down Expand Up @@ -154,12 +160,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
// If we can marshal the type to text, then we use that. This prevents the
// encoder for handling these types as generic structs (or whatever the
// underlying type of a TextMarshaler is).
switch t := rv.Interface().(type) {
case encoding.TextMarshaler, Marshaler:
switch {
case isMarshaler(rv):
enc.writeKeyValue(key, rv, false)
return
case Primitive: // TODO: #76 would make this superfluous after implemented.
enc.encode(key, reflect.ValueOf(t.undecoded))
case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented.
enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded))
return
}

Expand Down Expand Up @@ -429,11 +435,19 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
rt = rv.Type()
fieldsDirect, fieldsSub [][]int
addFields func(rt reflect.Type, rv reflect.Value, start []int)
ptrto func(t reflect.Type) reflect.Type
)
ptrto = func(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
return ptrto(t.Elem())
}
return t
}
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
isEmbed := f.Anonymous && ptrto(f.Type).Kind() == reflect.Struct
if f.PkgPath != "" && !isEmbed { /// Skip unexported fields.
continue
}
opts := getOptions(f.Tag)
Expand All @@ -447,7 +461,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
// not anonymous, like encoding/json does.
//
// Non-struct anonymous fields use the normal encoding logic.
if f.Anonymous {
if isEmbed {
if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct {
addFields(frv.Type(), frv, append(start, f.Index...))
continue
Expand Down Expand Up @@ -531,7 +545,7 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
}

if rv.Kind() == reflect.Struct {
if _, ok := rv.Interface().(time.Time); ok {
if rv.Type() == timeType {
return tomlDatetime
}
if isMarshaler(rv) {
Expand Down Expand Up @@ -572,13 +586,8 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
}

func isMarshaler(rv reflect.Value) bool {
switch rv.Interface().(type) {
case encoding.TextMarshaler:
return true
case Marshaler:
return true
}
return false
return rv.Type().Implements(marshalText) ||
rv.Type().Implements(marshalToml)
}

// isTableArray reports if all entries in the array or slice are a table.
Expand Down

0 comments on commit 0ae83fe

Please sign in to comment.