Skip to content

Commit 4c3b219

Browse files
committed
Reduce QueryContext allocations by reusing the channel
This commit reduces the number of allocations and memory usage of QueryContext by inverting the goroutine: instead of processing the request in the goroutine and having it send the result, we now process the request in the method itself and goroutine is only used to interrupt the query if the context is canceled. The advantage of this approach is that we no longer need to send anything on the channel, but instead can treat the channel as a semaphore (this reduces the amount of memory allocated by this method). Additionally, we now reuse the channel used to communicate with the goroutine which reduces the number of allocations. This commit also adds a test that actually exercises the sqlite3_interrupt logic since the existing tests did not. Those tests cancelled the context before scanning any of the rows and could be made to pass without ever calling sqlite3_interrupt. The below version of SQLiteRows.Next passes the previous tests: ```go func (rc *SQLiteRows) Next(dest []driver.Value) error { rc.s.mu.Lock() defer rc.s.mu.Unlock() if rc.s.closed { return io.EOF } if err := rc.ctx.Err(); err != nil { return err } return rc.nextSyncLocked(dest) } ``` Benchmark results: ``` goos: darwin goarch: arm64 pkg: github.com/mattn/go-sqlite3 cpu: Apple M1 Max │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkQueryContext/Background-10 3.994µ ± 2% 4.034µ ± 1% ~ (p=0.289 n=10) Suite/BenchmarkQueryContext/WithCancel-10 12.02µ ± 3% 11.56µ ± 4% -3.87% (p=0.003 n=10) geomean 6.930µ 6.829µ -1.46% │ old.txt │ new.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkQueryContext/Background-10 400.0 ± 0% 400.0 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 2.376Ki ± 0% 1.025Ki ± 0% -56.87% (p=0.000 n=10) geomean 986.6 647.9 -34.33% ¹ all samples are equal │ old.txt │ new.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkQueryContext/Background-10 12.00 ± 0% 12.00 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 38.00 ± 0% 28.00 ± 0% -26.32% (p=0.000 n=10) geomean 21.35 18.33 -14.16% ¹ all samples are equal ```
1 parent 7658c06 commit 4c3b219

File tree

3 files changed

+255
-21
lines changed

3 files changed

+255
-21
lines changed

sqlite3.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ type SQLiteRows struct {
399399
decltype []string
400400
ctx context.Context // no better alternative to pass context into Next() method
401401
closemu sync.Mutex
402+
// semaphore to signal the goroutine used to interrupt queries when a
403+
// cancellable context is passed to QueryContext
404+
sema chan struct{}
402405
}
403406

404407
type functionInfo struct {
@@ -2117,6 +2120,9 @@ func (rc *SQLiteRows) Close() error {
21172120
return nil
21182121
}
21192122
rc.s = nil // remove reference to SQLiteStmt
2123+
if rc.sema != nil {
2124+
close(rc.sema)
2125+
}
21202126
s.mu.Lock()
21212127
if s.closed {
21222128
s.mu.Unlock()
@@ -2174,27 +2180,40 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
21742180
return io.EOF
21752181
}
21762182

2177-
if rc.ctx.Done() == nil {
2183+
done := rc.ctx.Done()
2184+
if done == nil {
21782185
return rc.nextSyncLocked(dest)
21792186
}
2180-
resultCh := make(chan error)
2181-
defer close(resultCh)
2187+
if err := rc.ctx.Err(); err != nil {
2188+
return err // Fast check if the channel is closed
2189+
}
2190+
2191+
if rc.sema == nil {
2192+
rc.sema = make(chan struct{})
2193+
}
21822194
go func() {
2183-
resultCh <- rc.nextSyncLocked(dest)
2184-
}()
2185-
select {
2186-
case err := <-resultCh:
2187-
return err
2188-
case <-rc.ctx.Done():
21892195
select {
2190-
case <-resultCh: // no need to interrupt
2191-
default:
2192-
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
2196+
case <-done:
21932197
C.sqlite3_interrupt(rc.s.c.db)
2194-
<-resultCh // ensure goroutine completed
2198+
// Wait until signaled. We need to ensure that this goroutine
2199+
// will not call interrupt after this method returns.
2200+
<-rc.sema
2201+
case <-rc.sema:
21952202
}
2196-
return rc.ctx.Err()
2203+
}()
2204+
2205+
err := rc.nextSyncLocked(dest)
2206+
// Signal the goroutine to exit. This send will only succeed at a point
2207+
// where it is impossible for the goroutine to call sqlite3_interrupt.
2208+
//
2209+
// This is necessary to ensure the goroutine does not interrupt an
2210+
// unrelated query if the context is cancelled after this method returns
2211+
// but before the goroutine exits (we don't wait for it to exit).
2212+
rc.sema <- struct{}{}
2213+
if err != nil && isInterruptErr(err) {
2214+
err = rc.ctx.Err()
21972215
}
2216+
return err
21982217
}
21992218

22002219
// nextSyncLocked moves cursor to next; must be called with locked mutex.

sqlite3_go18_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ package sqlite3
1111
import (
1212
"context"
1313
"database/sql"
14+
"errors"
1415
"fmt"
1516
"io/ioutil"
1617
"math/rand"
1718
"os"
19+
"strings"
1820
"sync"
1921
"testing"
2022
"time"
@@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
268270
}
269271
}
270272

