Skip to content

Split code #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package app.softwork.sqldelight.postgresdriver

import kotlinx.cinterop.*
import libpq.*

internal class NoCursor(
result: CPointer<PGresult>
) : PostgresCursor(result) {
override fun close() {
result.clear()
}

private val maxRowIndex = PQntuples(result) - 1
override var currentRowIndex = -1

override fun next(): Boolean {
return if (currentRowIndex < maxRowIndex) {
currentRowIndex += 1
true
} else {
false
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package app.softwork.sqldelight.postgresdriver

import app.cash.sqldelight.db.*
import kotlinx.cinterop.*
import kotlinx.datetime.*
import kotlinx.uuid.*
import libpq.*
import kotlin.time.*

public sealed class PostgresCursor(
internal var result: CPointer<PGresult>
) : SqlCursor, Closeable {
internal abstract val currentRowIndex: Int

override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean()

override fun getBytes(index: Int): ByteArray? {
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
return if (isNull) {
null
} else {
val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!!
val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index)
bytes.fromHex(length)
}
}

private inline fun Int.fromHex(): Int = if (this in 48..57) {
this - 48
} else {
this - 87
}

// because "normal" CPointer<ByteVar>.toByteArray() functions does not support hex (2 Bytes) bytes
private fun CPointer<ByteVar>.fromHex(length: Int): ByteArray {
val array = ByteArray((length - 2) / 2)
var index = 0
for (i in 2 until length step 2) {
val first = this[i].toInt().fromHex()
val second = this[i + 1].toInt().fromHex()
val octet = first.shl(4).or(second)
array[index] = octet.toByte()
index++
}
return array
}

override fun getDouble(index: Int): Double? = getString(index)?.toDouble()

override fun getLong(index: Int): Long? = getString(index)?.toLong()

override fun getString(index: Int): String? {
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
return if (isNull) {
null
} else {
val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)
value!!.toKString()
}
}

public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate()
public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime()
public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime()
public fun getTimestamp(index: Int): Instant? = getString(index)?.let {
Instant.parse(it.replace(" ", "T"))
}

public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) }
public fun getUUID(index: Int): UUID? = getString(index)?.toUUID()
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ package app.softwork.sqldelight.postgresdriver
import app.cash.sqldelight.*
import app.cash.sqldelight.db.*
import kotlinx.cinterop.*
import kotlinx.datetime.*
import kotlinx.uuid.*
import libpq.*
import kotlin.time.*

public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDriver {
private var transaction: Transacter.Transaction? = null
Expand Down Expand Up @@ -160,7 +157,7 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
}
}.check(conn)

val value = PostgresCursor.RealCursor(result, cursorName, conn, fetchSize).use(mapper)
val value = RealCursor(result, cursorName, conn, fetchSize).use(mapper)
return QueryResult.Value(value = value)
}

Expand Down Expand Up @@ -217,7 +214,7 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
}
}.check(conn)

val value = PostgresCursor.NoCursor(result).use(mapper)
val value = NoCursor(result).use(mapper)
return QueryResult.Value(value = value)
}

Expand Down Expand Up @@ -271,236 +268,24 @@ private fun CPointer<PGconn>?.error(): String {
return errorMessage
}

