diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/go/db.go b/internal/endtoend/testdata/mix_param_types/sqlite/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/go/models.go b/internal/endtoend/testdata/mix_param_types/sqlite/go/models.go new file mode 100644 index 0000000000..d8136371e6 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +type Bar struct { + ID int64 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/sqlite/go/test.sql.go new file mode 100644 index 0000000000..36ea9ce689 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/go/test.sql.go @@ -0,0 +1,59 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = ?1 AND name <> ?2 +` + +type CountOneParams struct { + ID int64 + Name string +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > ?1 AND phone <> ?2 AND name <> ?3 +` + +type CountThreeParams struct { + ID int64 + Phone string + Name string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Phone, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = ?1 AND name <> ?2 +` + +type CountTwoParams struct { + ID int64 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/schema.sql b/internal/endtoend/testdata/mix_param_types/sqlite/schema.sql new file mode 100644 index 0000000000..5dba2166f9 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE bar ( + id integer not null, + name text not null, + phone text not null +); diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/sqlc.json b/internal/endtoend/testdata/mix_param_types/sqlite/sqlc.json new file mode 100644 index 0000000000..b2a38f973a --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "test.sql", + "engine": "sqlite" + } + ] +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/test.sql b/internal/endtoend/testdata/mix_param_types/sqlite/test.sql new file mode 100644 index 0000000000..384d830d42 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/test.sql @@ -0,0 +1,8 @@ +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> ?; + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = ? AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?; diff --git a/internal/sql/named/param_set.go b/internal/sql/named/param_set.go index d47617a39e..13321cc069 100644 --- a/internal/sql/named/param_set.go +++ b/internal/sql/named/param_set.go @@ -47,6 +47,14 @@ func (p *ParamSet) Add(param Param) int { return argn } +// AddAnonymous allocates the next available parameter position without +// associating a name. Used for bare ? placeholders that need explicit numbering. +func (p *ParamSet) AddAnonymous() int { + argn := p.nextArgNum() + p.positionToName[argn] = "" + return argn +} + // FetchMerge fetches an indexed parameter, and merges `mergeP` into it // Returns: the merged parameter and whether it was a named parameter func (p *ParamSet) FetchMerge(idx int, mergeP Param) (param Param, isNamed bool) { diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index d1ea1a22cc..78411d10c4 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -82,6 +82,18 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) hasNamedParameterSupport := engine != config.EngineMySQL + + // SQLite uses numbered ?N for sqlc.arg() but also accepts bare ?. When both + // coexist we must number all placeholders in text order so that positional + // argument binding matches the generated ?N values. + foundBare := astutils.Search(raw, isBareParamRef) + hasMixedParams := engine == config.EngineSQLite && + len(foundBare.Items) > 0 && + len(foundFunc.Items)+len(foundSign.Items) > 0 + if hasMixedParams { + numbs = nil + } + allParams := named.NewParamSet(numbs, hasNamedParameterSupport) if len(foundFunc.Items)+len(foundSign.Items) == 0 { @@ -183,6 +195,20 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, }) return false + case hasMixedParams && isBareParamRef(node): + ref := node.(*ast.ParamRef) + argn := allParams.AddAnonymous() + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: ref.Location, + }) + edits = append(edits, source.Edit{ + Location: ref.Location - raw.StmtLocation, + Old: "?", + New: fmt.Sprintf("?%d", argn), + }) + return false + default: return true } @@ -190,3 +216,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, return node.(*ast.RawStmt), allParams, edits } + +func isBareParamRef(node ast.Node) bool { + ref, ok := node.(*ast.ParamRef) + return ok && !ref.Dollar +}