Skip to content

Commit 94cba2b

Browse files
authored
Fix INSERT RETURNING (#88)
* Fix negative identifier in cursor * Fix INSERT RETURNING Co-authored-by: hfhbd <[email protected]>
1 parent 550388f commit 94cba2b

File tree

4 files changed

+225
-52
lines changed

4 files changed

+225
-52
lines changed

postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt

Lines changed: 143 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
6666
paramValues = preparedStatement?.values(this),
6767
paramFormats = preparedStatement?.formats?.refTo(0),
6868
paramLengths = preparedStatement?.lengths?.refTo(0),
69-
resultFormat = BINARY_RESULT_FORMAT
69+
resultFormat = TEXT_RESULT_FORMAT
7070
)
7171
}
7272
} else {
@@ -78,7 +78,7 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
7878
paramValues = preparedStatement?.values(this),
7979
paramFormats = preparedStatement?.formats?.refTo(0),
8080
paramLengths = preparedStatement?.lengths?.refTo(0),
81-
resultFormat = BINARY_RESULT_FORMAT,
81+
resultFormat = TEXT_RESULT_FORMAT,
8282
paramTypes = preparedStatement?.types?.refTo(0)
8383
)
8484
}
@@ -87,11 +87,12 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
8787
return QueryResult.Value(value = result.rows)
8888
}
8989

90-
private val CPointer<PGresult>.rows: Long get() {
91-
val rows = PQcmdTuples(this)!!.toKString()
92-
clear()
93-
return rows.toLongOrNull() ?: 0
94-
}
90+
private val CPointer<PGresult>.rows: Long
91+
get() {
92+
val rows = PQcmdTuples(this)!!.toKString()
93+
clear()
94+
return rows.toLongOrNull() ?: 0
95+
}
9596

9697
private fun preparedStatementExists(identifier: Int): Boolean {
9798
val result =
@@ -104,35 +105,92 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
104105
return result.value != null
105106
}
106107

107-
private fun Int.escapeNegative(): String = if(this < 0) "_${toString().substring(1)}" else toString()
108+
private fun Int.escapeNegative(): String = if (this < 0) "_${toString().substring(1)}" else toString()
108109

