Skip to content

Commit 8ba2e97

Browse files
committed
Improve Query performance and return an error if the query includes multiple statements
This commit changes Query to return an error if more than one SQL statement is provided. Previously, this library would only execute the last query statement. It also improves query construction performance by ~15%. This is a breaking a change since existing programs may rely on the broken mattn/go-sqlite3 implementation. That said, any program relying on this is also broken / using sqlite3 incorrectly. ``` goos: darwin goarch: arm64 pkg: github.com/charlievieth/go-sqlite3 cpu: Apple M4 Pro │ x1.txt │ x2.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkQuery-14 2.255µ ± 1% 1.837µ ± 1% -18.56% (p=0.000 n=10) Suite/BenchmarkQuerySimple-14 1.322µ ± 9% 1.124µ ± 4% -15.02% (p=0.000 n=10) geomean 1.727µ 1.436µ -16.81% │ x1.txt │ x2.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkQuery-14 664.0 ± 0% 656.0 ± 0% -1.20% (p=0.000 n=10) Suite/BenchmarkQuerySimple-14 472.0 ± 0% 456.0 ± 0% -3.39% (p=0.000 n=10) geomean 559.8 546.9 -2.30% │ x1.txt │ x2.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkQuery-14 23.00 ± 0% 22.00 ± 0% -4.35% (p=0.000 n=10) Suite/BenchmarkQuerySimple-14 14.00 ± 0% 13.00 ± 0% -7.14% (p=0.000 n=10) geomean 17.94 16.91 -5.76% ```
1 parent 976152a commit 8ba2e97

File tree

2 files changed

+253
-36
lines changed

2 files changed

+253
-36
lines changed

sqlite3.go

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,59 @@ _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_
137137
}
138138
#endif
139139
140+
#define GO_SQLITE_MULTIPLE_QUERIES -1
141+
142+
// Our own implementation of ctype.h's isspace (for simplicity and to avoid
143+
// whatever locale shenanigans are involved with the Libc's isspace).
144+
static int _sqlite3_isspace(unsigned char c) {
145+
return c == ' ' || c - '\t' < 5;
146+
}
147+
148+
static int _sqlite3_prepare_query(sqlite3 *db, const char *zSql, int nBytes,
149+
sqlite3_stmt **ppStmt, int *paramCount) {
150+
151+
const char *tail;
152+
int rc = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail);
153+
if (rc != SQLITE_OK) {
154+
return rc;
155+
}
156+
*paramCount = sqlite3_bind_parameter_count(*ppStmt);
157+
158+
// Check if the SQL query contains multiple statements.
159+
160+
// Trim leading space to handle queries with trailing whitespace.
161+
// This can save us an additional call to sqlite3_prepare_v2.
162+
const char *end = zSql + nBytes;
163+
while (tail < end && _sqlite3_isspace(*tail)) {
164+
tail++;
165+
}
166+
nBytes -= (tail - zSql);
167+
168+
// Attempt to parse the remaining SQL, if any.
169+
if (nBytes > 0 && *tail) {
170+
sqlite3_stmt *stmt;
171+
rc = _sqlite3_prepare_v2_internal(db, tail, nBytes, &stmt, NULL);
172+
if (rc != SQLITE_OK) {
173+
// sqlite3 will return OK and a NULL statement if it was
174+
goto error;
175+
}
176+
if (stmt != NULL) {
177+
sqlite3_finalize(stmt);
178+
rc = GO_SQLITE_MULTIPLE_QUERIES;
179+
goto error;
180+
}
181+
}
182+
183+
// Ok, the SQL contained one valid statement.
184+
return SQLITE_OK;
185+
186+
error:
187+
if (*ppStmt) {
188+
sqlite3_finalize(*ppStmt);
189+
}
190+
return rc;
191+
}
192+
140193
static int _sqlite3_prepare_v2(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, int *oBytes) {
141194
const char *tail = NULL;
142195
int rv = _sqlite3_prepare_v2_internal(db, zSql, nBytes, ppStmt, &tail);
@@ -1123,46 +1176,42 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
11231176
return c.query(context.Background(), query, list)
11241177
}
11251178

