Skip to content

Commit

Permalink
Merge pull request #6358 from nvanbenschoten/nvanbenschoten/typeCheck
Browse files Browse the repository at this point in the history
sql: Implement Summer, a smarter typing system
  • Loading branch information
nvanbenschoten committed May 3, 2016
2 parents 8219c82 + ad3d609 commit b8f272e
Show file tree
Hide file tree
Showing 56 changed files with 4,751 additions and 2,684 deletions.
5 changes: 5 additions & 0 deletions sql/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ func simplifyExpr(e parser.Expr) (simplified parser.Expr, equivalent bool) {
return parser.MakeDBool(true), false
}

func simplifyTypedExpr(e parser.TypedExpr) (simplified parser.TypedExpr, equivalent bool) {
expr, eq := simplifyExpr(e)
return expr.(parser.TypedExpr), eq
}

func simplifyNotExpr(n *parser.NotExpr) (parser.Expr, bool) {
switch t := n.Expr.(type) {
case *parser.ComparisonExpr:
Expand Down
33 changes: 16 additions & 17 deletions sql/analyze_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,33 @@ func testTableDesc() *TableDescriptor {
}
}

func parseAndNormalizeExpr(t *testing.T, sql string) (parser.Expr, qvalMap) {
func parseAndNormalizeExpr(t *testing.T, sql string) (parser.TypedExpr, qvalMap) {
expr, err := parser.ParseExprTraditional(sql)
if err != nil {
t.Fatalf("%s: %v", sql, err)
}
expr, err = (parser.EvalContext{}).NormalizeExpr(expr)
if err != nil {
t.Fatalf("%s: %v", sql, err)
}

// Perform qualified name resolution because {analyze,simplify}Expr want
// expressions containing qvalues.
desc := testTableDesc()
sel := testInitDummySelectNode(desc)
if err := desc.AllocateIDs(); err != nil {
if err = desc.AllocateIDs(); err != nil {
t.Fatal(err)
}
expr, nErr := sel.resolveQNames(expr)
if nErr != nil {
t.Fatalf("%s: %v", sql, nErr)
if expr, err = sel.resolveQNames(expr); err != nil {
t.Fatalf("%s: %v", sql, err)
}
typedExpr, err := parser.TypeCheck(expr, nil, nil)
if err != nil {
t.Fatalf("%s: %v", sql, err)
}
if _, err := parser.PerformTypeChecking(expr, nil); err != nil {
if typedExpr, err = (parser.EvalContext{}).NormalizeExpr(typedExpr); err != nil {
t.Fatalf("%s: %v", sql, err)
}
return expr, sel.qvals
return typedExpr, sel.qvals
}

func checkEquivExpr(a, b parser.Expr, qvals qvalMap) error {
func checkEquivExpr(a, b parser.TypedExpr, qvals qvalMap) error {
// The expressions above only use the values 1 and 2. Verify that the
// simplified expressions evaluate to the same value as the original
// expression for interesting values.
Expand Down Expand Up @@ -251,7 +250,7 @@ func TestSimplifyExpr(t *testing.T) {
}
for _, d := range testData {
expr, _ := parseAndNormalizeExpr(t, d.expr)
expr, equiv := simplifyExpr(expr)
expr, equiv := simplifyTypedExpr(expr)
if s := expr.String(); d.expected != s {
t.Errorf("%s: expected %s, but found %s", d.expr, d.expected, s)
}
Expand Down Expand Up @@ -287,7 +286,7 @@ func TestSimplifyNotExpr(t *testing.T) {
}
for _, d := range testData {
expr1, qvals := parseAndNormalizeExpr(t, d.expr)
expr2, equiv := simplifyExpr(expr1)
expr2, equiv := simplifyTypedExpr(expr1)
if s := expr2.String(); d.expected != s {
t.Errorf("%s: expected %s, but found %s", d.expr, d.expected, s)
}
Expand Down Expand Up @@ -421,7 +420,7 @@ func TestSimplifyAndExprCheck(t *testing.T) {
{`a > 1 AND a IS NOT NULL`, `a > 1`, true},
{`a IS NOT NULL AND a > 1`, `a > 1`, true},
{`a > 1.0 AND a = 2`, `a = 2`, true},
{`a > 1 AND a = 2.0`, `a = 2.0`, true},
{`a > 1 AND a = 2.1`, `a = 2.1`, true},

{`a >= 1 AND a = 1`, `a = 1`, true},
{`a >= 1 AND a = 2`, `a = 2`, true},
Expand Down Expand Up @@ -520,7 +519,7 @@ func TestSimplifyAndExprCheck(t *testing.T) {
}
for _, d := range testData {
expr1, qvals := parseAndNormalizeExpr(t, d.expr)
expr2, equiv := simplifyExpr(expr1)
expr2, equiv := simplifyTypedExpr(expr1)
if s := expr2.String(); d.expected != s {
t.Errorf("%s: expected %s, but found %s", d.expr, d.expected, s)
}
Expand Down Expand Up @@ -737,7 +736,7 @@ func TestSimplifyOrExprCheck(t *testing.T) {
}
for _, d := range testData {
expr1, qvals := parseAndNormalizeExpr(t, d.expr)
expr2, equiv := simplifyExpr(expr1)
expr2, equiv := simplifyTypedExpr(expr1)
if s := expr2.String(); d.expected != s {
t.Errorf("%s: expected %s, but found %s", d.expr, d.expected, s)
}
Expand Down
2 changes: 1 addition & 1 deletion sql/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (sc *SchemaChanger) truncateAndBacfillColumnsChunk(
added []ColumnDescriptor,
dropped []ColumnDescriptor,
nonNullableColumn string,
defaultExprs []parser.Expr,
defaultExprs []parser.TypedExpr,
evalCtx parser.EvalContext,
sp span,
) (roachpb.Key, bool, *roachpb.Error) {
Expand Down
8 changes: 4 additions & 4 deletions sql/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ type deleteNode struct {
// Privileges: DELETE and SELECT on table. We currently always use a SELECT statement.
// Notes: postgres requires DELETE. Also requires SELECT for "USING" and "WHERE" with tables.
// mysql requires DELETE. Also requires SELECT if a table is used in the "WHERE" clause.
func (p *planner) Delete(n *parser.Delete, autoCommit bool) (planNode, *roachpb.Error) {
en, pErr := p.makeEditNode(n.Table, n.Returning, autoCommit, privilege.DELETE)
func (p *planner) Delete(n *parser.Delete, desiredTypes []parser.Datum, autoCommit bool) (planNode, *roachpb.Error) {
en, pErr := p.makeEditNode(n.Table, n.Returning, desiredTypes, autoCommit, privilege.DELETE)
if pErr != nil {
return nil, pErr
}
Expand All @@ -59,7 +59,7 @@ func (p *planner) Delete(n *parser.Delete, autoCommit bool) (planNode, *roachpb.
Exprs: en.tableDesc.allColumnsSelector(),
From: []parser.TableExpr{n.Table},
Where: n.Where,
})
}, nil)
if pErr != nil {
return nil, pErr
}
Expand All @@ -80,7 +80,7 @@ func (d *deleteNode) Start() *roachpb.Error {
Exprs: d.tableDesc.allColumnsSelector(),
From: []parser.TableExpr{d.n.Table},
Where: d.n.Where,
})
}, nil)
if pErr != nil {
return pErr
}
Expand Down
8 changes: 5 additions & 3 deletions sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,9 @@ func (e *Executor) execStmtsInCurrentTxn(
txnState.schemaChangers.curStatementIdx = i

var stmtStrBefore string
if e.ctx.TestingKnobs.CheckStmtStringChange {
// TODO(nvanbenschoten) Constant literals can change their representation (1.0000 -> 1) when type checking,
// so we need to reconsider how this works.
if e.ctx.TestingKnobs.CheckStmtStringChange && false {
stmtStrBefore = stmt.String()
}
var res Result
Expand All @@ -595,7 +597,7 @@ func (e *Executor) execStmtsInCurrentTxn(
default:
panic(fmt.Sprintf("unexpected txn state: %s", txnState.State))
}
if e.ctx.TestingKnobs.CheckStmtStringChange {
if e.ctx.TestingKnobs.CheckStmtStringChange && false {
if after := stmt.String(); after != stmtStrBefore {
panic(fmt.Sprintf("statement changed after exec; before:\n %s\nafter:\n %s",
stmtStrBefore, after))
Expand Down Expand Up @@ -908,7 +910,7 @@ func (e *Executor) execStmt(
stmt parser.Statement, planMaker *planner, autoCommit bool,
) (Result, *roachpb.Error) {
var result Result
plan, pErr := planMaker.makePlan(stmt, autoCommit)
plan, pErr := planMaker.makePlan(stmt, nil, autoCommit)
if pErr != nil {
return result, pErr
}
Expand Down
2 changes: 1 addition & 1 deletion sql/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (p *planner) Explain(n *parser.Explain, autoCommit bool) (planNode, *roachp
p.txn.Context = opentracing.ContextWithSpan(p.txn.Context, sp)
}

plan, err := p.makePlan(n.Statement, autoCommit)
plan, err := p.makePlan(n.Statement, nil, autoCommit)
if err != nil {
return nil, err
}
Expand Down
22 changes: 15 additions & 7 deletions sql/expr_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,30 @@ func splitBoolExpr(expr parser.Expr, conv varConvertFunc, weaker bool) (restrict
// - the implementation is best-effort (it tries to get as much of the expression into RES as
// possible, and make REM as small as possible).
// - the original expression is modified in-place and should not be used again.
func splitFilter(expr parser.Expr, conv varConvertFunc) (restricted, remainder parser.Expr) {
func splitFilter(
expr parser.TypedExpr, conv varConvertFunc,
) (restricted, remainder parser.TypedExpr) {
if expr == nil {
return nil, nil
}
restricted, remainder = splitBoolExpr(expr, conv, true)
if restricted == parser.DBoolTrue {
restricted = nil
res, rem := splitBoolExpr(expr, conv, true)
if res == parser.DBoolTrue {
res = nil
}
if remainder == parser.DBoolTrue {
remainder = nil
if res != nil {
restricted = res.(parser.TypedExpr)
}
if rem == parser.DBoolTrue {
rem = nil
}
if rem != nil {
remainder = rem.(parser.TypedExpr)
}
return restricted, remainder
}

// runFilter runs a filter expression and returs whether the filter passes.
func runFilter(filter parser.Expr, evalCtx parser.EvalContext) (bool, error) {
func runFilter(filter parser.TypedExpr, evalCtx parser.EvalContext) (bool, error) {
if filter == nil {
return true, nil
}
Expand Down
61 changes: 37 additions & 24 deletions sql/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ func (p *planner) groupBy(n *parser.SelectClause, s *selectNode) (*groupNode, *r

// We could potentially skip this, since it will be checked in addRender,
// but checking now allows early err return.
if _, err := parser.PerformTypeChecking(resolved, p.evalCtx.Args); err != nil {
typedExpr, err := parser.TypeCheck(resolved, p.evalCtx.Args, nil /* no preference */)
if err != nil {
return nil, roachpb.NewError(err)
}

norm, err := p.parser.NormalizeExpr(p.evalCtx, resolved)
norm, err := p.parser.NormalizeExpr(p.evalCtx, typedExpr)
if err != nil {
return nil, roachpb.NewError(err)
}
Expand All @@ -89,24 +90,27 @@ func (p *planner) groupBy(n *parser.SelectClause, s *selectNode) (*groupNode, *r
}

// Normalize and check the HAVING expression too if it exists.
var typedHaving parser.TypedExpr
if n.Having != nil {
having, err := s.resolveQNames(n.Having.Expr)
if err != nil {
return nil, roachpb.NewError(err)
}

havingType, err := parser.PerformTypeChecking(having, p.evalCtx.Args)
typedHaving, err = parser.TypeCheck(having, p.evalCtx.Args, parser.DummyBool)
if err != nil {
return nil, roachpb.NewError(err)
}
if !(havingType.TypeEqual(parser.DummyBool) || havingType == parser.DNull) {
return nil, roachpb.NewUErrorf("argument of HAVING must be type %s, not type %s", parser.DummyBool.Type(), havingType.Type())
if typ := typedHaving.ReturnType(); !(typ.TypeEqual(parser.DummyBool) || typ == parser.DNull) {
return nil, roachpb.NewUErrorf("argument of HAVING must be type %s, not type %s",
parser.DummyBool.Type(), typ.Type())
}

if having, err = p.parser.NormalizeExpr(p.evalCtx, having); err != nil {
typedHaving, err = p.parser.NormalizeExpr(p.evalCtx, typedHaving)
if err != nil {
return nil, roachpb.NewError(err)
}
n.Having.Expr = having
n.Having.Expr = typedHaving
}

group := &groupNode{
Expand Down Expand Up @@ -137,19 +141,20 @@ func (p *planner) groupBy(n *parser.SelectClause, s *selectNode) (*groupNode, *r
// After extraction, group.render will be entirely rendered from aggregateFuncs,
// and group.funcs will contain all the functions which need to be fed values.
for i := range group.render {
expr, err := visitor.extract(group.render[i])
typedExpr, err := visitor.extract(group.render[i])
if err != nil {
return nil, roachpb.NewError(err)
}
group.render[i] = expr
group.render[i] = typedExpr
}

if n.Having != nil {
having, err := visitor.extract(n.Having.Expr)
if typedHaving != nil {
var err error
typedHaving, err = visitor.extract(typedHaving)
if err != nil {
return nil, roachpb.NewError(err)
}
group.having = having
group.having = typedHaving
}

// Queries like `SELECT MAX(n) FROM t` expect a row of NULLs if nothing was aggregated.
Expand All @@ -167,14 +172,14 @@ func (p *planner) groupBy(n *parser.SelectClause, s *selectNode) (*groupNode, *r

// Replace the render expressions in the scanNode with expressions that
// compute only the arguments to the aggregate expressions.
s.render = make([]parser.Expr, len(group.funcs))
s.render = make([]parser.TypedExpr, len(group.funcs))
for i, f := range group.funcs {
s.render[i] = f.arg
}

// Add the group-by expressions so they are available for bucketing.
for _, g := range groupBy {
if err := s.addRender(parser.SelectExpr{Expr: g}); err != nil {
if err := s.addRender(parser.SelectExpr{Expr: g}, nil); err != nil {
return nil, err
}
}
Expand All @@ -191,8 +196,8 @@ type groupNode struct {
// The "wrapped" node (which returns ungrouped results).
plan planNode

render []parser.Expr
having parser.Expr
render []parser.TypedExpr
having parser.TypedExpr

funcs []*aggregateFunc
// The set of bucket keys.
Expand Down Expand Up @@ -477,7 +482,7 @@ func (v *extractAggregatesVisitor) VisitPre(expr parser.Expr) (recurse bool, new

f := &aggregateFunc{
expr: t,
arg: t.Exprs[0],
arg: t.Exprs[0].(parser.TypedExpr),
create: impl,
group: v.n,
buckets: make(map[string]aggregateImpl),
Expand Down Expand Up @@ -528,9 +533,12 @@ func (*extractAggregatesVisitor) VisitPost(expr parser.Expr) parser.Expr { retur
// - `k` appears in GROUP BY, so `UPPER(k)` is OK, but...
// Invalid: `SELECT k, SUM(v) FROM kv GROUP BY UPPER(k)`
// - `k` does not appear in GROUP BY; UPPER(k) does nothing to help here.
func (v extractAggregatesVisitor) extract(expr parser.Expr) (parser.Expr, error) {
expr, _ = parser.WalkExpr(&v, expr)
return expr, v.err
func (v extractAggregatesVisitor) extract(typedExpr parser.TypedExpr) (parser.TypedExpr, error) {
expr, _ := parser.WalkExpr(&v, typedExpr)
if v.err != nil {
return nil, v.err
}
return expr.(parser.TypedExpr), nil
}

var _ parser.Visitor = &isAggregateVisitor{}
Expand Down Expand Up @@ -585,11 +593,12 @@ func (p *planner) isAggregate(n *parser.SelectClause) bool {
return false
}

var _ parser.TypedExpr = &aggregateFunc{}
var _ parser.VariableExpr = &aggregateFunc{}

type aggregateFunc struct {
expr parser.Expr
arg parser.Expr
expr parser.TypedExpr
arg parser.TypedExpr
create func() aggregateImpl
group *groupNode
buckets map[string]aggregateImpl
Expand Down Expand Up @@ -629,8 +638,8 @@ func (a *aggregateFunc) String() string {

func (a *aggregateFunc) Walk(v parser.Visitor) parser.Expr { return a }

func (a *aggregateFunc) TypeCheck(args parser.MapArgs) (parser.Datum, error) {
return a.expr.TypeCheck(args)
func (a *aggregateFunc) TypeCheck(args parser.MapArgs, desired parser.Datum) (parser.TypedExpr, error) {
return a, nil
}

func (a *aggregateFunc) Eval(ctx parser.EvalContext) (parser.Datum, error) {
Expand All @@ -655,6 +664,10 @@ func (a *aggregateFunc) Eval(ctx parser.EvalContext) (parser.Datum, error) {
return datum.Eval(ctx)
}

func (a *aggregateFunc) ReturnType() parser.Datum {
return a.expr.ReturnType()
}

type aggregateImpl interface {
add(parser.Datum) error
result() (parser.Datum, error)
Expand Down
Loading

0 comments on commit b8f272e

Please sign in to comment.