109-
override fun <R> executeQuery(
110+
private fun preparedStatement(
111+
parameters: Int,
112+
binders: (SqlPreparedStatement.() -> Unit)?
113+
): PostgresPreparedStatement? = if (parameters != 0) {
114+
PostgresPreparedStatement(parameters).apply {
115+
if (binders != null) {
116+
binders()
117+
}
118+
}
119+
} else null
120+
121+
public fun <R> executeQueryWithNativeCursor(
110122
identifier: Int?,
111123
sql: String,
112124
mapper: (SqlCursor) -> R,
113125
parameters: Int,
126+
fetchSize: Int = 1,
114127
binders: (SqlPreparedStatement.() -> Unit)?
115128
): QueryResult.Value<R> {
116129
val cursorName = if (identifier == null) "myCursor" else "cursor${identifier.escapeNegative()}"
117130
val cursor = "DECLARE $cursorName CURSOR FOR"
118-
val preparedStatement = if (parameters != 0) {
119-
PostgresPreparedStatement(parameters).apply {
120-
if (binders != null) {
121-
binders()
122-
}
123-
}
124-
} else null
131+
132+
val preparedStatement = preparedStatement(parameters, binders)
125133
val result = if (identifier != null) {
126-
if (!preparedStatementExists(identifier)) {
127-
PQprepare(
134+
checkPreparedStatement(identifier, "$cursor $sql", parameters, preparedStatement)
135+
conn.exec("BEGIN")
136+
memScoped {
137+
PQexecPrepared(
128138
conn,
129139
stmtName = identifier.toString(),
130-
query = "$cursor $sql",
131140
nParams = parameters,
132-
paramTypes = preparedStatement?.types?.refTo(0)
133-
).check(conn).clear()
141+
paramValues = preparedStatement?.values(this),
142+
paramLengths = preparedStatement?.lengths?.refTo(0),
143+
paramFormats = preparedStatement?.formats?.refTo(0),
144+
resultFormat = TEXT_RESULT_FORMAT
145+
)
134146
}
147+
} else {
135148
conn.exec("BEGIN")
149+
memScoped {
150+
PQexecParams(
151+
conn,
152+
command = "$cursor $sql",
153+
nParams = parameters,
154+
paramValues = preparedStatement?.values(this),
155+
paramLengths = preparedStatement?.lengths?.refTo(0),
156+
paramFormats = preparedStatement?.formats?.refTo(0),
157+
paramTypes = preparedStatement?.types?.refTo(0),
158+
resultFormat = TEXT_RESULT_FORMAT
159+
)
160+
}
161+
}.check(conn)
162+
163+
val value = PostgresCursor.RealCursor(result, cursorName, conn, fetchSize).use(mapper)
164+
return QueryResult.Value(value = value)
165+
}
166+
167+
private fun checkPreparedStatement(
168+
identifier: Int,
169+
sql: String,
170+
parameters: Int,
171+
preparedStatement: PostgresPreparedStatement?
172+
) {
173+
if (!preparedStatementExists(identifier)) {
174+
PQprepare(
175+
conn,
176+
stmtName = identifier.toString(),
177+
query = sql,
178+
nParams = parameters,
179+
paramTypes = preparedStatement?.types?.refTo(0)
180+
).check(conn).clear()
181+
}
182+
}
183+
184+
override fun <R> executeQuery(
185+
identifier: Int?,
186+
sql: String,
187+
mapper: (SqlCursor) -> R,
188+
parameters: Int,
189+
binders: (SqlPreparedStatement.() -> Unit)?
190+
): QueryResult.Value<R> {
191+
val preparedStatement = preparedStatement(parameters, binders)
192+
val result = if (identifier != null) {
193+
checkPreparedStatement(identifier, sql, parameters, preparedStatement)
136194
memScoped {
137195
PQexecPrepared(
138196
conn,
@@ -141,26 +199,25 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
141199
paramValues = preparedStatement?.values(this),
142200
paramLengths = preparedStatement?.lengths?.refTo(0),
143201
paramFormats = preparedStatement?.formats?.refTo(0),
144-
resultFormat = BINARY_RESULT_FORMAT
202+
resultFormat = TEXT_RESULT_FORMAT
145203
)
146204
}
147205
} else {
148-
conn.exec("BEGIN")
149206
memScoped {
150207
PQexecParams(
151208
conn,
152-
command = "$cursor $sql",
209+
command = sql,
153210
nParams = parameters,
154211
paramValues = preparedStatement?.values(this),
155212
paramLengths = preparedStatement?.lengths?.refTo(0),
156213
paramFormats = preparedStatement?.formats?.refTo(0),
157214
paramTypes = preparedStatement?.types?.refTo(0),
158-
resultFormat = BINARY_RESULT_FORMAT
215+
resultFormat = TEXT_RESULT_FORMAT
159216
)
160217
}
161218
}.check(conn)
162219

163-
val value = PostgresCursor(result, cursorName, conn).use(mapper)
220+
val value = PostgresCursor.NoCursor(result).use(mapper)
164221
return QueryResult.Value(value = value)
165222
}
166223

@@ -232,29 +289,73 @@ private fun CPointer<PGresult>?.check(conn: CPointer<PGconn>): CPointer<PGresult
232289
return this!!
233290
}
234291

235-
/**
236-
* Must be inside a transaction!
237-
*/
238-
public class PostgresCursor(
239-
private var result: CPointer<PGresult>,
240-
private val name: String,
241-
private val conn: CPointer<PGconn>
292+
public sealed class PostgresCursor(
293+
internal var result: CPointer<PGresult>
242294
) : SqlCursor, Closeable {
243-
override fun close() {
244-
result.clear()
245-
conn.exec("CLOSE $name")
246-
conn.exec("END")
295+
internal abstract val currentRowIndex: Int
296+
297+
/**
298+
* Must be inside a transaction!
299+
*/
300+
internal class RealCursor(
301+
result: CPointer<PGresult>,
302+
private val name: String,
303+
private val conn: CPointer<PGconn>,
304+
private val fetchSize: Int
305+
) : PostgresCursor(result) {
306+
override fun close() {
307+
result.clear()
308+
conn.exec("CLOSE $name")
309+
conn.exec("END")
310+
}
311+
312+
override var currentRowIndex = -1
313+
private var maxRowIndex = -1
314+
315+
override fun next(): Boolean {
316+
if (currentRowIndex == maxRowIndex) {
317+
currentRowIndex = -1
318+
}
319+
if (currentRowIndex == -1) {
320+
result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn)
321+
maxRowIndex = PQntuples(result) - 1
322+
}
323+
return if (currentRowIndex < maxRowIndex) {
324+
currentRowIndex += 1
325+
true
326+
} else false
327+
}
328+
}
329+
330+
internal class NoCursor(
331+
result: CPointer<PGresult>
332+
) : PostgresCursor(result) {
333+
override fun close() {
334+
result.clear()
335+
}
336+
337+
private val maxRowIndex = PQntuples(result) - 1
338+
override var currentRowIndex = -1
339+
340+
override fun next(): Boolean {
341+
return if (currentRowIndex < maxRowIndex) {
342+
currentRowIndex += 1
343+
true
344+
} else {
345+
false
346+
}
347+
}
247348
}
248349

249350
override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean()
250351

251352
override fun getBytes(index: Int): ByteArray? {
252-
val isNull = PQgetisnull(result, tup_num = 0, field_num = index) == 1
353+
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
253354
return if (isNull) {
254355
null
255356
} else {
256-
val bytes = PQgetvalue(result, tup_num = 0, field_num = index)!!
257-
val length = PQgetlength(result, tup_num = 0, field_num = index)
357+
val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!!
358+
val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index)
258359
bytes.fromHex(length)
259360
}
260361
}
@@ -284,11 +385,11 @@ public class PostgresCursor(
284385
override fun getLong(index: Int): Long? = getString(index)?.toLong()
285386

286387
override fun getString(index: Int): String? {
287-
val isNull = PQgetisnull(result, tup_num = 0, field_num = index) == 1
388+
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
288389
return if (isNull) {
289390
null
290391
} else {
291-
val value = PQgetvalue(result, tup_num = 0, field_num = index)
392+
val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)
292393
value!!.toKString()
293394
}
294395
}
@@ -302,11 +403,6 @@ public class PostgresCursor(
302403

303404
public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) }
304405
public fun getUUID(index: Int): UUID? = getString(index)?.toUUID()
305-
306-
override fun next(): Boolean {
307-
result = PQexec(conn, "FETCH NEXT IN $name").check(conn)
308-
return PQcmdTuples(result)!!.toKString().toInt() == 1
309-
}
310406
}
311407

