diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/NoCursor.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/NoCursor.kt new file mode 100644 index 0000000..bcb6e19 --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/NoCursor.kt @@ -0,0 +1,24 @@ +package app.softwork.sqldelight.postgresdriver + +import kotlinx.cinterop.* +import libpq.* + +internal class NoCursor( + result: CPointer +) : PostgresCursor(result) { + override fun close() { + result.clear() + } + + private val maxRowIndex = PQntuples(result) - 1 + override var currentRowIndex = -1 + + override fun next(): Boolean { + return if (currentRowIndex < maxRowIndex) { + currentRowIndex += 1 + true + } else { + false + } + } +} diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresCursor.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresCursor.kt new file mode 100644 index 0000000..2055ee4 --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresCursor.kt @@ -0,0 +1,71 @@ +package app.softwork.sqldelight.postgresdriver + +import app.cash.sqldelight.db.* +import kotlinx.cinterop.* +import kotlinx.datetime.* +import kotlinx.uuid.* +import libpq.* +import kotlin.time.* + +public sealed class PostgresCursor( + internal var result: CPointer +) : SqlCursor, Closeable { + internal abstract val currentRowIndex: Int + + override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean() + + override fun getBytes(index: Int): ByteArray? { + val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 + return if (isNull) { + null + } else { + val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!! + val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index) + bytes.fromHex(length) + } + } + + private inline fun Int.fromHex(): Int = if (this in 48..57) { + this - 48 + } else { + this - 87 + } + + // because "normal" CPointer.toByteArray() functions does not support hex (2 Bytes) bytes + private fun CPointer.fromHex(length: Int): ByteArray { + val array = ByteArray((length - 2) / 2) + var index = 0 + for (i in 2 until length step 2) { + val first = this[i].toInt().fromHex() + val second = this[i + 1].toInt().fromHex() + val octet = first.shl(4).or(second) + array[index] = octet.toByte() + index++ + } + return array + } + + override fun getDouble(index: Int): Double? = getString(index)?.toDouble() + + override fun getLong(index: Int): Long? = getString(index)?.toLong() + + override fun getString(index: Int): String? { + val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 + return if (isNull) { + null + } else { + val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index) + value!!.toKString() + } + } + + public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate() + public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime() + public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime() + public fun getTimestamp(index: Int): Instant? = getString(index)?.let { + Instant.parse(it.replace(" ", "T")) + } + + public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) } + public fun getUUID(index: Int): UUID? = getString(index)?.toUUID() +} diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt index 958dd8a..f667b0f 100644 --- a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt @@ -3,10 +3,7 @@ package app.softwork.sqldelight.postgresdriver import app.cash.sqldelight.* import app.cash.sqldelight.db.* import kotlinx.cinterop.* -import kotlinx.datetime.* -import kotlinx.uuid.* import libpq.* -import kotlin.time.* public class PostgresNativeDriver(private var conn: CPointer) : SqlDriver { private var transaction: Transacter.Transaction? = null @@ -160,7 +157,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } }.check(conn) - val value = PostgresCursor.RealCursor(result, cursorName, conn, fetchSize).use(mapper) + val value = RealCursor(result, cursorName, conn, fetchSize).use(mapper) return QueryResult.Value(value = value) } @@ -217,7 +214,7 @@ public class PostgresNativeDriver(private var conn: CPointer) : SqlDrive } }.check(conn) - val value = PostgresCursor.NoCursor(result).use(mapper) + val value = NoCursor(result).use(mapper) return QueryResult.Value(value = value) } @@ -271,17 +268,17 @@ private fun CPointer?.error(): String { return errorMessage } -private fun CPointer?.clear() { +internal fun CPointer?.clear() { PQclear(this) } -private fun CPointer.exec(sql: String) { +internal fun CPointer.exec(sql: String) { val result = PQexec(this, sql) result.check(this) result.clear() } -private fun CPointer?.check(conn: CPointer): CPointer { +internal fun CPointer?.check(conn: CPointer): CPointer { val status = PQresultStatus(this) check(status == PGRES_TUPLES_OK || status == PGRES_COMMAND_OK || status == PGRES_COPY_IN) { conn.error() @@ -289,218 +286,6 @@ private fun CPointer?.check(conn: CPointer): CPointer -) : SqlCursor, Closeable { - internal abstract val currentRowIndex: Int - - /** - * Must be inside a transaction! - */ - internal class RealCursor( - result: CPointer, - private val name: String, - private val conn: CPointer, - private val fetchSize: Int - ) : PostgresCursor(result) { - override fun close() { - result.clear() - conn.exec("CLOSE $name") - conn.exec("END") - } - - override var currentRowIndex = -1 - private var maxRowIndex = -1 - - override fun next(): Boolean { - if (currentRowIndex == maxRowIndex) { - currentRowIndex = -1 - } - if (currentRowIndex == -1) { - result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn) - maxRowIndex = PQntuples(result) - 1 - } - return if (currentRowIndex < maxRowIndex) { - currentRowIndex += 1 - true - } else false - } - } - - internal class NoCursor( - result: CPointer - ) : PostgresCursor(result) { - override fun close() { - result.clear() - } - - private val maxRowIndex = PQntuples(result) - 1 - override var currentRowIndex = -1 - - override fun next(): Boolean { - return if (currentRowIndex < maxRowIndex) { - currentRowIndex += 1 - true - } else { - false - } - } - } - - override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean() - - override fun getBytes(index: Int): ByteArray? { - val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 - return if (isNull) { - null - } else { - val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!! - val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index) - bytes.fromHex(length) - } - } - - private inline fun Int.fromHex(): Int = if (this in 48..57) { - this - 48 - } else { - this - 87 - } - - // because "normal" CPointer.toByteArray() functions does not support hex (2 Bytes) bytes - private fun CPointer.fromHex(length: Int): ByteArray { - val array = ByteArray((length - 2) / 2) - var index = 0 - for (i in 2 until length step 2) { - val first = this[i].toInt().fromHex() - val second = this[i + 1].toInt().fromHex() - val octet = first.shl(4).or(second) - array[index] = octet.toByte() - index++ - } - return array - } - - override fun getDouble(index: Int): Double? = getString(index)?.toDouble() - - override fun getLong(index: Int): Long? = getString(index)?.toLong() - - override fun getString(index: Int): String? { - val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1 - return if (isNull) { - null - } else { - val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index) - value!!.toKString() - } - } - - public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate() - public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime() - public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime() - public fun getTimestamp(index: Int): Instant? = getString(index)?.let { - Instant.parse(it.replace(" ", "T")) - } - - public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) } - public fun getUUID(index: Int): UUID? = getString(index)?.toUUID() -} - -public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement { - internal fun values(scope: AutofreeScope): CValuesRef> = createValues(parameters) { - value = when (val value = _values[it]) { - null -> null - is Data.Bytes -> value.bytes.refTo(0).getPointer(scope) - is Data.Text -> value.text.cstr.getPointer(scope) - } - } - - private sealed interface Data { - value class Bytes(val bytes: ByteArray) : Data - value class Text(val text: String) : Data - } - - private val _values = arrayOfNulls(parameters) - internal val lengths = IntArray(parameters) - internal val formats = IntArray(parameters) - internal val types = UIntArray(parameters) - - private fun bind(index: Int, value: String?, oid: UInt) { - lengths[index] = if (value != null) { - _values[index] = Data.Text(value) - value.length - } else 0 - formats[index] = PostgresNativeDriver.TEXT_RESULT_FORMAT - types[index] = oid - } - - override fun bindBoolean(index: Int, boolean: Boolean?) { - bind(index, boolean?.toString(), boolOid) - } - - override fun bindBytes(index: Int, bytes: ByteArray?) { - lengths[index] = if (bytes != null && bytes.isNotEmpty()) { - _values[index] = Data.Bytes(bytes) - bytes.size - } else 0 - formats[index] = PostgresNativeDriver.BINARY_RESULT_FORMAT - types[index] = byteaOid - } - - override fun bindDouble(index: Int, double: Double?) { - bind(index, double?.toString(), doubleOid) - } - - override fun bindLong(index: Int, long: Long?) { - bind(index, long?.toString(), longOid) - } - - override fun bindString(index: Int, string: String?) { - bind(index, string, textOid) - } - - public fun bindDate(index: Int, value: LocalDate?) { - bind(index, value?.toString(), dateOid) - } - - - public fun bindTime(index: Int, value: LocalTime?) { - bind(index, value?.toString(), timeOid) - } - - public fun bindLocalTimestamp(index: Int, value: LocalDateTime?) { - bind(index, value?.toString(), timestampOid) - } - - public fun bindTimestamp(index: Int, value: Instant?) { - bind(index, value?.toString(), timestampTzOid) - } - - public fun bindInterval(index: Int, value: Duration?) { - bind(index, value?.toIsoString(), intervalOid) - } - - public fun bindUUID(index: Int, value: UUID?) { - bind(index, value?.toString(), uuidOid) - } - - private companion object { - // Hardcoded, because not provided in libpq-fe.h for unknown reasons... - // select * from pg_type; - private const val boolOid = 16u - private const val byteaOid = 17u - private const val longOid = 20u - private const val textOid = 25u - private const val doubleOid = 701u - - private const val dateOid = 1082u - private const val timeOid = 1083u - private const val intervalOid = 1186u - private const val timestampOid = 1114u - private const val timestampTzOid = 1184u - private const val uuidOid = 2950u - } -} - public fun PostgresNativeDriver( host: String, database: String, user: String, password: String, port: Int = 5432, options: String? = null ): PostgresNativeDriver { diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt new file mode 100644 index 0000000..2b6cd5c --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresPreparedStatement.kt @@ -0,0 +1,103 @@ +package app.softwork.sqldelight.postgresdriver + +import app.cash.sqldelight.db.* +import kotlinx.cinterop.* +import kotlinx.datetime.* +import kotlinx.uuid.* +import kotlin.time.* + +public class PostgresPreparedStatement internal constructor(private val parameters: Int) : SqlPreparedStatement { + internal fun values(scope: AutofreeScope): CValuesRef> = createValues(parameters) { + value = when (val value = _values[it]) { + null -> null + is Data.Bytes -> value.bytes.refTo(0).getPointer(scope) + is Data.Text -> value.text.cstr.getPointer(scope) + } + } + + private sealed interface Data { + value class Bytes(val bytes: ByteArray) : Data + value class Text(val text: String) : Data + } + + private val _values = arrayOfNulls(parameters) + internal val lengths = IntArray(parameters) + internal val formats = IntArray(parameters) + internal val types = UIntArray(parameters) + + private fun bind(index: Int, value: String?, oid: UInt) { + lengths[index] = if (value != null) { + _values[index] = Data.Text(value) + value.length + } else 0 + formats[index] = PostgresNativeDriver.TEXT_RESULT_FORMAT + types[index] = oid + } + + override fun bindBoolean(index: Int, boolean: Boolean?) { + bind(index, boolean?.toString(), boolOid) + } + + override fun bindBytes(index: Int, bytes: ByteArray?) { + lengths[index] = if (bytes != null && bytes.isNotEmpty()) { + _values[index] = Data.Bytes(bytes) + bytes.size + } else 0 + formats[index] = PostgresNativeDriver.BINARY_RESULT_FORMAT + types[index] = byteaOid + } + + override fun bindDouble(index: Int, double: Double?) { + bind(index, double?.toString(), doubleOid) + } + + override fun bindLong(index: Int, long: Long?) { + bind(index, long?.toString(), longOid) + } + + override fun bindString(index: Int, string: String?) { + bind(index, string, textOid) + } + + public fun bindDate(index: Int, value: LocalDate?) { + bind(index, value?.toString(), dateOid) + } + + + public fun bindTime(index: Int, value: LocalTime?) { + bind(index, value?.toString(), timeOid) + } + + public fun bindLocalTimestamp(index: Int, value: LocalDateTime?) { + bind(index, value?.toString(), timestampOid) + } + + public fun bindTimestamp(index: Int, value: Instant?) { + bind(index, value?.toString(), timestampTzOid) + } + + public fun bindInterval(index: Int, value: Duration?) { + bind(index, value?.toIsoString(), intervalOid) + } + + public fun bindUUID(index: Int, value: UUID?) { + bind(index, value?.toString(), uuidOid) + } + + private companion object { + // Hardcoded, because not provided in libpq-fe.h for unknown reasons... + // select * from pg_type; + private const val boolOid = 16u + private const val byteaOid = 17u + private const val longOid = 20u + private const val textOid = 25u + private const val doubleOid = 701u + + private const val dateOid = 1082u + private const val timeOid = 1083u + private const val intervalOid = 1186u + private const val timestampOid = 1114u + private const val timestampTzOid = 1184u + private const val uuidOid = 2950u + } +} diff --git a/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/RealCursor.kt b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/RealCursor.kt new file mode 100644 index 0000000..caf4070 --- /dev/null +++ b/postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/RealCursor.kt @@ -0,0 +1,37 @@ +package app.softwork.sqldelight.postgresdriver + +import kotlinx.cinterop.* +import libpq.* + +/** + * Must be inside a transaction! + */ +internal class RealCursor( + result: CPointer, + private val name: String, + private val conn: CPointer, + private val fetchSize: Int +) : PostgresCursor(result) { + override fun close() { + result.clear() + conn.exec("CLOSE $name") + conn.exec("END") + } + + override var currentRowIndex = -1 + private var maxRowIndex = -1 + + override fun next(): Boolean { + if (currentRowIndex == maxRowIndex) { + currentRowIndex = -1 + } + if (currentRowIndex == -1) { + result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn) + maxRowIndex = PQntuples(result) - 1 + } + return if (currentRowIndex < maxRowIndex) { + currentRowIndex += 1 + true + } else false + } +}