From 7445b495e8b005f5185ad36e96721f4ba5d62ade Mon Sep 17 00:00:00 2001 From: Abhishek Goyal Date: Sun, 7 Jun 2026 12:11:47 +0530 Subject: [PATCH] fix(sqlite): correct parameter binding when mixing sqlc.arg() with bare ? When a SQLite query uses both sqlc.arg() named parameters and bare ? placeholders, the generated SQL has numbered ?N for named params but leaves bare ? unnumbered. SQLite's auto-numbering for bare ? then conflicts with the explicit ?N values, silently binding arguments to wrong columns. Fix by numbering all placeholders sequentially in text order when the mixed case is detected, ensuring positional argument passing matches the generated ?N values. --- .../testdata/mix_param_types/sqlite/go/db.go | 31 ++++++++++ .../mix_param_types/sqlite/go/models.go | 11 ++++ .../mix_param_types/sqlite/go/test.sql.go | 59 +++++++++++++++++++ .../mix_param_types/sqlite/schema.sql | 5 ++ .../testdata/mix_param_types/sqlite/sqlc.json | 12 ++++ .../testdata/mix_param_types/sqlite/test.sql | 8 +++ internal/sql/named/param_set.go | 8 +++ internal/sql/rewrite/parameters.go | 31 ++++++++++ 8 files changed, 165 insertions(+) create mode 100644 internal/endtoend/testdata/mix_param_types/sqlite/go/db.go create mode 100644 internal/endtoend/testdata/mix_param_types/sqlite/go/models.go create mode 100644 internal/endtoend/testdata/mix_param_types/sqlite/go/test.sql.go create mode 100644 internal/endtoend/testdata/mix_param_types/sqlite/schema.sql create mode 100644 internal/endtoend/testdata/mix_param_types/sqlite/sqlc.json create mode 100644 internal/endtoend/testdata/mix_param_types/sqlite/test.sql diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/go/db.go b/internal/endtoend/testdata/mix_param_types/sqlite/go/db.go new file mode 100644 index 0000000000..80dd6ab1f6 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/go/models.go b/internal/endtoend/testdata/mix_param_types/sqlite/go/models.go new file mode 100644 index 0000000000..d8136371e6 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +type Bar struct { + ID int64 + Name string + Phone string +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/go/test.sql.go b/internal/endtoend/testdata/mix_param_types/sqlite/go/test.sql.go new file mode 100644 index 0000000000..36ea9ce689 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/go/test.sql.go @@ -0,0 +1,59 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: test.sql + +package querytest + +import ( + "context" +) + +const countOne = `-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = ?1 AND name <> ?2 +` + +type CountOneParams struct { + ID int64 + Name string +} + +func (q *Queries) CountOne(ctx context.Context, arg CountOneParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countOne, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countThree = `-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > ?1 AND phone <> ?2 AND name <> ?3 +` + +type CountThreeParams struct { + ID int64 + Phone string + Name string +} + +func (q *Queries) CountThree(ctx context.Context, arg CountThreeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countThree, arg.ID, arg.Phone, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} + +const countTwo = `-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = ?1 AND name <> ?2 +` + +type CountTwoParams struct { + ID int64 + Name string +} + +func (q *Queries) CountTwo(ctx context.Context, arg CountTwoParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTwo, arg.ID, arg.Name) + var count int64 + err := row.Scan(&count) + return count, err +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/schema.sql b/internal/endtoend/testdata/mix_param_types/sqlite/schema.sql new file mode 100644 index 0000000000..5dba2166f9 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE bar ( + id integer not null, + name text not null, + phone text not null +); diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/sqlc.json b/internal/endtoend/testdata/mix_param_types/sqlite/sqlc.json new file mode 100644 index 0000000000..b2a38f973a --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "querytest", + "schema": "schema.sql", + "queries": "test.sql", + "engine": "sqlite" + } + ] +} diff --git a/internal/endtoend/testdata/mix_param_types/sqlite/test.sql b/internal/endtoend/testdata/mix_param_types/sqlite/test.sql new file mode 100644 index 0000000000..384d830d42 --- /dev/null +++ b/internal/endtoend/testdata/mix_param_types/sqlite/test.sql @@ -0,0 +1,8 @@ +-- name: CountOne :one +SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> ?; + +-- name: CountTwo :one +SELECT count(1) FROM bar WHERE id = ? AND name <> sqlc.arg(name); + +-- name: CountThree :one +SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?; diff --git a/internal/sql/named/param_set.go b/internal/sql/named/param_set.go index d47617a39e..13321cc069 100644 --- a/internal/sql/named/param_set.go +++ b/internal/sql/named/param_set.go @@ -47,6 +47,14 @@ func (p *ParamSet) Add(param Param) int { return argn } +// AddAnonymous allocates the next available parameter position without +// associating a name. Used for bare ? placeholders that need explicit numbering. +func (p *ParamSet) AddAnonymous() int { + argn := p.nextArgNum() + p.positionToName[argn] = "" + return argn +} + // FetchMerge fetches an indexed parameter, and merges `mergeP` into it // Returns: the merged parameter and whether it was a named parameter func (p *ParamSet) FetchMerge(idx int, mergeP Param) (param Param, isNamed bool) { diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index d1ea1a22cc..78411d10c4 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -82,6 +82,18 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) hasNamedParameterSupport := engine != config.EngineMySQL + + // SQLite uses numbered ?N for sqlc.arg() but also accepts bare ?. When both + // coexist we must number all placeholders in text order so that positional + // argument binding matches the generated ?N values. + foundBare := astutils.Search(raw, isBareParamRef) + hasMixedParams := engine == config.EngineSQLite && + len(foundBare.Items) > 0 && + len(foundFunc.Items)+len(foundSign.Items) > 0 + if hasMixedParams { + numbs = nil + } + allParams := named.NewParamSet(numbs, hasNamedParameterSupport) if len(foundFunc.Items)+len(foundSign.Items) == 0 { @@ -183,6 +195,20 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, }) return false + case hasMixedParams && isBareParamRef(node): + ref := node.(*ast.ParamRef) + argn := allParams.AddAnonymous() + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: ref.Location, + }) + edits = append(edits, source.Edit{ + Location: ref.Location - raw.StmtLocation, + Old: "?", + New: fmt.Sprintf("?%d", argn), + }) + return false + default: return true } @@ -190,3 +216,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, return node.(*ast.RawStmt), allParams, edits } + +func isBareParamRef(node ast.Node) bool { + ref, ok := node.(*ast.ParamRef) + return ok && !ref.Dollar +}