private fun CPointer<PGresult>?.clear() {
internal fun CPointer<PGresult>?.clear() {
PQclear(this)
}

private fun CPointer<PGconn>.exec(sql: String) {
internal fun CPointer<PGconn>.exec(sql: String) {
val result = PQexec(this, sql)
result.check(this)
result.clear()
}

private fun CPointer<PGresult>?.check(conn: CPointer<PGconn>): CPointer<PGresult> {
internal fun CPointer<PGresult>?.check(conn: CPointer<PGconn>): CPointer<PGresult> {
val status = PQresultStatus(this)
check(status == PGRES_TUPLES_OK || status == PGRES_COMMAND_OK || status == PGRES_COPY_IN) {
conn.error()
}
return this!!
}

public sealed class PostgresCursor(
internal var result: CPointer<PGresult>
) : SqlCursor, Closeable {
internal abstract val currentRowIndex: Int

/**
* Must be inside a transaction!
*/
internal class RealCursor(
result: CPointer<PGresult>,
private val name: String,
private val conn: CPointer<PGconn>,
private val fetchSize: Int
) : PostgresCursor(result) {
override fun close() {
result.clear()
conn.exec("CLOSE $name")
conn.exec("END")
}

override var currentRowIndex = -1
private var maxRowIndex = -1

override fun next(): Boolean {
if (currentRowIndex == maxRowIndex) {
currentRowIndex = -1
}
if (currentRowIndex == -1) {
result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn)
maxRowIndex = PQntuples(result) - 1
}
return if (currentRowIndex < maxRowIndex) {
currentRowIndex += 1
true
} else false
}
}

internal class NoCursor(
result: CPointer<PGresult>
) : PostgresCursor(result) {
override fun close() {
result.clear()
}

private val maxRowIndex = PQntuples(result) - 1
override var currentRowIndex = -1

override fun next(): Boolean {
return if (currentRowIndex < maxRowIndex) {
currentRowIndex += 1
true
} else {
false
}
}
}

override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean()

override fun getBytes(index: Int): ByteArray? {
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
return if (isNull) {
null
} else {
val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!!
val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index)
bytes.fromHex(length)
}
}

private inline fun Int.fromHex(): Int = if (this in 48..57) {
this - 48
} else {
this - 87
}

// because "normal" CPointer<ByteVar>.toByteArray() functions does not support hex (2 Bytes) bytes
private fun CPointer<ByteVar>.fromHex(length: Int): ByteArray {
val array = ByteArray((length - 2) / 2)
var index = 0
for (i in 2 until length step 2) {
val first = this[i].toInt().fromHex()
val second = this[i + 1].toInt().fromHex()
val octet = first.shl(4).or(second)
array[index] = octet.toByte()
index++
}
return array
}

override fun getDouble(index: Int): Double? = getString(index)?.toDouble()

override fun getLong(index: Int): Long? = getString(index)?.toLong()

override fun getString(index: Int): String? {
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
return if (isNull) {
null
} else {
val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)
value!!.toKString()
}
}

public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate()
public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime()
public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime()
public fun getTimestamp(index: Int): Instant? = getString(index)?.let {
Instant.parse(it.replace(" ", "T"))
}

public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) }
public fun getUUID(index: Int): UUID? = getString(index)?.toUUID()
}

public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement {
internal fun values(scope: AutofreeScope): CValuesRef<CPointerVar<ByteVar>> = createValues(parameters) {
value = when (val value = _values[it]) {
null -> null
is Data.Bytes -> value.bytes.refTo(0).getPointer(scope)
is Data.Text -> value.text.cstr.getPointer(scope)
}
}

private sealed interface Data {
value class Bytes(val bytes: ByteArray) : Data
value class Text(val text: String) : Data
}

private val _values = arrayOfNulls<Data>(parameters)
internal val lengths = IntArray(parameters)
internal val formats = IntArray(parameters)
internal val types = UIntArray(parameters)

private fun bind(index: Int, value: String?, oid: UInt) {
lengths[index] = if (value != null) {
_values[index] = Data.Text(value)
value.length
} else 0
formats[index] = PostgresNativeDriver.TEXT_RESULT_FORMAT
types[index] = oid
}

override fun bindBoolean(index: Int, boolean: Boolean?) {
bind(index, boolean?.toString(), boolOid)
}

override fun bindBytes(index: Int, bytes: ByteArray?) {
lengths[index] = if (bytes != null && bytes.isNotEmpty()) {
_values[index] = Data.Bytes(bytes)
bytes.size
} else 0
formats[index] = PostgresNativeDriver.BINARY_RESULT_FORMAT
types[index] = byteaOid
}

override fun bindDouble(index: Int, double: Double?) {
bind(index, double?.toString(), doubleOid)
}

override fun bindLong(index: Int, long: Long?) {
bind(index, long?.toString(), longOid)
}

override fun bindString(index: Int, string: String?) {
bind(index, string, textOid)
}

public fun bindDate(index: Int, value: LocalDate?) {
bind(index, value?.toString(), dateOid)
}


public fun bindTime(index: Int, value: LocalTime?) {
bind(index, value?.toString(), timeOid)
}

public fun bindLocalTimestamp(index: Int, value: LocalDateTime?) {
bind(index, value?.toString(), timestampOid)
}

public fun bindTimestamp(index: Int, value: Instant?) {
bind(index, value?.toString(), timestampTzOid)
}

public fun bindInterval(index: Int, value: Duration?) {
bind(index, value?.toIsoString(), intervalOid)
}

public fun bindUUID(index: Int, value: UUID?) {
bind(index, value?.toString(), uuidOid)
}

private companion object {
// Hardcoded, because not provided in libpq-fe.h for unknown reasons...
// select * from pg_type;
private const val boolOid = 16u
private const val byteaOid = 17u
private const val longOid = 20u
private const val textOid = 25u
private const val doubleOid = 701u

private const val dateOid = 1082u
private const val timeOid = 1083u
private const val intervalOid = 1186u
private const val timestampOid = 1114u
private const val timestampTzOid = 1184u
private const val uuidOid = 2950u
}
}

public fun PostgresNativeDriver(
host: String, database: String, user: String, password: String, port: Int = 5432, options: String? = null
): PostgresNativeDriver {
Expand Down
Loading