Skip to content

Commit 2c47b8a

Browse files
authored
Split code (#90)
* Fix INSERT RETURNING * Introduce a native cursor method * Custom cursor size * Custom cursor size test * Split code into multiple files * Remove public constructor for PreparedStatement Co-authored-by: hfhbd <[email protected]>
1 parent 94cba2b commit 2c47b8a

File tree

5 files changed

+240
-220
lines changed

5 files changed

+240
-220
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package app.softwork.sqldelight.postgresdriver
2+
3+
import kotlinx.cinterop.*
4+
import libpq.*
5+
6+
internal class NoCursor(
7+
result: CPointer<PGresult>
8+
) : PostgresCursor(result) {
9+
override fun close() {
10+
result.clear()
11+
}
12+
13+
private val maxRowIndex = PQntuples(result) - 1
14+
override var currentRowIndex = -1
15+
16+
override fun next(): Boolean {
17+
return if (currentRowIndex < maxRowIndex) {
18+
currentRowIndex += 1
19+
true
20+
} else {
21+
false
22+
}
23+
}
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package app.softwork.sqldelight.postgresdriver
2+
3+
import app.cash.sqldelight.db.*
4+
import kotlinx.cinterop.*
5+
import kotlinx.datetime.*
6+
import kotlinx.uuid.*
7+
import libpq.*
8+
import kotlin.time.*
9+
10+
public sealed class PostgresCursor(
11+
internal var result: CPointer<PGresult>
12+
) : SqlCursor, Closeable {
13+
internal abstract val currentRowIndex: Int
14+
15+
override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean()
16+
17+
override fun getBytes(index: Int): ByteArray? {
18+
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
19+
return if (isNull) {
20+
null
21+
} else {
22+
val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!!
23+
val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index)
24+
bytes.fromHex(length)
25+
}
26+
}
27+
28+
private inline fun Int.fromHex(): Int = if (this in 48..57) {
29+
this - 48
30+
} else {
31+
this - 87
32+
}
33+
34+
// because "normal" CPointer<ByteVar>.toByteArray() functions does not support hex (2 Bytes) bytes
35+
private fun CPointer<ByteVar>.fromHex(length: Int): ByteArray {
36+
val array = ByteArray((length - 2) / 2)
37+
var index = 0
38+
for (i in 2 until length step 2) {
39+
val first = this[i].toInt().fromHex()
40+
val second = this[i + 1].toInt().fromHex()
41+
val octet = first.shl(4).or(second)
42+
array[index] = octet.toByte()
43+
index++
44+
}
45+
return array
46+
}
47+
48+
override fun getDouble(index: Int): Double? = getString(index)?.toDouble()
49+
50+
override fun getLong(index: Int): Long? = getString(index)?.toLong()
51+
52+
override fun getString(index: Int): String? {
53+
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
54+
return if (isNull) {
55+
null
56+
} else {
57+
val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)
58+
value!!.toKString()
59+
}
60+
}
61+
62+
public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate()
63+
public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime()
64+
public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime()
65+
public fun getTimestamp(index: Int): Instant? = getString(index)?.let {
66+
Instant.parse(it.replace(" ", "T"))
67+
}
68+
69+
public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) }
70+
public fun getUUID(index: Int): UUID? = getString(index)?.toUUID()
71+
}

postgres-native-sqldelight-driver/src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt

Lines changed: 5 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@ package app.softwork.sqldelight.postgresdriver
33
import app.cash.sqldelight.*
44
import app.cash.sqldelight.db.*
55
import kotlinx.cinterop.*
6-
import kotlinx.datetime.*
7-
import kotlinx.uuid.*
86
import libpq.*
9-
import kotlin.time.*
107

118
public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDriver {
129
private var transaction: Transacter.Transaction? = null
@@ -160,7 +157,7 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
160157
}
161158
}.check(conn)
162159

163-
val value = PostgresCursor.RealCursor(result, cursorName, conn, fetchSize).use(mapper)
160+
val value = RealCursor(result, cursorName, conn, fetchSize).use(mapper)
164161
return QueryResult.Value(value = value)
165162
}
166163

@@ -217,7 +214,7 @@ public class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDrive
217214
}
218215
}.check(conn)
219216

220-
val value = PostgresCursor.NoCursor(result).use(mapper)
217+
val value = NoCursor(result).use(mapper)
221218
return QueryResult.Value(value = value)
222219
}
223220