1179+
var closedRows = &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}
1180+
11261181
func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
1127-
start := 0
1128-
for {
1129-
stmtArgs := make([]driver.NamedValue, 0, len(args))
1130-
s, err := c.prepare(ctx, query)
1131-
if err != nil {
1132-
return nil, err
1133-
}
1134-
s.(*SQLiteStmt).cls = true
1135-
na := s.NumInput()
1136-
if len(args)-start < na {
1137-
s.Close()
1138-
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
1139-
}
1140-
// consume the number of arguments used in the current
1141-
// statement and append all named arguments not contained
1142-
// therein
1143-
stmtArgs = append(stmtArgs, args[start:start+na]...)
1144-
for i := range args {
1145-
if (i < start || i >= na) && args[i].Name != "" {
1146-
stmtArgs = append(stmtArgs, args[i])
1147-
}
1148-
}
1149-
for i := range stmtArgs {
1150-
stmtArgs[i].Ordinal = i + 1
1151-
}
1152-
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
1153-
if err != nil && err != driver.ErrSkip {
1154-
s.Close()
1155-
return rows, err
1182+
s := SQLiteStmt{c: c, cls: true}
1183+
p := stringData(query)
1184+
var paramCount C.int
1185+
rv := C._sqlite3_prepare_query(c.db, (*C.char)(unsafe.Pointer(p)), C.int(len(query)), &s.s, &paramCount)
1186+
if rv != C.SQLITE_OK {
1187+
if rv == C.GO_SQLITE_MULTIPLE_QUERIES {
1188+
return nil, errors.New("query contains multiple SQL statements")
11561189
}
1157-
start += na
1158-
tail := s.(*SQLiteStmt).t
1159-
if tail == "" {
1160-
return rows, nil
1190+
return nil, c.lastError()
1191+
}
1192+
1193+
// The sqlite3_stmt will be nil if the SQL was valid but did not
1194+
// contain a query. For now we're supporting this for the sake of
1195+
// backwards compatibility, but that may change in the future.
1196+
if s.s == nil {
1197+
return closedRows, nil
1198+
}
1199+
1200+
na := int(paramCount)
1201+
if n := len(args); n != na {
1202+
s.finalize()
1203+
if n < na {
1204+
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
11611205
}
1162-
rows.Close()
1163-
s.Close()
1164-
query = tail
1206+
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
11651207
}
1208+
1209+
rows, err := s.query(ctx, args)
1210+
if err != nil && err != driver.ErrSkip {
1211+
s.finalize()
1212+
return rows, err
1213+
}
1214+
return rows, nil
11661215
}
11671216

11681217
// Begin transaction.

sqlite3_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"math/rand"
2020
"net/url"
2121
"os"
22+
"path/filepath"
2223
"reflect"
2324
"regexp"
2425
"runtime"
@@ -1203,6 +1204,163 @@ func TestQueryer(t *testing.T) {
12031204
}
12041205
}
12051206

1207+
func testQuery(t *testing.T, test func(t *testing.T, db *sql.DB)) {
1208+
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
1209+
if err != nil {
1210+
t.Fatal("Failed to open database:", err)
1211+
}
1212+
defer db.Close()
1213+
1214+
_, err = db.Exec(`
1215+
CREATE TABLE FOO (id INTEGER);
1216+
INSERT INTO foo(id) VALUES(?);
1217+
INSERT INTO foo(id) VALUES(?);
1218+
INSERT INTO foo(id) VALUES(?);
1219+
`, 3, 2, 1)
1220+
if err != nil {
1221+
t.Fatal(err)
1222+
}
1223+
1224+
// Capture panic so tests can continue
1225+
defer func() {
1226+
if e := recover(); e != nil {
1227+
buf := make([]byte, 32*1024)
1228+
n := runtime.Stack(buf, false)
1229+
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
1230+
}
1231+
}()
1232+
test(t, db)
1233+
}
1234+
1235+
func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
1236+
var values []interface{}
1237+
testQuery(t, func(t *testing.T, db *sql.DB) {
1238+
rows, err := db.Query(query, args...)
1239+
if err != nil {
1240+
t.Fatal(err)
1241+
}
1242+
if rows == nil {
1243+
t.Fatal("nil rows")
1244+
}
1245+
for i := 0; rows.Next(); i++ {
1246+
if i > 1_000 {
1247+
t.Fatal("To many iterations of rows.Next():", i)
1248+
}
1249+
var v interface{}
1250+
if err := rows.Scan(&v); err != nil {
1251+
t.Fatal(err)
1252+
}
1253+
values = append(values, v)
1254+
}
1255+
if err := rows.Err(); err != nil {
1256+
t.Fatal(err)
1257+
}
1258+
if err := rows.Close(); err != nil {
1259+
t.Fatal(err)
1260+
}
1261+
})
1262+
return values
1263+
}
1264+
1265+
func TestQuery(t *testing.T) {
1266+
queries := []struct {
1267+
query string
1268+
args []interface{}
1269+
}{
1270+
{"SELECT id FROM foo ORDER BY id;", nil},
1271+
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
1272+
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},
1273+
1274+
// Comments
1275+
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
1276+
{"SELECT id FROM foo ORDER BY id -- comment", nil}, // Not terminated
1277+
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
1278+
{
1279+
`-- FOO
1280+
SELECT id FROM foo ORDER BY id; -- BAR
1281+
/* BAZ */`,
1282+
nil,
1283+
},
1284+
}
1285+
want := []interface{}{
1286+
int64(1),
1287+
int64(2),
1288+
int64(3),
1289+
}
1290+
for _, q := range queries {
1291+
t.Run("", func(t *testing.T) {
1292+
got := testQueryValues(t, q.query, q.args...)
1293+
if !reflect.DeepEqual(got, want) {
1294+
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
1295+
}
1296+
})
1297+
}
1298+
}
1299+
1300+
func TestQueryNoSQL(t *testing.T) {
1301+
got := testQueryValues(t, "")
1302+
if got != nil {
1303+
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
1304+
}
1305+
}
1306+
1307+
func testQueryError(t *testing.T, query string, args ...interface{}) {
1308+
testQuery(t, func(t *testing.T, db *sql.DB) {
1309+
rows, err := db.Query(query, args...)
1310+
if err == nil {
1311+
t.Error("Expected an error got:", err)
1312+
}
1313+
if rows != nil {
1314+
t.Error("Returned rows should be nil on error!")
1315+
// Attempt to iterate over rows to make sure they don't panic.
1316+
for i := 0; rows.Next(); i++ {
1317+
if i > 1_000 {
1318+
t.Fatal("To many iterations of rows.Next():", i)
1319+
}
1320+
}
1321+
if err := rows.Err(); err != nil {
1322+
t.Error(err)
1323+
}
1324+
rows.Close()
1325+
}
1326+
})
1327+
}
1328+
1329+
func TestQueryNotEnoughArgs(t *testing.T) {
1330+
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
1331+
}
1332+
1333+
func TestQueryTooManyArgs(t *testing.T) {
1334+
// TODO: test error message / kind
1335+
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
1336+
}
1337+
1338+
func TestQueryMultipleStatements(t *testing.T) {
1339+
testQueryError(t, "SELECT 1; SELECT 2;")
1340+
testQueryError(t, "SELECT 1; SELECT 2; SELECT 3;")
1341+
testQueryError(t, "SELECT 1; ; SELECT 2;") // Empty statement in between
1342+
testQueryError(t, "SELECT 1; FOOBAR 2;") // Error in second statement
1343+
1344+
// Test that multiple trailing semicolons (";;") are not an error
1345+
noError := func(t *testing.T, query string, args ...any) {
1346+
testQuery(t, func(t *testing.T, db *sql.DB) {
1347+
var n int64
1348+
if err := db.QueryRow(query, args...).Scan(&n); err != nil {
1349+
t.Fatal(err)
1350+
}
1351+
if n != 1 {
1352+
t.Fatalf("got: %d want: %d", n, 1)
1353+
}
1354+
})
1355+
}
1356+
noError(t, "SELECT 1; ;")
1357+
noError(t, "SELECT ?; ;", 1)
1358+
}
1359+
1360+
func TestQueryInvalidTable(t *testing.T) {
1361+
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
1362+
}
1363+
12061364
func TestStress(t *testing.T) {
12071365
tempFilename := TempFilename(t)
12081366
defer os.Remove(tempFilename)
@@ -2180,6 +2338,7 @@ var benchmarks = []testing.InternalBenchmark{
21802338
{Name: "BenchmarkExecContextStep", F: benchmarkExecContextStep},
21812339
{Name: "BenchmarkExecTx", F: benchmarkExecTx},
21822340
{Name: "BenchmarkQuery", F: benchmarkQuery},
2341+
{Name: "BenchmarkQuerySimple", F: benchmarkQuerySimple},
21832342
{Name: "BenchmarkQueryContext", F: benchmarkQueryContext},
21842343
{Name: "BenchmarkParams", F: benchmarkParams},
21852344
{Name: "BenchmarkStmt", F: benchmarkStmt},
@@ -2619,6 +2778,15 @@ func benchmarkQuery(b *testing.B) {
26192778
}
26202779
}
26212780

2781+
func benchmarkQuerySimple(b *testing.B) {
2782+
for i := 0; i < b.N; i++ {
2783+
var n int
2784+
if err := db.QueryRow("select 1;").Scan(&n); err != nil {
2785+
panic(err)
2786+
}
2787+
}
2788+
}
2789+
26222790
// benchmarkQueryContext is benchmark for QueryContext
26232791
func benchmarkQueryContext(b *testing.B) {
26242792
const createTableStmt = `

0 commit comments

Comments
 (0)