312408
public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement {

postgres-native-sqldelight-driver/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ class PostgresNativeDriverTest {
1313
password = "password"
1414
)
1515
assertEquals(0, driver.execute(null, "DROP TABLE IF EXISTS baz;", parameters = 0).value)
16-
assertEquals(0, driver.execute(null, "CREATE TABLE baz(a int primary key, foo text, b bytea);", parameters = 0).value)
16+
assertEquals(
17+
0,
18+
driver.execute(null, "CREATE TABLE baz(a INT PRIMARY KEY, foo TEXT, b BYTEA);", parameters = 0).value
19+
)
1720
repeat(5) {
1821
val result = driver.execute(null, "INSERT INTO baz VALUES ($it)", parameters = 0)
1922
assertEquals(1, result.value)
@@ -29,7 +32,7 @@ class PostgresNativeDriverTest {
2932
bindBytes(5, byteArrayOf(16.toByte(), 12.toByte()))
3033
}.value
3134
assertEquals(2, result)
32-
val notPrepared = driver.executeQuery(null, "SELECT * from baz limit 1;", parameters = 0, mapper = {
35+
val notPrepared = driver.executeQuery(null, "SELECT * FROM baz LIMIT 1;", parameters = 0, mapper = {
3336
assertTrue(it.next())
3437
Simple(
3538
index = it.getLong(0)!!.toInt(),
@@ -57,6 +60,7 @@ class PostgresNativeDriverTest {
5760
}
5861
).value
5962

63+
assertEquals(7, preparedStatement.size)
6064
assertEquals(
6165
List(5) {
6266
Simple(it, null, null)
@@ -66,6 +70,75 @@ class PostgresNativeDriverTest {
6670
),
6771
preparedStatement
6872
)
73+
74+
expect(7) {
75+
val cursorList = driver.executeQueryWithNativeCursor(
76+
-99,
77+
"SELECT * FROM baz",
78+
fetchSize = 4,
79+
parameters = 0,
80+
binders = null,
81+
mapper = {
82+
buildList {
83+
while (it.next()) {
84+
add(
85+
Simple(
86+
index = it.getLong(0)!!.toInt(),
87+
name = it.getString(1),
88+
byteArray = it.getBytes(2)
89+
)
90+
)
91+
}
92+
}
93+
}).value
94+
cursorList.size
95+
}
96+
97+
expect(7) {
98+
val cursorList = driver.executeQueryWithNativeCursor(
99+
-5,
100+
"SELECT * FROM baz",
101+
fetchSize = 1,
102+
parameters = 0,
103+
binders = null,
104+
mapper = {
105+
buildList {
106+
while (it.next()) {
107+
add(
108+
Simple(
109+
index = it.getLong(0)!!.toInt(),
110+
name = it.getString(1),
111+
byteArray = it.getBytes(2)
112+
)
113+
)
114+
}
115+
}
116+
}).value
117+
cursorList.size
118+
}
119+
120+
expect(0) {
121+
val cursorList = driver.executeQueryWithNativeCursor(
122+
-100,
123+
"SELECT * FROM baz WHERE a = -1",
124+
fetchSize = 1,
125+
parameters = 0,
126+
binders = null,
127+
mapper = {
128+
buildList {
129+
while (it.next()) {
130+
add(
131+
Simple(
132+
index = it.getLong(0)!!.toInt(),
133+
name = it.getString(1),
134+
byteArray = it.getBytes(2)
135+
)
136+
)
137+
}
138+
}
139+
}).value
140+
cursorList.size
141+
}
69142
}
70143

71144
private data class Simple(val index: Int, val name: String?, val byteArray: ByteArray?) {

0 commit comments

Comments
 (0)