-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
overload.go
424 lines (379 loc) · 11.6 KB
/
overload.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
// Copyright 2015 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.
//
// Author: Nathan VanBenschoten (nvanbenschoten@gmail.com)
package parser
import "fmt"
// overloadImpl is an implementation of an overloaded function. It provides
// access to the parameter type list and the return type of the implementation.
type overloadImpl interface {
params() typeList
returnType() Datum
}
// typeList is a list of types representing a function parameter list.
type typeList interface {
match(types ArgTypes) bool
matchAt(typ Datum, i int) bool
matchLen(l int) bool
getAt(i int) Datum
}
var _ typeList = ArgTypes{}
var _ typeList = AnyType{}
var _ typeList = VariadicType{}
var _ typeList = SingleType{}
// ArgTypes is a typeList implementation that accepts a specific number of
// argument types.
type ArgTypes []Datum
func (a ArgTypes) match(types ArgTypes) bool {
if len(types) != len(a) {
return false
}
for i := range types {
if !a.matchAt(types[i], i) {
return false
}
}
return true
}
func (a ArgTypes) matchAt(typ Datum, i int) bool {
if i >= len(a) {
return false
}
if _, ok := typ.(*DTuple); ok {
typ = dummyTuple
}
return a[i].TypeEqual(typ)
}
func (a ArgTypes) matchLen(l int) bool {
return len(a) == l
}
func (a ArgTypes) getAt(i int) Datum {
return a[i]
}
// AnyType is a typeList implementation that accepts any arguments.
type AnyType struct{}
func (AnyType) match(types ArgTypes) bool {
return true
}
func (AnyType) matchAt(typ Datum, i int) bool {
return true
}
func (AnyType) matchLen(l int) bool {
return true
}
func (AnyType) getAt(i int) Datum {
panic("getAt called on AnyType")
}
// VariadicType is a typeList implementation which accepts any number of
// arguments and matches when each argument is either NULL or of the type
// typ.
type VariadicType struct {
Typ Datum
}
func (v VariadicType) match(types ArgTypes) bool {
for i := range types {
if !v.matchAt(types[i], i) {
return false
}
}
return true
}
func (v VariadicType) matchAt(typ Datum, i int) bool {
return typ == DNull || typ.TypeEqual(v.Typ)
}
func (v VariadicType) matchLen(l int) bool {
return true
}
func (v VariadicType) getAt(i int) Datum {
return v.Typ
}
// SingleType is a typeList implementation which accepts a single
// argument of type typ. It is logically identical to an ArgTypes
// implementation with length 1, but avoids the slice allocation.
type SingleType struct {
Typ Datum
}
func (s SingleType) match(types ArgTypes) bool {
if len(types) != 1 {
return false
}
return s.matchAt(types[0], 0)
}
func (s SingleType) matchAt(typ Datum, i int) bool {
if i != 0 {
return false
}
return typ.TypeEqual(s.Typ)
}
func (s SingleType) matchLen(l int) bool {
return l == 1
}
func (s SingleType) getAt(i int) Datum {
if i != 0 {
return nil
}
return s.Typ
}
// typeCheckOverloadedExprs determines the correct overload to use for the given set of
// expression parameters, along with an optional desired return type. It returns the expression
// parameters after being type checked, along with the chosen overloadImpl. If an overloaded function
// implementation could not be determined, the overloadImpl return value will be nil.
func typeCheckOverloadedExprs(args MapArgs, desired Datum, overloads []overloadImpl, exprs ...Expr) ([]TypedExpr, overloadImpl, error) {
// Special-case the AnyType overload. We determine its return type by checking that
// all parameters have the same type.
for _, overload := range overloads {
// Only one overload can be provided if it has parameters with AnyType.
if _, ok := overload.params().(AnyType); ok {
if len(overloads) > 1 {
return nil, nil, fmt.Errorf("only one overload can have parameters with AnyType")
}
typedExprs, _, err := typeCheckSameTypedExprs(args, desired, exprs...)
if err != nil {
return nil, nil, err
}
return typedExprs, overload, nil
}
}
// Hold the resolved type expressions of the provided exprs, in order.
typedExprs := make([]TypedExpr, len(exprs))
var resolvableExprs, constExprs, valExprs []indexedExpr
for i, expr := range exprs {
idxExpr := indexedExpr{e: expr, i: i}
switch {
case isNumericConstant(expr):
constExprs = append(constExprs, idxExpr)
case isUnresolvedVariable(args, expr):
valExprs = append(valExprs, idxExpr)
default:
resolvableExprs = append(resolvableExprs, idxExpr)
}
}
// defaultTypeCheck type checks the constant and valArg expressions without a preference
// and adds them to the type checked slice.
defaultTypeCheck := func(errorOnArgs bool) error {
for _, expr := range constExprs {
typ, err := expr.e.TypeCheck(args, nil)
if err != nil {
return fmt.Errorf("error type checking constant value: %v", err)
}
typedExprs[expr.i] = typ
}
for _, expr := range valExprs {
if errorOnArgs {
_, err := expr.e.(ValArg).TypeCheck(args, nil)
return err
}
// If we dont want to error on args, avoid type checking them without a desired type.
typedExprs[expr.i] = &DValArg{name: expr.e.(ValArg).name}
}
return nil
}
// If no overloads are provided, just type check parameters and return.
if len(overloads) == 0 {
for _, expr := range resolvableExprs {
typ, err := expr.e.TypeCheck(args, nil)
if err != nil {
return nil, nil, fmt.Errorf("error type checking resolved expression: %v", err)
}
typedExprs[expr.i] = typ
}
if err := defaultTypeCheck(false); err != nil {
return nil, nil, err
}
return typedExprs, nil, nil
}
// Function to filter overloads which return false from the provided closure.
filterOverloads := func(fn func(overloadImpl) bool) {
for i := 0; i < len(overloads); {
if fn(overloads[i]) {
i++
} else {
overloads[i], overloads[len(overloads)-1] = overloads[len(overloads)-1], overloads[i]
overloads = overloads[:len(overloads)-1]
}
}
}
// Filter out incorrect parameter length overloads.
filterOverloads(func(o overloadImpl) bool {
return o.params().matchLen(len(exprs))
})
// Filter out overloads which constants cannot become.
for _, expr := range constExprs {
constExpr := expr.e.(*NumVal)
filterOverloads(func(o overloadImpl) bool {
return canConstantBecome(constExpr, o.params().getAt(expr.i))
})
}
// TODO(nvanbenschoten) We should add a filtering step here to filter
// out impossible candidates based on identical parameters. For instance,
// f(int, float) is not a possible candidate for the expression f($1, $1).
// Filter out overloads on resolved types.
for _, expr := range resolvableExprs {
var paramDesired Datum
if len(overloads) == 1 {
// Once we get down to a single overload candidate, begin desiring its
// parameter types for the corresponding argument expressions.
paramDesired = overloads[0].params().getAt(expr.i)
}
typ, err := expr.e.TypeCheck(args, paramDesired)
if err != nil {
return nil, nil, err
}
typedExprs[expr.i] = typ
filterOverloads(func(o overloadImpl) bool {
return o.params().matchAt(typ.ReturnType(), expr.i)
})
}
// checkReturn checks the number of remaining overloaded function implementations, returning
// if we should stop overload resolution, along with a nullable overloadImpl to return if
// we should stop overload resolution.
checkReturn := func() (bool, overloadImpl, error) {
switch len(overloads) {
case 0:
if err := defaultTypeCheck(false); err != nil {
return true, nil, err
}
return true, nil, nil
case 1:
o := overloads[0]
p := o.params()
for _, expr := range constExprs {
des := p.getAt(expr.i)
typ, err := expr.e.TypeCheck(args, des)
if err != nil {
return true, nil, fmt.Errorf("error type checking constant value: %v", err)
} else if des != nil && !typ.ReturnType().TypeEqual(des) {
panic(fmt.Errorf("desired constant value type %s but set type %s", des.Type(), typ.ReturnType().Type()))
}
typedExprs[expr.i] = typ
}
for _, expr := range valExprs {
des := p.getAt(expr.i)
typ, err := expr.e.TypeCheck(args, des)
if err != nil {
return true, nil, err
}
typedExprs[expr.i] = typ
}
return true, o, nil
default:
return false, nil, nil
}
}
// At this point, all remaining candidate overload implementations are valid, so we
// begin checking for a single remaining implementation while applying a few "preference"
// filters if more than one candidate remains.
if ok, fn, err := checkReturn(); ok {
return typedExprs, fn, err
}
// Filter out overloads which return the desired type.
if desired != nil {
filterOverloads(func(o overloadImpl) bool {
return o.returnType().TypeEqual(desired)
})
if ok, fn, err := checkReturn(); ok {
return typedExprs, fn, err
}
}
var homogeneousTyp Datum
if len(resolvableExprs) > 0 {
homogeneousTyp = typedExprs[resolvableExprs[0].i].ReturnType()
for _, resExprs := range resolvableExprs[1:] {
if !homogeneousTyp.TypeEqual(typedExprs[resExprs.i].ReturnType()) {
homogeneousTyp = nil
break
}
}
}
var bestConstType Datum
if len(constExprs) > 0 {
numVals := make([]*NumVal, len(constExprs))
for i, expr := range constExprs {
numVals[i] = expr.e.(*NumVal)
}
before := overloads
// Check if all constants can become the homogeneous type.
if homogeneousTyp != nil {
all := true
for _, constExpr := range numVals {
if !canConstantBecome(constExpr, homogeneousTyp) {
all = false
break
}
}
if all {
for _, expr := range constExprs {
filterOverloads(func(o overloadImpl) bool {
return o.params().getAt(expr.i).TypeEqual(homogeneousTyp)
})
}
}
}
if len(overloads) == 1 {
if ok, fn, err := checkReturn(); ok {
return typedExprs, fn, err
}
}
// Restore the expressions if this did not work.
overloads = before
// Check if an overload fits with the natural constant types.
for i, expr := range constExprs {
natural := naturalConstantType(numVals[i])
if natural != nil {
filterOverloads(func(o overloadImpl) bool {
return o.params().getAt(expr.i).TypeEqual(natural)
})
}
}
if len(overloads) == 1 {
if ok, fn, err := checkReturn(); ok {
return typedExprs, fn, err
}
}
// Restore the expressions if this did not work.
overloads = before
// Check if an overload fits with the "best" mutual constant types.
bestConstType = commonNumericConstantType(numVals...)
for _, expr := range constExprs {
filterOverloads(func(o overloadImpl) bool {
return o.params().getAt(expr.i).TypeEqual(bestConstType)
})
}
if ok, fn, err := checkReturn(); ok {
return typedExprs, fn, err
}
if homogeneousTyp != nil {
if !homogeneousTyp.TypeEqual(bestConstType) {
homogeneousTyp = nil
}
} else {
homogeneousTyp = bestConstType
}
}
// If all other parameters are homogeneous, we favor this type for ValArgs.
if homogeneousTyp != nil && len(valExprs) > 0 {
for _, expr := range valExprs {
filterOverloads(func(o overloadImpl) bool {
return o.params().getAt(expr.i).TypeEqual(homogeneousTyp)
})
}
if ok, fn, err := checkReturn(); ok {
return typedExprs, fn, err
}
}
if err := defaultTypeCheck(len(overloads) > 0); err != nil {
return nil, nil, err
}
return typedExprs, nil, nil
}