From f1f3a26e85b4ef4f1ed78b91513e6b8ead1db185 Mon Sep 17 00:00:00 2001 From: hfhbd Date: Mon, 17 Oct 2022 21:53:14 +0200 Subject: [PATCH 1/6] Fix INSERT RETURNING --- .../postgresdriver/PostgresNativeDriver.kt | 162 ++++++++++++++---- .../sqldelight/postgresdriver/Users.sq | 5 +- .../PostgresNativeDriverTest.kt | 5 +- 3 files changed, 134 insertions(+), 38 deletions(-) diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt index 7f22250..a0685a5 100644 --- a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt @@ -66,7 +66,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive paramValues = preparedStatement?.values(this), paramFormats = preparedStatement?.formats?.refTo(0), paramLengths = preparedStatement?.lengths?.refTo(0), - resultFormat = BINARY_RESULT_FORMAT + resultFormat = TEXT_RESULT_FORMAT ) } } else { @@ -78,7 +78,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive paramValues = preparedStatement?.values(this), paramFormats = preparedStatement?.formats?.refTo(0), paramLengths = preparedStatement?.lengths?.refTo(0), - resultFormat = BINARY_RESULT_FORMAT, + resultFormat = TEXT_RESULT_FORMAT, paramTypes = preparedStatement?.types?.refTo(0) ) } @@ -106,33 +106,105 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive private fun Int.escapeNegative(): String = if(this < 0) "_${toString().substring(1)}" else toString() + private fun preparedStatement( + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)? + ): PostgresPreparedStatement? = if (parameters != 0) { + PostgresPreparedStatement(parameters).apply { + if (binders != null) { + binders() + } + } + } else null + override fun executeQuery( identifier: Int?, sql: String, mapper: (SqlCursor) -> R, parameters: Int, binders: (SqlPreparedStatement.() -> Unit)? + ): QueryResult.Value = if (sql.startsWith("INSERT")) { + executeInsertReturningQuery( + identifier = identifier, sql = sql, mapper = mapper, parameters = parameters, binders = binders + ) + } else { + executeNormalQuery( + identifier = identifier, sql = sql, mapper = mapper, parameters = parameters, binders = binders + ) + } + + private fun executeNormalQuery( + identifier: Int?, + sql: String, + mapper: (SqlCursor) -> R, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)? ): QueryResult.Value { val cursorName = if (identifier == null) "myCursor" else "cursor${identifier.escapeNegative()}" val cursor = "DECLARE $cursorName CURSOR FOR" - val preparedStatement = if (parameters != 0) { - PostgresPreparedStatement(parameters).apply { - if (binders != null) { - binders() - } - } - } else null + + val preparedStatement = preparedStatement(parameters, binders) val result = if (identifier != null) { - if (!preparedStatementExists(identifier)) { - PQprepare( + checkPreparedStatement(identifier, sql, parameters, preparedStatement) + conn.exec("BEGIN") + memScoped { + PQexecPrepared( conn, stmtName = identifier.toString(), - query = "$cursor $sql", nParams = parameters, - paramTypes = preparedStatement?.types?.refTo(0) - ).check(conn).clear() + paramValues = preparedStatement?.values(this), + paramLengths = preparedStatement?.lengths?.refTo(0), + paramFormats = preparedStatement?.formats?.refTo(0), + resultFormat = TEXT_RESULT_FORMAT + ) } + } else { conn.exec("BEGIN") + memScoped { + PQexecParams( + conn, + command = "$cursor $sql", + nParams = parameters, + paramValues = preparedStatement?.values(this), + paramLengths = preparedStatement?.lengths?.refTo(0), + paramFormats = preparedStatement?.formats?.refTo(0), + paramTypes = preparedStatement?.types?.refTo(0), + resultFormat = TEXT_RESULT_FORMAT + ) + } + }.check(conn) + + val value = PostgresCursor.MultiCursor(result, cursorName, conn).use(mapper) + return QueryResult.Value(value = value) + } + + private fun checkPreparedStatement( + identifier: Int, + sql: String, + parameters: Int, + preparedStatement: PostgresPreparedStatement? + ) { + if (!preparedStatementExists(identifier)) { + PQprepare( + conn, + stmtName = identifier.toString(), + query = sql, + nParams = parameters, + paramTypes = preparedStatement?.types?.refTo(0) + ).check(conn).clear() + } + } + + private fun executeInsertReturningQuery( + identifier: Int?, + sql: String, + mapper: (SqlCursor) -> R, + parameters: Int, + binders: (SqlPreparedStatement.() -> Unit)? + ): QueryResult.Value { + val preparedStatement = preparedStatement(parameters, binders) + val result = if (identifier != null) { + checkPreparedStatement(identifier, sql, parameters, preparedStatement) memScoped { PQexecPrepared( conn, @@ -141,26 +213,25 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive paramValues = preparedStatement?.values(this), paramLengths = preparedStatement?.lengths?.refTo(0), paramFormats = preparedStatement?.formats?.refTo(0), - resultFormat = BINARY_RESULT_FORMAT + resultFormat = TEXT_RESULT_FORMAT ) } } else { - conn.exec("BEGIN") memScoped { PQexecParams( conn, - command = "$cursor $sql", + command = sql, nParams = parameters, paramValues = preparedStatement?.values(this), paramLengths = preparedStatement?.lengths?.refTo(0), paramFormats = preparedStatement?.formats?.refTo(0), paramTypes = preparedStatement?.types?.refTo(0), - resultFormat = BINARY_RESULT_FORMAT + resultFormat = TEXT_RESULT_FORMAT ) } }.check(conn) - val value = PostgresCursor(result, cursorName, conn).use(mapper) + val value = PostgresCursor.SingleCursor(result).use(mapper) return QueryResult.Value(value = value) } @@ -232,18 +303,44 @@ private fun CPointer?.check(conn: CPointer): CPointer, - private val name: String, - private val conn: CPointer +public sealed class PostgresCursor( + internal var result: CPointer ) : SqlCursor, Closeable { - override fun close() { - result.clear() - conn.exec("CLOSE $name") - conn.exec("END") + + /** + * Must be inside a transaction! + */ + internal class MultiCursor( + result: CPointer, + private val name: String, + private val conn: CPointer + ) : PostgresCursor(result) { + override fun close() { + result.clear() + conn.exec("CLOSE $name") + conn.exec("END") + } + + override fun next(): Boolean { + result = PQexec(conn, "FETCH NEXT IN $name").check(conn) + return PQcmdTuples(result)!!.toKString().toInt() == 1 + } + } + + internal class SingleCursor( + result: CPointer + ) : PostgresCursor(result) { + override fun close() { + result.clear() + } + + private var hasNext = true + override fun next() = if (hasNext) { + hasNext = false + true + } else { + false + } } override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean() @@ -302,11 +399,6 @@ public class PostgresCursor( public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) } public fun getUUID(index: Int): UUID? = getString(index)?.toUUID() - - override fun next(): Boolean { - result = PQexec(conn, "FETCH NEXT IN $name").check(conn) - return PQcmdTuples(result)!!.toKString().toInt() == 1 - } } public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement { diff --git a/testing/src/commonMain/sqldelight/app/softwork/sqldelight/postgresdriver/Users.sq b/testing/src/commonMain/sqldelight/app/softwork/sqldelight/postgresdriver/Users.sq index a51c967..a048400 100644 --- a/testing/src/commonMain/sqldelight/app/softwork/sqldelight/postgresdriver/Users.sq +++ b/testing/src/commonMain/sqldelight/app/softwork/sqldelight/postgresdriver/Users.sq @@ -1,6 +1,7 @@ -insert: +insertAndGet: INSERT INTO users(email, username, bio, image) -VALUES (?, ?, ?, ?); +VALUES (?, ?, ?, ?) +RETURNING id; selectByUsername: SELECT email, username, bio, image diff --git a/testing/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt b/testing/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt index 7fd6acf..fec9369 100644 --- a/testing/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt +++ b/testing/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt @@ -68,7 +68,10 @@ class PostgresNativeDriverTest { fun userTest() { val queries = NativePostgres(driver).usersQueries NativePostgres.Schema.migrate(driver, 0, NativePostgres.Schema.version) - queries.insert("test@test", "test", "bio", "") + val id = queries.insertAndGet("test@test", "test", "bio", "").executeAsOne() + assertEquals(1, id) + val id2 = queries.insertAndGet("test2@test", "test2", "bio2", "").executeAsOne() + assertEquals(2, id2) val testUser = queries.selectByUsername("test").executeAsOne() assertEquals( SelectByUsername( From 17633fc09a13cc63a3514a55e88006b3bcec5b7c Mon Sep 17 00:00:00 2001 From: hfhbd Date: Tue, 18 Oct 2022 09:17:18 +0200 Subject: [PATCH 2/6] Introduce a native cursor method --- .../postgresdriver/PostgresNativeDriver.kt | 63 ++++++++----------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt index a0685a5..63aed53 100644 --- a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt @@ -87,11 +87,12 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive return QueryResult.Value(value = result.rows) } - private val CPointer.rows: Long get() { - val rows = PQcmdTuples(this)!!.toKString() - clear() - return rows.toLongOrNull() ?: 0 - } + private val CPointer.rows: Long + get() { + val rows = PQcmdTuples(this)!!.toKString() + clear() + return rows.toLongOrNull() ?: 0 + } private fun preparedStatementExists(identifier: Int): Boolean { val result = @@ -104,7 +105,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive return result.value != null } - private fun Int.escapeNegative(): String = if(this < 0) "_${toString().substring(1)}" else toString() + private fun Int.escapeNegative(): String = if (this < 0) "_${toString().substring(1)}" else toString() private fun preparedStatement( parameters: Int, @@ -117,23 +118,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } } else null - override fun executeQuery( - identifier: Int?, - sql: String, - mapper: (SqlCursor) -> R, - parameters: Int, - binders: (SqlPreparedStatement.() -> Unit)? - ): QueryResult.Value = if (sql.startsWith("INSERT")) { - executeInsertReturningQuery( - identifier = identifier, sql = sql, mapper = mapper, parameters = parameters, binders = binders - ) - } else { - executeNormalQuery( - identifier = identifier, sql = sql, mapper = mapper, parameters = parameters, binders = binders - ) - } - - private fun executeNormalQuery( + public fun executeQueryWithNativeCursor( identifier: Int?, sql: String, mapper: (SqlCursor) -> R, @@ -174,7 +159,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } }.check(conn) - val value = PostgresCursor.MultiCursor(result, cursorName, conn).use(mapper) + val value = PostgresCursor.RealCursor(result, cursorName, conn).use(mapper) return QueryResult.Value(value = value) } @@ -195,7 +180,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } } - private fun executeInsertReturningQuery( + override fun executeQuery( identifier: Int?, sql: String, mapper: (SqlCursor) -> R, @@ -231,7 +216,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } }.check(conn) - val value = PostgresCursor.SingleCursor(result).use(mapper) + val value = PostgresCursor.NoCursor(result).use(mapper) return QueryResult.Value(value = value) } @@ -306,11 +291,12 @@ private fun CPointer?.check(conn: CPointer): CPointer ) : SqlCursor, Closeable { + internal abstract val currentRow: Int /** * Must be inside a transaction! */ - internal class MultiCursor( + internal class RealCursor( result: CPointer, private val name: String, private val conn: CPointer @@ -321,22 +307,27 @@ public sealed class PostgresCursor( conn.exec("END") } + override val currentRow = 0 + override fun next(): Boolean { result = PQexec(conn, "FETCH NEXT IN $name").check(conn) return PQcmdTuples(result)!!.toKString().toInt() == 1 } } - internal class SingleCursor( + internal class NoCursor( result: CPointer ) : PostgresCursor(result) { override fun close() { result.clear() } - private var hasNext = true - override fun next() = if (hasNext) { - hasNext = false + override val currentRow get() = row + + private val rows = PQntuples(result) + private var row = 0 + override fun next() = if (row < rows) { + row += 1 true } else { false @@ -346,12 +337,12 @@ public sealed class PostgresCursor( override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean() override fun getBytes(index: Int): ByteArray? { - val isNull = PQgetisnull(result, tup_num = 0, field_num = index) == 1 + val isNull = PQgetisnull(result, tup_num = currentRow, field_num = index) == 1 return if (isNull) { null } else { - val bytes = PQgetvalue(result, tup_num = 0, field_num = index)!! - val length = PQgetlength(result, tup_num = 0, field_num = index) + val bytes = PQgetvalue(result, tup_num = currentRow, field_num = index)!! + val length = PQgetlength(result, tup_num = currentRow, field_num = index) bytes.fromHex(length) } } @@ -381,11 +372,11 @@ public sealed class PostgresCursor( override fun getLong(index: Int): Long? = getString(index)?.toLong() override fun getString(index: Int): String? { - val isNull = PQgetisnull(result, tup_num = 0, field_num = index) == 1 + val isNull = PQgetisnull(result, tup_num = currentRow, field_num = index) == 1 return if (isNull) { null } else { - val value = PQgetvalue(result, tup_num = 0, field_num = index) + val value = PQgetvalue(result, tup_num = currentRow, field_num = index) value!!.toKString() } } From 0e0c5f7dd4213d3b1c37bd67c1cba110e9734f21 Mon Sep 17 00:00:00 2001 From: hfhbd Date: Tue, 18 Oct 2022 11:48:02 +0200 Subject: [PATCH 3/6] Custom cursor size --- .../postgresdriver/PostgresNativeDriver.kt | 51 ++++++++++++------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt index 63aed53..3cbddbd 100644 --- a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt @@ -123,6 +123,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive sql: String, mapper: (SqlCursor) -> R, parameters: Int, + fetchSize: Int = 0, binders: (SqlPreparedStatement.() -> Unit)? ): QueryResult.Value { val cursorName = if (identifier == null) "myCursor" else "cursor${identifier.escapeNegative()}" @@ -159,7 +160,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } }.check(conn) - val value = PostgresCursor.RealCursor(result, cursorName, conn).use(mapper) + val value = PostgresCursor.RealCursor(result, cursorName, conn, fetchSize).use(mapper) return QueryResult.Value(value = value) } @@ -291,7 +292,7 @@ private fun CPointer?.check(conn: CPointer): CPointer ) : SqlCursor, Closeable { - internal abstract val currentRow: Int + internal abstract val currentRowIndex: Int /** * Must be inside a transaction! @@ -299,7 +300,8 @@ public sealed class PostgresCursor( internal class RealCursor( result: CPointer, private val name: String, - private val conn: CPointer + private val conn: CPointer, + private val fetchSize: Int ) : PostgresCursor(result) { override fun close() { result.clear() @@ -307,11 +309,21 @@ public sealed class PostgresCursor( conn.exec("END") } - override val currentRow = 0 + override var currentRowIndex = -1 + private var maxRowIndex = -1 override fun next(): Boolean { - result = PQexec(conn, "FETCH NEXT IN $name").check(conn) - return PQcmdTuples(result)!!.toKString().toInt() == 1 + if (currentRowIndex == maxRowIndex) { + currentRowIndex = -1 + } + if (currentRowIndex == -1) { + result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn) + maxRowIndex = PQntuples(result) - 1 + } + return if (currentRowIndex < maxRowIndex) { + currentRowIndex += 1 + true + } else false } } @@ -322,27 +334,28 @@ public sealed class PostgresCursor( result.clear() } - override val currentRow get() = row + private val maxRowIndex = PQntuples(result) - 1 + override var currentRowIndex = -1 - private val rows = PQntuples(result) - private var row = 0 - override fun next() = if (row < rows) { - row += 1 - true - } else { - false + override fun next(): Boolean { + return if (currentRowIndex < maxRowIndex) { + currentRowIndex += 1 + true + } else { + false + } } } override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean() override fun getBytes(index: Int): ByteArray? { - val isNull = PQgetisnull(result, tup_num = currentRow, field_num = index) == 1 + val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 return if (isNull) { null } else { - val bytes = PQgetvalue(result, tup_num = currentRow, field_num = index)!! - val length = PQgetlength(result, tup_num = currentRow, field_num = index) + val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!! + val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index) bytes.fromHex(length) } } @@ -372,11 +385,11 @@ public sealed class PostgresCursor( override fun getLong(index: Int): Long? = getString(index)?.toLong() override fun getString(index: Int): String? { - val isNull = PQgetisnull(result, tup_num = currentRow, field_num = index) == 1 + val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 return if (isNull) { null } else { - val value = PQgetvalue(result, tup_num = currentRow, field_num = index) + val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index) value!!.toKString() } } From 41145b01b51d9dc404ac53a57ce4d76d21282d15 Mon Sep 17 00:00:00 2001 From: hfhbd Date: Tue, 18 Oct 2022 12:06:27 +0200 Subject: [PATCH 4/6] Custom cursor size test --- .../postgresdriver/PostgresNativeDriver.kt | 4 +- .../PostgresNativeDriverTest.kt | 77 ++++++++++++++++++- 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt index 3cbddbd..958dd8a 100644 --- a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt @@ -123,7 +123,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive sql: String, mapper: (SqlCursor) -> R, parameters: Int, - fetchSize: Int = 0, + fetchSize: Int = 1, binders: (SqlPreparedStatement.() -> Unit)? ): QueryResult.Value { val cursorName = if (identifier == null) "myCursor" else "cursor${identifier.escapeNegative()}" @@ -131,7 +131,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive val preparedStatement = preparedStatement(parameters, binders) val result = if (identifier != null) { - checkPreparedStatement(identifier, sql, parameters, preparedStatement) + checkPreparedStatement(identifier, "$cursor $sql", parameters, preparedStatement) conn.exec("BEGIN") memScoped { PQexecPrepared( diff --git a/postgres-native-sqldelight-driver/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt b/postgres-native-sqldelight-driver/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt index fbb7f89..6153b7f 100644 --- a/postgres-native-sqldelight-driver/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt +++ b/postgres-native-sqldelight-driver/src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt @@ -13,7 +13,10 @@ class PostgresNativeDriverTest { password = "password" ) assertEquals(0, driver.execute(null, "DROP TABLE IF EXISTS baz;", parameters = 0).value) - assertEquals(0, driver.execute(null, "CREATE TABLE baz(a int primary key, foo text, b bytea);", parameters = 0).value) + assertEquals( + 0, + driver.execute(null, "CREATE TABLE baz(a INT PRIMARY KEY, foo TEXT, b BYTEA);", parameters = 0).value + ) repeat(5) { val result = driver.execute(null, "INSERT INTO baz VALUES ($it)", parameters = 0) assertEquals(1, result.value) @@ -29,7 +32,7 @@ class PostgresNativeDriverTest { bindBytes(5, byteArrayOf(16.toByte(), 12.toByte())) }.value assertEquals(2, result) - val notPrepared = driver.executeQuery(null, "SELECT * from baz limit 1;", parameters = 0, mapper = { + val notPrepared = driver.executeQuery(null, "SELECT * FROM baz LIMIT 1;", parameters = 0, mapper = { assertTrue(it.next()) Simple( index = it.getLong(0)!!.toInt(), @@ -57,6 +60,7 @@ class PostgresNativeDriverTest { } ).value + assertEquals(7, preparedStatement.size) assertEquals( List(5) { Simple(it, null, null) @@ -66,6 +70,75 @@ class PostgresNativeDriverTest { ), preparedStatement ) + + expect(7) { + val cursorList = driver.executeQueryWithNativeCursor( + -99, + "SELECT * FROM baz", + fetchSize = 4, + parameters = 0, + binders = null, + mapper = { + buildList { + while (it.next()) { + add( + Simple( + index = it.getLong(0)!!.toInt(), + name = it.getString(1), + byteArray = it.getBytes(2) + ) + ) + } + } + }).value + cursorList.size + } + + expect(7) { + val cursorList = driver.executeQueryWithNativeCursor( + -5, + "SELECT * FROM baz", + fetchSize = 1, + parameters = 0, + binders = null, + mapper = { + buildList { + while (it.next()) { + add( + Simple( + index = it.getLong(0)!!.toInt(), + name = it.getString(1), + byteArray = it.getBytes(2) + ) + ) + } + } + }).value + cursorList.size + } + + expect(0) { + val cursorList = driver.executeQueryWithNativeCursor( + -100, + "SELECT * FROM baz WHERE a = -1", + fetchSize = 1, + parameters = 0, + binders = null, + mapper = { + buildList { + while (it.next()) { + add( + Simple( + index = it.getLong(0)!!.toInt(), + name = it.getString(1), + byteArray = it.getBytes(2) + ) + ) + } + } + }).value + cursorList.size + } } private data class Simple(val index: Int, val name: String?, val byteArray: ByteArray?) { From 3614509cf9978dab734ee2907e926717986d4513 Mon Sep 17 00:00:00 2001 From: hfhbd Date: Tue, 18 Oct 2022 12:28:55 +0200 Subject: [PATCH 5/6] Split code into multiple files --- .../sqldelight/postgresdriver/NoCursor.kt | 24 ++ .../postgresdriver/PostgresCursor.kt | 71 ++++++ .../postgresdriver/PostgresNativeDriver.kt | 225 +----------------- .../PostgresPreparedStatement.kt | 103 ++++++++ .../sqldelight/postgresdriver/RealCursor.kt | 37 +++ 5 files changed, 240 insertions(+), 220 deletions(-) create mode 100644 postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/NoCursor.kt create mode 100644 postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresCursor.kt create mode 100644 postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt create mode 100644 postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/RealCursor.kt diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/NoCursor.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/NoCursor.kt new file mode 100644 index 0000000..bcb6e19 --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/NoCursor.kt @@ -0,0 +1,24 @@ +package app.softwork.sqldelight.postgresdriver + +import kotlinx.cinterop.* +import libpq.* + +internal class NoCursor( + result: CPointer +) : PostgresCursor(result) { + override fun close() { + result.clear() + } + + private val maxRowIndex = PQntuples(result) - 1 + override var currentRowIndex = -1 + + override fun next(): Boolean { + return if (currentRowIndex < maxRowIndex) { + currentRowIndex += 1 + true + } else { + false + } + } +} diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresCursor.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresCursor.kt new file mode 100644 index 0000000..2055ee4 --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresCursor.kt @@ -0,0 +1,71 @@ +package app.softwork.sqldelight.postgresdriver + +import app.cash.sqldelight.db.* +import kotlinx.cinterop.* +import kotlinx.datetime.* +import kotlinx.uuid.* +import libpq.* +import kotlin.time.* + +public sealed class PostgresCursor( + internal var result: CPointer +) : SqlCursor, Closeable { + internal abstract val currentRowIndex: Int + + override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean() + + override fun getBytes(index: Int): ByteArray? { + val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 + return if (isNull) { + null + } else { + val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!! + val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index) + bytes.fromHex(length) + } + } + + private inline fun Int.fromHex(): Int = if (this in 48..57) { + this - 48 + } else { + this - 87 + } + + // because "normal" CPointer.toByteArray() functions does not support hex (2 Bytes) bytes + private fun CPointer.fromHex(length: Int): ByteArray { + val array = ByteArray((length - 2) / 2) + var index = 0 + for (i in 2 until length step 2) { + val first = this[i].toInt().fromHex() + val second = this[i + 1].toInt().fromHex() + val octet = first.shl(4).or(second) + array[index] = octet.toByte() + index++ + } + return array + } + + override fun getDouble(index: Int): Double? = getString(index)?.toDouble() + + override fun getLong(index: Int): Long? = getString(index)?.toLong() + + override fun getString(index: Int): String? { + val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 + return if (isNull) { + null + } else { + val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index) + value!!.toKString() + } + } + + public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate() + public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime() + public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime() + public fun getTimestamp(index: Int): Instant? = getString(index)?.let { + Instant.parse(it.replace(" ", "T")) + } + + public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) } + public fun getUUID(index: Int): UUID? = getString(index)?.toUUID() +} diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt index 958dd8a..f667b0f 100644 --- a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt @@ -3,10 +3,7 @@ package app.softwork.sqldelight.postgresdriver import app.cash.sqldelight.* import app.cash.sqldelight.db.* import kotlinx.cinterop.* -import kotlinx.datetime.* -import kotlinx.uuid.* import libpq.* -import kotlin.time.* public class PostgresNativeDriver(private var conn: CPointer) : SqlDriver { private var transaction: Transacter.Transaction? = null @@ -160,7 +157,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } }.check(conn) - val value = PostgresCursor.RealCursor(result, cursorName, conn, fetchSize).use(mapper) + val value = RealCursor(result, cursorName, conn, fetchSize).use(mapper) return QueryResult.Value(value = value) } @@ -217,7 +214,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } }.check(conn) - val value = PostgresCursor.NoCursor(result).use(mapper) + val value = NoCursor(result).use(mapper) return QueryResult.Value(value = value) } @@ -271,17 +268,17 @@ private fun CPointer?.error(): String { return errorMessage } -private fun CPointer?.clear() { +internal fun CPointer?.clear() { PQclear(this) } -private fun CPointer.exec(sql: String) { +internal fun CPointer.exec(sql: String) { val result = PQexec(this, sql) result.check(this) result.clear() } -private fun CPointer?.check(conn: CPointer): CPointer { +internal fun CPointer?.check(conn: CPointer): CPointer { val status = PQresultStatus(this) check(status == PGRES_TUPLES_OK || status == PGRES_COMMAND_OK || status == PGRES_COPY_IN) { conn.error() @@ -289,218 +286,6 @@ private fun CPointer?.check(conn: CPointer): CPointer -) : SqlCursor, Closeable { - internal abstract val currentRowIndex: Int - - /** - * Must be inside a transaction! - */ - internal class RealCursor( - result: CPointer, - private val name: String, - private val conn: CPointer, - private val fetchSize: Int - ) : PostgresCursor(result) { - override fun close() { - result.clear() - conn.exec("CLOSE $name") - conn.exec("END") - } - - override var currentRowIndex = -1 - private var maxRowIndex = -1 - - override fun next(): Boolean { - if (currentRowIndex == maxRowIndex) { - currentRowIndex = -1 - } - if (currentRowIndex == -1) { - result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn) - maxRowIndex = PQntuples(result) - 1 - } - return if (currentRowIndex < maxRowIndex) { - currentRowIndex += 1 - true - } else false - } - } - - internal class NoCursor( - result: CPointer - ) : PostgresCursor(result) { - override fun close() { - result.clear() - } - - private val maxRowIndex = PQntuples(result) - 1 - override var currentRowIndex = -1 - - override fun next(): Boolean { - return if (currentRowIndex < maxRowIndex) { - currentRowIndex += 1 - true - } else { - false - } - } - } - - override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean() - - override fun getBytes(index: Int): ByteArray? { - val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 - return if (isNull) { - null - } else { - val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!! - val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index) - bytes.fromHex(length) - } - } - - private inline fun Int.fromHex(): Int = if (this in 48..57) { - this - 48 - } else { - this - 87 - } - - // because "normal" CPointer.toByteArray() functions does not support hex (2 Bytes) bytes - private fun CPointer.fromHex(length: Int): ByteArray { - val array = ByteArray((length - 2) / 2) - var index = 0 - for (i in 2 until length step 2) { - val first = this[i].toInt().fromHex() - val second = this[i + 1].toInt().fromHex() - val octet = first.shl(4).or(second) - array[index] = octet.toByte() - index++ - } - return array - } - - override fun getDouble(index: Int): Double? = getString(index)?.toDouble() - - override fun getLong(index: Int): Long? = getString(index)?.toLong() - - override fun getString(index: Int): String? { - val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 - return if (isNull) { - null - } else { - val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index) - value!!.toKString() - } - } - - public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate() - public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime() - public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime() - public fun getTimestamp(index: Int): Instant? = getString(index)?.let { - Instant.parse(it.replace(" ", "T")) - } - - public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) } - public fun getUUID(index: Int): UUID? = getString(index)?.toUUID() -} - -public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement { - internal fun values(scope: AutofreeScope): CValuesRef> = createValues(parameters) { - value = when (val value = _values[it]) { - null -> null - is Data.Bytes -> value.bytes.refTo(0).getPointer(scope) - is Data.Text -> value.text.cstr.getPointer(scope) - } - } - - private sealed interface Data { - value class Bytes(val bytes: ByteArray) : Data - value class Text(val text: String) : Data - } - - private val _values = arrayOfNulls(parameters) - internal val lengths = IntArray(parameters) - internal val formats = IntArray(parameters) - internal val types = UIntArray(parameters) - - private fun bind(index: Int, value: String?, oid: UInt) { - lengths[index] = if (value != null) { - _values[index] = Data.Text(value) - value.length - } else 0 - formats[index] = PostgresNativeDriver.TEXT_RESULT_FORMAT - types[index] = oid - } - - override fun bindBoolean(index: Int, boolean: Boolean?) { - bind(index, boolean?.toString(), boolOid) - } - - override fun bindBytes(index: Int, bytes: ByteArray?) { - lengths[index] = if (bytes != null && bytes.isNotEmpty()) { - _values[index] = Data.Bytes(bytes) - bytes.size - } else 0 - formats[index] = PostgresNativeDriver.BINARY_RESULT_FORMAT - types[index] = byteaOid - } - - override fun bindDouble(index: Int, double: Double?) { - bind(index, double?.toString(), doubleOid) - } - - override fun bindLong(index: Int, long: Long?) { - bind(index, long?.toString(), longOid) - } - - override fun bindString(index: Int, string: String?) { - bind(index, string, textOid) - } - - public fun bindDate(index: Int, value: LocalDate?) { - bind(index, value?.toString(), dateOid) - } - - - public fun bindTime(index: Int, value: LocalTime?) { - bind(index, value?.toString(), timeOid) - } - - public fun bindLocalTimestamp(index: Int, value: LocalDateTime?) { - bind(index, value?.toString(), timestampOid) - } - - public fun bindTimestamp(index: Int, value: Instant?) { - bind(index, value?.toString(), timestampTzOid) - } - - public fun bindInterval(index: Int, value: Duration?) { - bind(index, value?.toIsoString(), intervalOid) - } - - public fun bindUUID(index: Int, value: UUID?) { - bind(index, value?.toString(), uuidOid) - } - - private companion object { - // Hardcoded, because not provided in libpq-fe.h for unknown reasons... - // select * from pg_type; - private const val boolOid = 16u - private const val byteaOid = 17u - private const val longOid = 20u - private const val textOid = 25u - private const val doubleOid = 701u - - private const val dateOid = 1082u - private const val timeOid = 1083u - private const val intervalOid = 1186u - private const val timestampOid = 1114u - private const val timestampTzOid = 1184u - private const val uuidOid = 2950u - } -} - public fun PostgresNativeDriver( host: String, database: String, user: String, password: String, port: Int = 5432, options: String? = null ): PostgresNativeDriver { diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt new file mode 100644 index 0000000..21539c9 --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt @@ -0,0 +1,103 @@ +package app.softwork.sqldelight.postgresdriver + +import app.cash.sqldelight.db.* +import kotlinx.cinterop.* +import kotlinx.datetime.* +import kotlinx.uuid.* +import kotlin.time.* + +public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement { + internal fun values(scope: AutofreeScope): CValuesRef> = createValues(parameters) { + value = when (val value = _values[it]) { + null -> null + is Data.Bytes -> value.bytes.refTo(0).getPointer(scope) + is Data.Text -> value.text.cstr.getPointer(scope) + } + } + + private sealed interface Data { + value class Bytes(val bytes: ByteArray) : Data + value class Text(val text: String) : Data + } + + private val _values = arrayOfNulls(parameters) + internal val lengths = IntArray(parameters) + internal val formats = IntArray(parameters) + internal val types = UIntArray(parameters) + + private fun bind(index: Int, value: String?, oid: UInt) { + lengths[index] = if (value != null) { + _values[index] = Data.Text(value) + value.length + } else 0 + formats[index] = PostgresNativeDriver.TEXT_RESULT_FORMAT + types[index] = oid + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + bind(index, boolean?.toString(), boolOid) + } + + override fun bindBytes(index: Int, bytes: ByteArray?) { + lengths[index] = if (bytes != null && bytes.isNotEmpty()) { + _values[index] = Data.Bytes(bytes) + bytes.size + } else 0 + formats[index] = PostgresNativeDriver.BINARY_RESULT_FORMAT + types[index] = byteaOid + } + + override fun bindDouble(index: Int, double: Double?) { + bind(index, double?.toString(), doubleOid) + } + + override fun bindLong(index: Int, long: Long?) { + bind(index, long?.toString(), longOid) + } + + override fun bindString(index: Int, string: String?) { + bind(index, string, textOid) + } + + public fun bindDate(index: Int, value: LocalDate?) { + bind(index, value?.toString(), dateOid) + } + + + public fun bindTime(index: Int, value: LocalTime?) { + bind(index, value?.toString(), timeOid) + } + + public fun bindLocalTimestamp(index: Int, value: LocalDateTime?) { + bind(index, value?.toString(), timestampOid) + } + + public fun bindTimestamp(index: Int, value: Instant?) { + bind(index, value?.toString(), timestampTzOid) + } + + public fun bindInterval(index: Int, value: Duration?) { + bind(index, value?.toIsoString(), intervalOid) + } + + public fun bindUUID(index: Int, value: UUID?) { + bind(index, value?.toString(), uuidOid) + } + + private companion object { + // Hardcoded, because not provided in libpq-fe.h for unknown reasons... + // select * from pg_type; + private const val boolOid = 16u + private const val byteaOid = 17u + private const val longOid = 20u + private const val textOid = 25u + private const val doubleOid = 701u + + private const val dateOid = 1082u + private const val timeOid = 1083u + private const val intervalOid = 1186u + private const val timestampOid = 1114u + private const val timestampTzOid = 1184u + private const val uuidOid = 2950u + } +} diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/RealCursor.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/RealCursor.kt new file mode 100644 index 0000000..caf4070 --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/RealCursor.kt @@ -0,0 +1,37 @@ +package app.softwork.sqldelight.postgresdriver + +import kotlinx.cinterop.* +import libpq.* + +/** + * Must be inside a transaction! + */ +internal class RealCursor( + result: CPointer, + private val name: String, + private val conn: CPointer, + private val fetchSize: Int +) : PostgresCursor(result) { + override fun close() { + result.clear() + conn.exec("CLOSE $name") + conn.exec("END") + } + + override var currentRowIndex = -1 + private var maxRowIndex = -1 + + override fun next(): Boolean { + if (currentRowIndex == maxRowIndex) { + currentRowIndex = -1 + } + if (currentRowIndex == -1) { + result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn) + maxRowIndex = PQntuples(result) - 1 + } + return if (currentRowIndex < maxRowIndex) { + currentRowIndex += 1 + true + } else false + } +} From 2049ff076e451ae99046531e33faa613ffa014b6 Mon Sep 17 00:00:00 2001 From: hfhbd Date: Tue, 18 Oct 2022 12:46:46 +0200 Subject: [PATCH 6/6] Remove public constructor for PreparedStatement --- .../sqldelight/postgresdriver/PostgresPreparedStatement.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt index 21539c9..2b6cd5c 100644 --- a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt @@ -6,7 +6,7 @@ import kotlinx.datetime.* import kotlinx.uuid.* import kotlin.time.* -public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement { +public class PostgresPreparedStatement internal constructor(private val parameters: Int) : SqlPreparedStatement { internal fun values(scope: AutofreeScope): CValuesRef> = createValues(parameters) { value = when (val value = _values[it]) { null -> null