@@ -271,236 +268,24 @@ private fun CPointer<PGconn>?.error(): String {
271268
return errorMessage
272269
}
273270

274-
private fun CPointer<PGresult>?.clear() {
271+
internal fun CPointer<PGresult>?.clear() {
275272
PQclear(this)
276273
}
277274

278-
private fun CPointer<PGconn>.exec(sql: String) {
275+
internal fun CPointer<PGconn>.exec(sql: String) {
279276
val result = PQexec(this, sql)
280277
result.check(this)
281278
result.clear()
282279
}
283280

284-
private fun CPointer<PGresult>?.check(conn: CPointer<PGconn>): CPointer<PGresult> {
281+
internal fun CPointer<PGresult>?.check(conn: CPointer<PGconn>): CPointer<PGresult> {
285282
val status = PQresultStatus(this)
286283
check(status == PGRES_TUPLES_OK || status == PGRES_COMMAND_OK || status == PGRES_COPY_IN) {
287284
conn.error()
288285
}
289286
return this!!
290287
}
291288

292-
public sealed class PostgresCursor(
293-
internal var result: CPointer<PGresult>
294-
) : SqlCursor, Closeable {
295-
internal abstract val currentRowIndex: Int
296-
297-
/**
298-
* Must be inside a transaction!
299-
*/
300-
internal class RealCursor(
301-
result: CPointer<PGresult>,
302-
private val name: String,
303-
private val conn: CPointer<PGconn>,
304-
private val fetchSize: Int
305-
) : PostgresCursor(result) {
306-
override fun close() {
307-
result.clear()
308-
conn.exec("CLOSE $name")
309-
conn.exec("END")
310-
}
311-
312-
override var currentRowIndex = -1
313-
private var maxRowIndex = -1
314-
315-
override fun next(): Boolean {
316-
if (currentRowIndex == maxRowIndex) {
317-
currentRowIndex = -1
318-
}
319-
if (currentRowIndex == -1) {
320-
result = PQexec(conn, "FETCH $fetchSize IN $name").check(conn)
321-
maxRowIndex = PQntuples(result) - 1
322-
}
323-
return if (currentRowIndex < maxRowIndex) {
324-
currentRowIndex += 1
325-
true
326-
} else false
327-
}
328-
}
329-
330-
internal class NoCursor(
331-
result: CPointer<PGresult>
332-
) : PostgresCursor(result) {
333-
override fun close() {
334-
result.clear()
335-
}
336-
337-
private val maxRowIndex = PQntuples(result) - 1
338-
override var currentRowIndex = -1
339-
340-
override fun next(): Boolean {
341-
return if (currentRowIndex < maxRowIndex) {
342-
currentRowIndex += 1
343-
true
344-
} else {
345-
false
346-
}
347-
}
348-
}
349-
350-
override fun getBoolean(index: Int): Boolean? = getString(index)?.toBoolean()
351-
352-
override fun getBytes(index: Int): ByteArray? {
353-
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
354-
return if (isNull) {
355-
null
356-
} else {
357-
val bytes = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)!!
358-
val length = PQgetlength(result, tup_num = currentRowIndex, field_num = index)
359-
bytes.fromHex(length)
360-
}
361-
}
362-
363-
private inline fun Int.fromHex(): Int = if (this in 48..57) {
364-
this - 48
365-
} else {
366-
this - 87
367-
}
368-
369-
// because "normal" CPointer<ByteVar>.toByteArray() functions does not support hex (2 Bytes) bytes
370-
private fun CPointer<ByteVar>.fromHex(length: Int): ByteArray {
371-
val array = ByteArray((length - 2) / 2)
372-
var index = 0
373-
for (i in 2 until length step 2) {
374-
val first = this[i].toInt().fromHex()
375-
val second = this[i + 1].toInt().fromHex()
376-
val octet = first.shl(4).or(second)
377-
array[index] = octet.toByte()
378-
index++
379-
}
380-
return array
381-
}
382-
383-
override fun getDouble(index: Int): Double? = getString(index)?.toDouble()
384-
385-
override fun getLong(index: Int): Long? = getString(index)?.toLong()
386-
387-
override fun getString(index: Int): String? {
388-
val isNull = PQgetisnull(result, tup_num = currentRowIndex, field_num = index) == 1
389-
return if (isNull) {
390-
null
391-
} else {
392-
val value = PQgetvalue(result, tup_num = currentRowIndex, field_num = index)
393-
value!!.toKString()
394-
}
395-
}
396-
397-
public fun getDate(index: Int): LocalDate? = getString(index)?.toLocalDate()
398-
public fun getTime(index: Int): LocalTime? = getString(index)?.toLocalTime()
399-
public fun getLocalTimestamp(index: Int): LocalDateTime? = getString(index)?.replace(" ", "T")?.toLocalDateTime()
400-
public fun getTimestamp(index: Int): Instant? = getString(index)?.let {
401-
Instant.parse(it.replace(" ", "T"))
402-
}
403-
404-
public fun getInterval(index: Int): Duration? = getString(index)?.let { Duration.parseIsoString(it) }
405-
public fun getUUID(index: Int): UUID? = getString(index)?.toUUID()
406-
}
407-
408-
public class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement {
409-
internal fun values(scope: AutofreeScope): CValuesRef<CPointerVar<ByteVar>> = createValues(parameters) {
410-
value = when (val value = _values[it]) {
411-
null -> null
412-
is Data.Bytes -> value.bytes.refTo(0).getPointer(scope)
413-
is Data.Text -> value.text.cstr.getPointer(scope)
414-
}
415-
}
416-
417-
private sealed interface Data {
418-
value class Bytes(val bytes: ByteArray) : Data
419-
value class Text(val text: String) : Data
420-
}
421-
422-
private val _values = arrayOfNulls<Data>(parameters)
423-
internal val lengths = IntArray(parameters)
424-
internal val formats = IntArray(parameters)
425-
internal val types = UIntArray(parameters)
426-
427-
private fun bind(index: Int, value: String?, oid: UInt) {
428-
lengths[index] = if (value != null) {
429-
_values[index] = Data.Text(value)
430-
value.length
431-
} else 0
432-
formats[index] = PostgresNativeDriver.TEXT_RESULT_FORMAT
433-
types[index] = oid
434-
}
435-
436-
override fun bindBoolean(index: Int, boolean: Boolean?) {
437-
bind(index, boolean?.toString(), boolOid)
438-
}
439-
440-
override fun bindBytes(index: Int, bytes: ByteArray?) {
441-
lengths[index] = if (bytes != null && bytes.isNotEmpty()) {
442-
_values[index] = Data.Bytes(bytes)
443-
bytes.size
444-
} else 0
445-
formats[index] = PostgresNativeDriver.BINARY_RESULT_FORMAT
446-
types[index] = byteaOid
447-
}
448-
449-
override fun bindDouble(index: Int, double: Double?) {
450-
bind(index, double?.toString(), doubleOid)
451-
}
452-
453-
override fun bindLong(index: Int, long: Long?) {
454-
bind(index, long?.toString(), longOid)
455-
}
456-
457-
override fun bindString(index: Int, string: String?) {
458-
bind(index, string, textOid)
459-
}
460-
461-
public fun bindDate(index: Int, value: LocalDate?) {
462-
bind(index, value?.toString(), dateOid)
463-
}
464-
465-
466-
public fun bindTime(index: Int, value: LocalTime?) {
467-
bind(index, value?.toString(), timeOid)
468-
}
469-
470-
public fun bindLocalTimestamp(index: Int, value: LocalDateTime?) {
471-
bind(index, value?.toString(), timestampOid)
472-
}
473-
474-
public fun bindTimestamp(index: Int, value: Instant?) {
475-
bind(index, value?.toString(), timestampTzOid)
476-
}
477-
478-
public fun bindInterval(index: Int, value: Duration?) {
479-
bind(index, value?.toIsoString(), intervalOid)
480-
}
481-
482-
public fun bindUUID(index: Int, value: UUID?) {
483-
bind(index, value?.toString(), uuidOid)
484-
}
485-
486-
private companion object {
487-
// Hardcoded, because not provided in libpq-fe.h for unknown reasons...
488-
// select * from pg_type;
489-
private const val boolOid = 16u
490-
private const val byteaOid = 17u
491-
private const val longOid = 20u
492-
private const val textOid = 25u
493-
private const val doubleOid = 701u
494-
495-
private const val dateOid = 1082u
496-
private const val timeOid = 1083u
497-
private const val intervalOid = 1186u
498-
private const val timestampOid = 1114u
499-
private const val timestampTzOid = 1184u
500-
private const val uuidOid = 2950u
501-
}
502-
}
503-
504289
public fun PostgresNativeDriver(
505290
host: String, database: String, user: String, password: String, port: Int = 5432, options: String? = null
506291
): PostgresNativeDriver {

0 commit comments

Comments
 (0)