|
| 1 | +package app.softwork.sqldelight.postgresdialect |
| 2 | + |
| 3 | +import app.cash.sqldelight.dialect.api.* |
| 4 | +import app.cash.sqldelight.dialects.postgresql.* |
| 5 | +import app.cash.sqldelight.dialects.postgresql.grammar.psi.* |
| 6 | +import com.alecstrong.sql.psi.core.psi.* |
| 7 | +import com.squareup.kotlinpoet.* |
| 8 | +import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy |
| 9 | + |
| 10 | +class PostgresNativeDialect : PostgreSqlDialect() { |
| 11 | + override val runtimeTypes = RuntimeTypes( |
| 12 | + driverType = ClassName("app.softwork.sqldelight.postgresdriver", "PostgresNativeDriver"), |
| 13 | + cursorType = ClassName("app.softwork.sqldelight.postgresdriver", "PostgresCursor"), |
| 14 | + preparedStatementType = ClassName("app.softwork.sqldelight.postgresdriver", "PostgresPreparedStatement") |
| 15 | + ) |
| 16 | + |
| 17 | + override fun typeResolver(parentResolver: TypeResolver): TypeResolver = PostgresNativeTypeResolver(parentResolver) |
| 18 | + |
| 19 | + class PostgresNativeTypeResolver(parentResolver: TypeResolver) : PostgreSqlTypeResolver(parentResolver) { |
| 20 | + override fun definitionType(typeName: SqlTypeName): IntermediateType = with(typeName) { |
| 21 | + check(this is PostgreSqlTypeName) |
| 22 | + val type = IntermediateType( |
| 23 | + when { |
| 24 | + smallIntDataType != null -> PostgreSqlType.SMALL_INT |
| 25 | + intDataType != null -> PostgreSqlType.INTEGER |
| 26 | + bigIntDataType != null -> PostgreSqlType.BIG_INT |
| 27 | + approximateNumericDataType != null -> PrimitiveType.REAL |
| 28 | + stringDataType != null -> PrimitiveType.TEXT |
| 29 | + uuidDataType != null -> PostgreSqlType.UUID |
| 30 | + smallSerialDataType != null -> PostgreSqlType.SMALL_INT |
| 31 | + serialDataType != null -> PostgreSqlType.INTEGER |
| 32 | + bigSerialDataType != null -> PostgreSqlType.BIG_INT |
| 33 | + dateDataType != null -> { |
| 34 | + when (dateDataType!!.firstChild.text) { |
| 35 | + "DATE" -> PostgreSqlType.DATE |
| 36 | + //"TIME" -> PostgreSqlType.TIME |
| 37 | + "TIMESTAMP" -> if (dateDataType!!.node.getChildren(null) |
| 38 | + .any { it.text == "WITH" } |
| 39 | + ) PostgreSqlType.TIMESTAMP_TIMEZONE else PostgreSqlType.TIMESTAMP |
| 40 | + "TIMESTAMPTZ" -> PostgreSqlType.TIMESTAMP_TIMEZONE |
| 41 | + "INTERVAL" -> PostgreSqlType.INTERVAL |
| 42 | + else -> throw IllegalArgumentException("Unknown date type ${dateDataType!!.text}") |
| 43 | + } |
| 44 | + } |
| 45 | + jsonDataType != null -> PrimitiveType.TEXT |
| 46 | + booleanDataType != null -> PrimitiveType.BOOLEAN |
| 47 | + blobDataType != null -> PrimitiveType.BLOB |
| 48 | + else -> throw IllegalArgumentException("Unknown kotlin type for sql type ${this.text}") |
| 49 | + } |
| 50 | + ) |
| 51 | + if (node.getChildren(null).map { it.text }.takeLast(2) == listOf("[", "]")) { |
| 52 | + return IntermediateType(object : DialectType { |
| 53 | + override val javaType = Array::class.asTypeName().parameterizedBy(type.javaType) |
| 54 | + |
| 55 | + override fun prepareStatementBinder(columnIndex: String, value: CodeBlock) = |
| 56 | + CodeBlock.of("bindArray($columnIndex, %L)\n", value) |
| 57 | + |
| 58 | + override fun cursorGetter(columnIndex: Int, cursorName: String) = |
| 59 | + CodeBlock.of("$cursorName.getArray($columnIndex)") |
| 60 | + }) |
| 61 | + } |
| 62 | + return type |
| 63 | + } |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +internal enum class PostgreSqlType(override val javaType: TypeName): DialectType { |
| 68 | + SMALL_INT(SHORT) { |
| 69 | + override fun decode(value: CodeBlock) = CodeBlock.of("%L.toShort()", value) |
| 70 | + |
| 71 | + override fun encode(value: CodeBlock) = CodeBlock.of("%L.toLong()", value) |
| 72 | + }, |
| 73 | + INTEGER(INT) { |
| 74 | + override fun decode(value: CodeBlock) = CodeBlock.of("%L.toInt()", value) |
| 75 | + |
| 76 | + override fun encode(value: CodeBlock) = CodeBlock.of("%L.toLong()", value) |
| 77 | + }, |
| 78 | + BIG_INT(LONG), |
| 79 | + DATE(ClassName("kotlinx.datetime", "LocalDate")), |
| 80 | + //TIME(kotlinx.datetime.LocalTime::class.asTypeName()), |
| 81 | + TIMESTAMP(ClassName("kotlinx.datetime", "LocalDateTime")), |
| 82 | + TIMESTAMP_TIMEZONE(ClassName("kotlinx.datetime", "Instant")), |
| 83 | + INTERVAL(ClassName("kotlin.time", "Duration")), |
| 84 | + UUID(ClassName("kotlinx.uuid", "UUID")); |
| 85 | + |
| 86 | + override fun prepareStatementBinder(columnIndex: String, value: CodeBlock): CodeBlock { |
| 87 | + return CodeBlock.builder() |
| 88 | + .add( |
| 89 | + when (this) { |
| 90 | + SMALL_INT, INTEGER, BIG_INT -> "bindLong" |
| 91 | + DATE -> "bindDate" |
| 92 | + //TIME -> "bindTime" |
| 93 | + TIMESTAMP -> "bindLocalTimestamp" |
| 94 | + TIMESTAMP_TIMEZONE -> "bindTimestamp" |
| 95 | + INTERVAL -> "bindInterval" |
| 96 | + UUID -> "bindUUID" |
| 97 | + } |
| 98 | + ) |
| 99 | + .add("($columnIndex, %L)\n", value) |
| 100 | + .build() |
| 101 | + } |
| 102 | + |
| 103 | + override fun cursorGetter(columnIndex: Int, cursorName: String): CodeBlock { |
| 104 | + return CodeBlock.of( |
| 105 | + when (this) { |
| 106 | + SMALL_INT, INTEGER, BIG_INT -> "$cursorName.getLong($columnIndex)" |
| 107 | + DATE -> "$cursorName.getDate($columnIndex)" |
| 108 | + //TIME -> "$cursorName.getTime($columnIndex)" |
| 109 | + TIMESTAMP -> "$cursorName.getLocalTimestamp($columnIndex)" |
| 110 | + TIMESTAMP_TIMEZONE -> "$cursorName.getTimestamp($columnIndex)" |
| 111 | + INTERVAL -> "$cursorName.getInterval($columnIndex)" |
| 112 | + UUID -> "$cursorName.getUUID($columnIndex)" |
| 113 | + } |
| 114 | + ) |
| 115 | + } |
| 116 | +} |
0 commit comments