Skip to content

Commit 359a939

Browse files
authored
Fix binary type (#8)
Co-authored-by: hfhbd <[email protected]>
1 parent 6585411 commit 359a939

File tree

2 files changed

+147
-33
lines changed

2 files changed

+147
-33
lines changed

src/commonMain/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriver.kt

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDriver {
5656
paramValues = preparedStatement?.values(this),
5757
paramFormats = preparedStatement?.formats?.refTo(0),
5858
paramLengths = preparedStatement?.lengths?.refTo(0),
59-
resultFormat = TEXT_RESULT_FORMAT
59+
resultFormat = BINARY_RESULT_FORMAT
6060
)
6161
}
6262
} else {
@@ -68,7 +68,7 @@ class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDriver {
6868
paramValues = preparedStatement?.values(this),
6969
paramFormats = preparedStatement?.formats?.refTo(0),
7070
paramLengths = preparedStatement?.lengths?.refTo(0),
71-
resultFormat = TEXT_RESULT_FORMAT,
71+
resultFormat = BINARY_RESULT_FORMAT,
7272
paramTypes = preparedStatement?.types?.refTo(0)
7373
)
7474
}
@@ -85,7 +85,7 @@ class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDriver {
8585
parameters: Int,
8686
binders: (SqlPreparedStatement.() -> Unit)?
8787
): R {
88-
val cursorName = identifier?.toString() ?: "myCursor"
88+
val cursorName = if (identifier == null) "myCursor" else "cursor$identifier"
8989
val cursor = "DECLARE $cursorName CURSOR FOR"
9090
val preparedStatement = if (parameters != 0) {
9191
PostgresPreparedStatement(parameters).apply {
@@ -109,7 +109,7 @@ class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDriver {
109109
paramValues = preparedStatement?.values(this),
110110
paramLengths = preparedStatement?.lengths?.refTo(0),
111111
paramFormats = preparedStatement?.formats?.refTo(0),
112-
resultFormat = TEXT_RESULT_FORMAT
112+
resultFormat = BINARY_RESULT_FORMAT
113113
)
114114
}
115115
} else {
@@ -123,7 +123,7 @@ class PostgresNativeDriver(private var conn: CPointer<PGconn>) : SqlDriver {
123123
paramLengths = preparedStatement?.lengths?.refTo(0),
124124
paramFormats = preparedStatement?.formats?.refTo(0),
125125
paramTypes = preparedStatement?.types?.refTo(0),
126-
resultFormat = TEXT_RESULT_FORMAT
126+
resultFormat = BINARY_RESULT_FORMAT
127127
)
128128
}
129129
}.check(conn)
@@ -179,7 +179,7 @@ private fun CPointer<PGconn>.exec(sql: String) {
179179

180180
private fun CPointer<PGresult>?.check(conn: CPointer<PGconn>): CPointer<PGresult> {
181181
val status = PQresultStatus(this)
182-
require(status == PGRES_TUPLES_OK || status == PGRES_COMMAND_OK) {
182+
check(status == PGRES_TUPLES_OK || status == PGRES_COMMAND_OK) {
183183
conn.error()
184184
}
185185
return this!!
@@ -192,8 +192,7 @@ class PostgresCursor(
192192
private var result: CPointer<PGresult>,
193193
private val name: String,
194194
private val conn: CPointer<PGconn>
195-
) :
196-
SqlCursor, Closeable {
195+
) : SqlCursor, Closeable {
197196
override fun close() {
198197
result.clear()
199198
conn.exec("CLOSE $name")
@@ -202,7 +201,36 @@ class PostgresCursor(
202201

203202
override fun getBoolean(index: Int) = getString(index)?.toBoolean()
204203

205-
override fun getBytes(index: Int) = getString(index)?.encodeToByteArray()
204+
override fun getBytes(index: Int): ByteArray? {
205+
val isNull = PQgetisnull(result, tup_num = 0, field_num = index) == 1
206+
return if (isNull) {
207+
null
208+
} else {
209+
val bytes = PQgetvalue(result, tup_num = 0, field_num = index)!!
210+
val length = PQgetlength(result, tup_num = 0, field_num = index)
211+
bytes.fromHex(length)
212+
}
213+
}
214+
215+
private inline fun Int.fromHex(): Int = if (this in 48..57) {
216+
this - 48
217+
} else {
218+
this - 97
219+
}
220+
221+
// because "normal" CPointer<ByteVar>.toByteArray() functions does not support hex (2 Bytes) bytes
222+
private fun CPointer<ByteVar>.fromHex(length: Int): ByteArray {
223+
val array = ByteArray((length - 2) / 2)
224+
var index = 0
225+
for (i in 2 until length step 2) {
226+
val first = this[i].toInt().fromHex()
227+
val second = this[i + 1].toInt().fromHex()
228+
val octet = first.shl(4).or(second)
229+
array[index] = octet.toByte()
230+
index++
231+
}
232+
return array
233+
}
206234

207235
override fun getDouble(index: Int) = getString(index)?.toDouble()
208236

@@ -226,17 +254,26 @@ class PostgresCursor(
226254

227255
class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatement {
228256
fun values(scope: AutofreeScope): CValuesRef<CPointerVar<ByteVar>> = createValues(parameters) {
229-
value = _values[it]?.cstr?.getPointer(scope)
257+
value = when (val value = _values[it]) {
258+
null -> null
259+
is Data.Bytes -> value.bytes.refTo(0).getPointer(scope)
260+
is Data.Text -> value.text.cstr.getPointer(scope)
261+
}
230262
}
231263

232-
private val _values = arrayOfNulls<String>(parameters)
264+
private sealed interface Data {
265+
inline class Bytes(val bytes: ByteArray) : Data
266+
inline class Text(val text: String) : Data
267+
}
268+
269+
private val _values = arrayOfNulls<Data>(parameters)
233270
val lengths = IntArray(parameters)
234271
val formats = IntArray(parameters)
235272
val types = UIntArray(parameters)
236273

237274
private fun bind(index: Int, value: String?, oid: UInt) {
238275
lengths[index] = if (value != null) {
239-
_values[index] = value
276+
_values[index] = Data.Text(value)
240277
value.length
241278
} else 0
242279
formats[index] = PostgresNativeDriver.TEXT_RESULT_FORMAT
@@ -248,7 +285,12 @@ class PostgresPreparedStatement(private val parameters: Int) : SqlPreparedStatem
248285
}
249286

250287
override fun bindBytes(index: Int, bytes: ByteArray?) {
251-
bind(index, bytes?.decodeToString(), byteaOid)
288+
lengths[index] = if (bytes != null && bytes.isNotEmpty()) {
289+
_values[index] = Data.Bytes(bytes)
290+
bytes.size
291+
} else 0
292+
formats[index] = PostgresNativeDriver.BINARY_RESULT_FORMAT
293+
types[index] = byteaOid
252294
}
253295

254296
override fun bindDouble(index: Int, double: Double?) {
@@ -278,12 +320,12 @@ fun PostgresNativeDriver(
278320
database: String,
279321
user: String,
280322
password: String,
281-
port: Int? = null,
323+
port: Int = 5432,
282324
options: String? = null
283325
): PostgresNativeDriver {
284326
val conn = PQsetdbLogin(
285327
pghost = host,
286-
pgport = port?.toString(),
328+
pgport = port.toString(),
287329
pgtty = null,
288330
dbName = database,
289331
login = user,

src/commonTest/kotlin/app/softwork/sqldelight/postgresdriver/PostgresNativeDriverTest.kt

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,109 @@ class PostgresNativeDriverTest {
1212
database = "postgres",
1313
password = "password"
1414
)
15-
driver.execute(null, "DROP TABLE IF EXISTS foo;", parameters = 0)
16-
driver.execute(null, "CREATE TABLE foo(a int primary key, bar text);", parameters = 0)
15+
assertEquals(0, driver.execute(null, "DROP TABLE IF EXISTS baz;", parameters = 0))
16+
assertEquals(0, driver.execute(null, "CREATE TABLE baz(a int primary key, foo text, b bytea);", parameters = 0))
1717
repeat(5) {
18-
driver.execute(null, "INSERT INTO foo VALUES ($it, 'a')", parameters = 0)
19-
}
20-
repeat(5) {
21-
driver.execute(null, "INSERT INTO foo VALUES ($1, $2)", parameters = 2) {
22-
bindLong(0, 5 + it.toLong())
23-
bindString(1, "bar $it")
24-
}
18+
val result = driver.execute(null, "INSERT INTO baz VALUES ($it)", parameters = 0)
19+
assertEquals(1, result)
2520
}
2621

27-
assertEquals(1, driver.execute(null, "SELECT * from foo limit 1;", parameters = 0))
28-
val s = driver.executeQuery(
29-
null,
30-
sql = "SELECT * FROM foo;",
22+
val result = driver.execute(null, "INSERT INTO baz VALUES ($1, $2, $3), ($4, $5, $6)", parameters = 6) {
23+
bindLong(0, 5)
24+
bindString(1, "bar 0")
25+
bindBytes(2, byteArrayOf(1.toByte(), 2.toByte()))
26+
27+
bindLong(3, 6)
28+
bindString(4, "bar 1")
29+
bindBytes(5, null)
30+
}
31+
assertEquals(2, result)
32+
val notPrepared = driver.executeQuery(null, "SELECT * from baz limit 1;", parameters = 0, mapper = {
33+
assertTrue(it.next())
34+
Simple(
35+
index = it.getLong(0)!!.toInt(),
36+
name = it.getString(1),
37+
byteArray = it.getBytes(2)
38+
)
39+
})
40+
assertEquals(Simple(0, null, null), notPrepared)
41+
val preparedStatement = driver.executeQuery(
42+
42,
43+
sql = "SELECT * FROM baz;",
3144
parameters = 0, binders = null,
3245
mapper = {
3346
buildList {
3447
while (it.next()) {
3548
add(
36-
S(
37-
index = it.getLong(0)!!,
38-
name = it.getString(1)!!
49+
Simple(
50+
index = it.getLong(0)!!.toInt(),
51+
name = it.getString(1),
52+
byteArray = it.getBytes(2)
3953
)
4054
)
4155
}
4256
}
4357
})
44-
assertEquals(10, s.size)
58+
assertEquals(
59+
List(5) {
60+
Simple(it, null, null)
61+
} + listOf(
62+
Simple(5, "bar 0", byteArrayOf(1.toByte(), 2.toByte())),
63+
Simple(6, "bar 1", null),
64+
),
65+
preparedStatement
66+
)
67+
}
68+
69+
data class Simple(val index: Int, val name: String?, val byteArray: ByteArray?) {
70+
override fun equals(other: Any?): Boolean {
71+
if (this === other) return true
72+
73+
other as Simple
74+
75+
if (index != other.index) return false
76+
if (name != other.name) return false
77+
if (byteArray != null) {
78+
if (other.byteArray == null) return false
79+
if (!byteArray.contentEquals(other.byteArray)) return false
80+
} else if (other.byteArray != null) return false
81+
82+
return true
83+
}
84+
85+
override fun hashCode(): Int {
86+
var result = index.hashCode()
87+
result = 31 * result + (name?.hashCode() ?: 0)
88+
result = 31 * result + (byteArray?.contentHashCode() ?: 0)
89+
return result
90+
}
4591
}
4692

47-
data class S(val index: Long, val name: String)
93+
@Test
94+
fun wrongCredentials() {
95+
assertFailsWith<IllegalArgumentException> {
96+
PostgresNativeDriver(
97+
host = "wrongHost",
98+
user = "postgres",
99+
database = "postgres",
100+
password = "password"
101+
)
102+
}
103+
assertFailsWith<IllegalArgumentException> {
104+
PostgresNativeDriver(
105+
host = "localhost",
106+
user = "postgres",
107+
database = "postgres",
108+
password = "wrongPassword"
109+
)
110+
}
111+
assertFailsWith<IllegalArgumentException> {
112+
PostgresNativeDriver(
113+
host = "localhost",
114+
user = "wrongUser",
115+
database = "postgres",
116+
password = "password"
117+
)
118+
}
119+
}
48120
}

0 commit comments

Comments
 (0)