273+
// Test that we can successfully interrupt a long running query when
274+
// the context is canceled. The previous two QueryRowContext tests
275+
// only test that we handle a previously cancelled context and thus
276+
// do not call sqlite3_interrupt.
277+
func TestQueryRowContextCancelInterrupt(t *testing.T) {
278+
db, err := sql.Open("sqlite3", ":memory:")
279+
if err != nil {
280+
t.Fatal(err)
281+
}
282+
defer db.Close()
283+
284+
// Test that we have the unixepoch function and if not skip the test.
285+
if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil {
286+
libVersion, libVersionNumber, sourceID := Version()
287+
if strings.Contains(err.Error(), "no such function: unixepoch") {
288+
t.Skip("Skipping the 'unixepoch' function is not implemented in "+
289+
"this version of sqlite3:", libVersion, libVersionNumber, sourceID)
290+
}
291+
t.Fatal(err)
292+
}
293+
294+
const createTableStmt = `
295+
CREATE TABLE timestamps (
296+
ts TIMESTAMP NOT NULL
297+
);`
298+
if _, err := db.Exec(createTableStmt); err != nil {
299+
t.Fatal(err)
300+
}
301+
302+
stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`)
303+
if err != nil {
304+
t.Fatal(err)
305+
}
306+
defer stmt.Close()
307+
308+
// Computationally expensive query that consumes many rows. This is needed
309+
// to test cancellation because queries are not interrupted immediately.
310+
// Instead, queries are only halted at certain checkpoints where the
311+
// sqlite3.isInterrupted is checked and true.
312+
queryStmt := `
313+
SELECT
314+
SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1,
315+
SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2,
316+
SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3,
317+
SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4
318+
FROM
319+
timestamps
320+
WHERE datetime(ts, 'unixepoch', 'localtime')
321+
LIKE
322+
?;`
323+
324+
query := func(t *testing.T, timeout time.Duration) (int, error) {
325+
// Create a complicated pattern to match timestamps
326+
const pattern = "%2%0%2%4%-%-%:%:%"
327+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
328+
defer cancel()
329+
rows, err := db.QueryContext(ctx, queryStmt, pattern)
330+
if err != nil {
331+
return 0, err
332+
}
333+
var count int
334+
for rows.Next() {
335+
var n int64
336+
if err := rows.Scan(&n, &n, &n, &n); err != nil {
337+
return count, err
338+
}
339+
count++
340+
}
341+
return count, rows.Err()
342+
}
343+
344+
average := func(n int, fn func()) time.Duration {
345+
start := time.Now()
346+
for i := 0; i < n; i++ {
347+
fn()
348+
}
349+
return time.Since(start) / time.Duration(n)
350+
}
351+
352+
createRows := func(n int) {
353+
t.Logf("Creating %d rows", n)
354+
if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil {
355+
t.Fatal(err)
356+
}
357+
ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix()
358+
rr := rand.New(rand.NewSource(1234))
359+
for i := 0; i < n; i++ {
360+
if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil {
361+
t.Fatal(err)
362+
}
363+
}
364+
}
365+
366+
const TargetRuntime = 200 * time.Millisecond
367+
const N = 5_000 // Number of rows to insert at a time
368+
369+
// Create enough rows that the query takes ~200ms to run.
370+
start := time.Now()
371+
createRows(N)
372+
baseAvg := average(4, func() {
373+
if _, err := query(t, time.Hour); err != nil {
374+
t.Fatal(err)
375+
}
376+
})
377+
t.Log("Base average:", baseAvg)
378+
rowCount := N * (int(TargetRuntime/baseAvg) + 1)
379+
createRows(rowCount)
380+
t.Log("Table setup time:", time.Since(start))
381+
382+
// Set the timeout to 1/10 of the average query time.
383+
avg := average(2, func() {
384+
n, err := query(t, time.Hour)
385+
if err != nil {
386+
t.Fatal(err)
387+
}
388+
if n == 0 {
389+
t.Fatal("scanned zero rows")
390+
}
391+
})
392+
// Guard against the timeout being too short to reliably test.
393+
if avg < TargetRuntime/2 {
394+
t.Fatalf("Average query runtime should be around %s got: %s ",
395+
TargetRuntime, avg)
396+
}
397+
timeout := (avg / 10).Round(100 * time.Microsecond)
398+
t.Logf("Average: %s Timeout: %s", avg, timeout)
399+
400+
for i := 0; i < 10; i++ {
401+
tt := time.Now()
402+
n, err := query(t, timeout)
403+
if !errors.Is(err, context.DeadlineExceeded) {
404+
fn := t.Errorf
405+
if err != nil {
406+
fn = t.Fatalf
407+
}
408+
fn("expected error %v got %v", context.DeadlineExceeded, err)
409+
}
410+
d := time.Since(tt)
411+
t.Logf("%d: rows: %d duration: %s", i, n, d)
412+
if d > timeout*4 {
413+
t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d)
414+
}
415+
}
416+
}
417+
271418
func TestExecCancel(t *testing.T) {
272419
db, err := sql.Open("sqlite3", ":memory:")
273420
if err != nil {

sqlite3_test.go

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package sqlite3
1010

1111
import (
1212
"bytes"
13+
"context"
1314
"database/sql"
1415
"database/sql/driver"
1516
"errors"
@@ -2030,7 +2031,7 @@ func BenchmarkCustomFunctions(b *testing.B) {
20302031
}
20312032

20322033
func TestSuite(t *testing.T) {
2033-
initializeTestDB(t)
2034+
initializeTestDB(t, false)
20342035
defer freeTestDB()
20352036

20362037
for _, test := range tests {
@@ -2039,7 +2040,7 @@ func TestSuite(t *testing.T) {
20392040
}
20402041

20412042
func BenchmarkSuite(b *testing.B) {
2042-
initializeTestDB(b)
2043+
initializeTestDB(b, true)
20432044
defer freeTestDB()
20442045

20452046
for _, benchmark := range benchmarks {
@@ -2068,8 +2069,13 @@ type TestDB struct {
20682069

20692070
var db *TestDB
20702071

2071-
func initializeTestDB(t testing.TB) {
2072-
tempFilename := TempFilename(t)
2072+
func initializeTestDB(t testing.TB, memory bool) {
2073+
var tempFilename string
2074+
if memory {
2075+
tempFilename = ":memory:"
2076+
} else {
2077+
tempFilename = TempFilename(t)
2078+
}
20732079
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
20742080
if err != nil {
20752081
os.Remove(tempFilename)
@@ -2084,9 +2090,11 @@ func freeTestDB() {
20842090
if err != nil {
20852091
panic(err)
20862092
}
2087-
err = os.Remove(db.tempFilename)
2088-
if err != nil {
2089-
panic(err)
2093+
if db.tempFilename != "" && db.tempFilename != ":memory:" {
2094+
err := os.Remove(db.tempFilename)
2095+
if err != nil {
2096+
panic(err)
2097+
}
20902098
}
20912099
}
20922100

@@ -2107,6 +2115,7 @@ var tests = []testing.InternalTest{
21072115
var benchmarks = []testing.InternalBenchmark{
21082116
{Name: "BenchmarkExec", F: benchmarkExec},
21092117
{Name: "BenchmarkQuery", F: benchmarkQuery},
2118+
{Name: "BenchmarkQueryContext", F: benchmarkQueryContext},
21102119
{Name: "BenchmarkParams", F: benchmarkParams},
21112120
{Name: "BenchmarkStmt", F: benchmarkStmt},
21122121
{Name: "BenchmarkRows", F: benchmarkRows},
@@ -2480,6 +2489,65 @@ func benchmarkQuery(b *testing.B) {
24802489
}
24812490
}
24822491

2492+
// benchmarkQueryContext is benchmark for QueryContext
2493+
func benchmarkQueryContext(b *testing.B) {
2494+
const createTableStmt = `
2495+
CREATE TABLE IF NOT EXISTS query_context(
2496+
id INTEGER PRIMARY KEY
2497+
);
2498+
DELETE FROM query_context;
2499+
VACUUM;`
2500+
test := func(ctx context.Context, b *testing.B) {
2501+
if _, err := db.Exec(createTableStmt); err != nil {
2502+
b.Fatal(err)
2503+
}
2504+
for i := 0; i < 10; i++ {
2505+
_, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i))
2506+
if err != nil {
2507+
db.Fatal(err)
2508+
}
2509+
}
2510+
stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`)
2511+
if err != nil {
2512+
b.Fatal(err)
2513+
}
2514+
b.Cleanup(func() { stmt.Close() })
2515+
2516+
var n int
2517+
for i := 0; i < b.N; i++ {
2518+
rows, err := stmt.QueryContext(ctx)
2519+
if err != nil {
2520+
b.Fatal(err)
2521+
}
2522+
for rows.Next() {
2523+
if err := rows.Scan(&n); err != nil {
2524+
b.Fatal(err)
2525+
}
2526+
}
2527+
if err := rows.Err(); err != nil {
2528+
b.Fatal(err)
2529+
}
2530+
}
2531+
}
2532+
2533+
// When the context does not have a Done channel we should use
2534+
// the fast path that directly handles the query instead of
2535+
// handling it in a goroutine. This benchmark also serves to
2536+
// highlight the performance impact of using a cancelable
2537+
// context.
2538+
b.Run("Background", func(b *testing.B) {
2539+
test(context.Background(), b)
2540+
})
2541+
2542+
// Benchmark a query with a context that can be canceled. This
2543+
// requires using a goroutine and is thus much slower.
2544+
b.Run("WithCancel", func(b *testing.B) {
2545+
ctx, cancel := context.WithCancel(context.Background())
2546+
defer cancel()
2547+
test(ctx, b)
2548+
})
2549+
}
2550+
24832551
// benchmarkParams is benchmark for params
24842552
func benchmarkParams(b *testing.B) {
24852553
for i := 0; i < b.N; i++ {

0 commit comments

Comments
 (0)