forked from cosmos/cosmos-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
provider_desc.go
164 lines (137 loc) · 3.93 KB
/
provider_desc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package depinject
import (
"reflect"
"strings"
"unicode"
"github.com/cockroachdb/errors"
"golang.org/x/exp/slices"
)
// providerDescriptor defines a special provider type that is defined by
// reflection. It should be passed as a value to the Provide function.
// Ex:
//
// option.Provide(providerDescriptor{ ... })
type providerDescriptor struct {
// Inputs defines the in parameter types to Fn.
Inputs []providerInput
// Outputs defines the out parameter types to Fn.
Outputs []providerOutput
// Fn defines the provider function.
Fn func([]reflect.Value) ([]reflect.Value, error)
// Location defines the source code location to be used for this provider
// in error messages.
Location Location
}
type providerInput struct {
Type reflect.Type
Optional bool
}
type providerOutput struct {
Type reflect.Type
}
func extractProviderDescriptor(provider interface{}) (providerDescriptor, error) {
rctr, err := doExtractProviderDescriptor(provider)
if err != nil {
return providerDescriptor{}, err
}
return postProcessProvider(rctr)
}
func extractInvokerDescriptor(provider interface{}) (providerDescriptor, error) {
rctr, err := doExtractProviderDescriptor(provider)
if err != nil {
return providerDescriptor{}, err
}
// mark all inputs as optional
for i, input := range rctr.Inputs {
input.Optional = true
rctr.Inputs[i] = input
}
return postProcessProvider(rctr)
}
func doExtractProviderDescriptor(ctr interface{}) (providerDescriptor, error) {
val := reflect.ValueOf(ctr)
typ := val.Type()
if typ.Kind() != reflect.Func {
return providerDescriptor{}, errors.Errorf("expected a Func type, got %v", typ)
}
loc := LocationFromPC(val.Pointer()).(*location)
nameParts := strings.Split(loc.name, ".")
if len(nameParts) == 0 {
return providerDescriptor{}, errors.Errorf("missing function name %s", loc)
}
lastNamePart := nameParts[len(nameParts)-1]
if unicode.IsLower([]rune(lastNamePart)[0]) {
return providerDescriptor{}, errors.Errorf("function must be exported: %s", loc)
}
if strings.Contains(lastNamePart, "-") {
return providerDescriptor{}, errors.Errorf("function can't be used as a provider (it might be a bound instance method): %s", loc)
}
pkgParts := strings.Split(loc.pkg, "/")
if slices.Contains(pkgParts, "internal") {
return providerDescriptor{}, errors.Errorf("function must not be in an internal package: %s", loc)
}
if typ.IsVariadic() {
return providerDescriptor{}, errors.Errorf("variadic function can't be used as a provider: %s", loc)
}
numIn := typ.NumIn()
in := make([]providerInput, numIn)
for i := 0; i < numIn; i++ {
in[i] = providerInput{
Type: typ.In(i),
}
}
errIdx := -1
numOut := typ.NumOut()
var out []providerOutput
for i := 0; i < numOut; i++ {
t := typ.Out(i)
if t == errType {
if i != numOut-1 {
return providerDescriptor{}, errors.Errorf("output error parameter is not last parameter in function %s", loc)
}
errIdx = i
} else {
out = append(out, providerOutput{Type: t})
}
}
return providerDescriptor{
Inputs: in,
Outputs: out,
Fn: func(values []reflect.Value) ([]reflect.Value, error) {
res := val.Call(values)
if errIdx >= 0 {
err := res[errIdx]
if !err.IsZero() {
return nil, err.Interface().(error)
}
return res[0:errIdx], nil
}
return res, nil
},
Location: loc,
}, nil
}
var errType = reflect.TypeOf((*error)(nil)).Elem()
func postProcessProvider(descriptor providerDescriptor) (providerDescriptor, error) {
descriptor, err := expandStructArgsProvider(descriptor)
if err != nil {
return providerDescriptor{}, err
}
err = checkInputAndOutputTypes(descriptor)
return descriptor, err
}
func checkInputAndOutputTypes(descriptor providerDescriptor) error {
for _, input := range descriptor.Inputs {
err := isExportedType(input.Type)
if err != nil {
return err
}
}
for _, output := range descriptor.Outputs {
err := isExportedType(output.Type)
if err != nil {
return err
}
}
return nil
}