diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 8c7a325c3c..62eb2fb02f 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -66,6 +66,9 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { case *ast.BetweenExpr: p.parent = node + case *ast.CallStmt: + p.parent = n.FuncCall + case *ast.FuncCall: p.parent = node diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 65a2d5853c..a01675645b 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -83,6 +83,8 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { if len(targets.Items) == 0 && n.Larg != nil { return outputColumns(qc, n.Larg) } + case *ast.CallStmt: + targets = &ast.List{} case *ast.TruncateStmt: targets = &ast.List{} case *ast.UpdateStmt: @@ -386,6 +388,8 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { list = &ast.List{ Items: append(n.FromClause.Items, n.Relations.Items...), } + case *ast.CallStmt: + list = &ast.List{} default: return nil, fmt.Errorf("sourceTables: unsupported node type: %T", n) } diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 9beaead85d..017a326797 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -48,6 +48,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } var table *ast.TableName switch n := raw.Stmt.(type) { + case *ast.CallStmt: case *ast.SelectStmt: case *ast.DeleteStmt: case *ast.InsertStmt: diff --git a/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/go/query.sql.go index c40d6f2561..7aee5bd63c 100644 --- a/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/go/query.sql.go @@ -7,11 +7,53 @@ import ( "context" ) -const placeholder = `-- name: Placeholder :exec -SELECT 1 +const callInsertData = `-- name: CallInsertData :exec +CALL insert_data($1, $2) ` -func (q *Queries) Placeholder(ctx context.Context) error { - _, err := q.db.Exec(ctx, placeholder) +type CallInsertDataParams struct { + A int32 + B int32 +} + +func (q *Queries) CallInsertData(ctx context.Context, arg CallInsertDataParams) error { + _, err := q.db.Exec(ctx, callInsertData, arg.A, arg.B) + return err +} + +const callInsertDataNamed = `-- name: CallInsertDataNamed :exec +CALL insert_data(b => $1, a => $2) +` + +type CallInsertDataNamedParams struct { + B int32 + A int32 +} + +func (q *Queries) CallInsertDataNamed(ctx context.Context, arg CallInsertDataNamedParams) error { + _, err := q.db.Exec(ctx, callInsertDataNamed, arg.B, arg.A) + return err +} + +const callInsertDataNoArgs = `-- name: CallInsertDataNoArgs :exec +CALL insert_data(1, 2) +` + +func (q *Queries) CallInsertDataNoArgs(ctx context.Context) error { + _, err := q.db.Exec(ctx, callInsertDataNoArgs) + return err +} + +const callInsertDataSqlcArgs = `-- name: CallInsertDataSqlcArgs :exec +CALL insert_data($1, $2) +` + +type CallInsertDataSqlcArgsParams struct { + Foo int32 + Bar int32 +} + +func (q *Queries) CallInsertDataSqlcArgs(ctx context.Context, arg CallInsertDataSqlcArgsParams) error { + _, err := q.db.Exec(ctx, callInsertDataSqlcArgs, arg.Foo, arg.Bar) return err } diff --git a/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/query.sql b/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/query.sql index eea032f74a..2d8d86ad0c 100644 --- a/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/ddl_create_procedure/postgresql/pgx/query.sql @@ -1,6 +1,11 @@ --- name: Placeholder :exec -SELECT 1; +-- name: CallInsertData :exec +CALL insert_data($1, $2); --- FIXME: Implement CALL --- name: CallInsertData :select +-- name: CallInsertDataNoArgs :exec CALL insert_data(1, 2); + +-- name: CallInsertDataNamed :exec +CALL insert_data(b => $1, a => $2); + +-- name: CallInsertDataSqlcArgs :exec +CALL insert_data(sqlc.arg('foo'), sqlc.arg('bar')); diff --git a/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/go/query.sql.go index 203f0fe62c..ca07b67007 100644 --- a/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/go/query.sql.go @@ -7,11 +7,53 @@ import ( "context" ) -const placeholder = `-- name: Placeholder :exec -SELECT 1 +const callInsertData = `-- name: CallInsertData :exec +CALL insert_data($1, $2) ` -func (q *Queries) Placeholder(ctx context.Context) error { - _, err := q.db.ExecContext(ctx, placeholder) +type CallInsertDataParams struct { + A int32 + B int32 +} + +func (q *Queries) CallInsertData(ctx context.Context, arg CallInsertDataParams) error { + _, err := q.db.ExecContext(ctx, callInsertData, arg.A, arg.B) + return err +} + +const callInsertDataNamed = `-- name: CallInsertDataNamed :exec +CALL insert_data(b => $1, a => $2) +` + +type CallInsertDataNamedParams struct { + B int32 + A int32 +} + +func (q *Queries) CallInsertDataNamed(ctx context.Context, arg CallInsertDataNamedParams) error { + _, err := q.db.ExecContext(ctx, callInsertDataNamed, arg.B, arg.A) + return err +} + +const callInsertDataNoArgs = `-- name: CallInsertDataNoArgs :exec +CALL insert_data(1, 2) +` + +func (q *Queries) CallInsertDataNoArgs(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, callInsertDataNoArgs) + return err +} + +const callInsertDataSqlcArgs = `-- name: CallInsertDataSqlcArgs :exec +CALL insert_data($1, $2) +` + +type CallInsertDataSqlcArgsParams struct { + Foo int32 + Bar int32 +} + +func (q *Queries) CallInsertDataSqlcArgs(ctx context.Context, arg CallInsertDataSqlcArgsParams) error { + _, err := q.db.ExecContext(ctx, callInsertDataSqlcArgs, arg.Foo, arg.Bar) return err } diff --git a/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/query.sql b/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/query.sql index eea032f74a..2d8d86ad0c 100644 --- a/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/query.sql +++ b/internal/endtoend/testdata/ddl_create_procedure/postgresql/stdlib/query.sql @@ -1,6 +1,11 @@ --- name: Placeholder :exec -SELECT 1; +-- name: CallInsertData :exec +CALL insert_data($1, $2); --- FIXME: Implement CALL --- name: CallInsertData :select +-- name: CallInsertDataNoArgs :exec CALL insert_data(1, 2); + +-- name: CallInsertDataNamed :exec +CALL insert_data(b => $1, a => $2); + +-- name: CallInsertDataSqlcArgs :exec +CALL insert_data(sqlc.arg('foo'), sqlc.arg('bar')); diff --git a/internal/engine/postgresql/convert.go b/internal/engine/postgresql/convert.go index c6e8976b66..ce420a502e 100644 --- a/internal/engine/postgresql/convert.go +++ b/internal/engine/postgresql/convert.go @@ -639,6 +639,33 @@ func convertBooleanTest(n *pg.BooleanTest) *ast.BooleanTest { } } +func convertCallStmt(n *pg.CallStmt) *ast.CallStmt { + if n == nil { + return nil + } + rel, err := parseRelationFromNodes(n.Funccall.Funcname) + if err != nil { + // TODO: How should we handle errors? + panic(err) + } + + return &ast.CallStmt{ + FuncCall: &ast.FuncCall{ + Func: rel.FuncName(), + Funcname: convertSlice(n.Funccall.Funcname), + Args: convertSlice(n.Funccall.Args), + AggOrder: convertSlice(n.Funccall.AggOrder), + AggFilter: convertNode(n.Funccall.AggFilter), + AggWithinGroup: n.Funccall.AggWithinGroup, + AggStar: n.Funccall.AggStar, + AggDistinct: n.Funccall.AggDistinct, + FuncVariadic: n.Funccall.FuncVariadic, + Over: convertWindowDef(n.Funccall.Over), + Location: int(n.Funccall.Location), + }, + } +} + func convertCaseExpr(n *pg.CaseExpr) *ast.CaseExpr { if n == nil { return nil @@ -3092,6 +3119,9 @@ func convertNode(node *pg.Node) ast.Node { case *pg.Node_BooleanTest: return convertBooleanTest(n.BooleanTest) + + case *pg.Node_CallStmt: + return convertCallStmt(n.CallStmt) case *pg.Node_CaseExpr: return convertCaseExpr(n.CaseExpr) diff --git a/internal/sql/ast/call_stmt.go b/internal/sql/ast/call_stmt.go new file mode 100644 index 0000000000..252bfb3169 --- /dev/null +++ b/internal/sql/ast/call_stmt.go @@ -0,0 +1,12 @@ +package ast + +type CallStmt struct { + FuncCall *FuncCall +} + +func (n *CallStmt) Pos() int { + if n.FuncCall == nil { + return 0 + } + return n.FuncCall.Pos() +} diff --git a/internal/sql/astutils/rewrite.go b/internal/sql/astutils/rewrite.go index 209cfb382c..2735d6e3eb 100644 --- a/internal/sql/astutils/rewrite.go +++ b/internal/sql/astutils/rewrite.go @@ -414,6 +414,9 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast. a.apply(n, "Xpr", nil, n.Xpr) a.apply(n, "Arg", nil, n.Arg) + case *ast.CallStmt: + a.apply(n, "FuncCall", nil, n.FuncCall) + case *ast.CaseExpr: a.apply(n, "Xpr", nil, n.Xpr) a.apply(n, "Arg", nil, n.Arg) diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index eefad0ac03..d1b4ee6aa8 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -511,6 +511,11 @@ func Walk(f Visitor, node ast.Node) { Walk(f, n.Arg) } + case *ast.CallStmt: + if n.FuncCall != nil { + Walk(f, n.FuncCall) + } + case *ast.CaseExpr: if n.Xpr != nil { Walk(f, n.Xpr)