Skip to content

Add query util function in kotlin. #1492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import org.springframework.data.annotation.Transient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.beans.FeatureDescriptor;
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import org.reactivestreams.Publisher;
import org.springframework.beans.BeansException;
Expand Down Expand Up @@ -161,6 +165,26 @@ public R2dbcEntityTemplate(DatabaseClient databaseClient, ReactiveDataAccessStra
this.projectionFactory = new SpelAwareProxyProjectionFactory();
}

public <T> String getColumnName(Class<T> tableType, Field field) {
Transient annotation = field.getDeclaredAnnotation(Transient.class);
if (annotation != null) {
throw new MappingException("This Column has Transient annotation");
}

org.springframework.data.relational.core.mapping.Column columnAnnotation =
field.getDeclaredAnnotation(org.springframework.data.relational.core.mapping.Column.class);
if (columnAnnotation != null) {
return columnAnnotation.value();
}

return StreamSupport.stream(Objects.requireNonNull(mappingContext.getPersistentEntity(tableType)).spliterator(), false)
.filter(column -> field.equals(column.getField()))
.findFirst()
.orElseThrow(() -> new MappingException("This Column name is not matched"))
.getColumnName()
.toString();
}

@Override
public DatabaseClient getDatabaseClient() {
return this.databaseClient;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.springframework.data.r2dbc.core

import org.springframework.data.mapping.MappingException
import kotlin.reflect.KProperty
import kotlin.reflect.jvm.javaField

inline fun <reified T : Any, C : Any> R2dbcEntityTemplate.column(property: KProperty<C>): String = property.javaField?.let {
this.getColumnName(T::class.java, it)
} ?: throw MappingException("property is not valid")
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.springframework.data.r2dbc.core.query

import org.springframework.data.relational.core.query.Criteria
import org.springframework.data.relational.core.query.Query

fun query(vararg args: Criteria?): Query = Query.query(
args.fold(Criteria.empty()) { acc, arg ->
arg?.let { acc.and(it) } ?: acc
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,12 @@
*/
package org.springframework.data.r2dbc.core

data class Person(val id: String)
import org.springframework.data.annotation.Id
import org.springframework.data.relational.core.mapping.Table

@Table("person")
data class Person(
@Id
val id: String,
val name: String,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.springframework.data.r2dbc.core.query

import io.r2dbc.spi.R2dbcType
import io.r2dbc.spi.test.MockColumnMetadata
import io.r2dbc.spi.test.MockResult
import io.r2dbc.spi.test.MockRow
import io.r2dbc.spi.test.MockRowMetadata
import org.assertj.core.api.Assertions
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.EnumSource
import org.springframework.data.r2dbc.core.*
import org.springframework.data.r2dbc.dialect.PostgresDialect
import org.springframework.data.r2dbc.testing.StatementRecorder
import org.springframework.data.relational.core.query.Criteria.where
import org.springframework.data.relational.core.query.isEqual
import org.springframework.r2dbc.core.DatabaseClient
import reactor.test.StepVerifier

/**
* Unit tests for [QueryExtensions].
*
* @author kihwankim
*/
class QueryExtensionsKtTest {
private lateinit var client: DatabaseClient
private lateinit var entityTemplate: R2dbcEntityTemplate
private lateinit var recorder: StatementRecorder

@BeforeEach
fun before() {
recorder = StatementRecorder.newInstance()
client = DatabaseClient.builder().connectionFactory(recorder)
.bindMarkers(PostgresDialect.INSTANCE.bindMarkersFactory).build()
entityTemplate = R2dbcEntityTemplate(client, DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE))
}

@ParameterizedTest // gh-1491
@EnumSource(QueryConditionType::class)
fun shouldSelectAll(queryCondition: QueryConditionType) {
val metadata = MockRowMetadata.builder()
.columnMetadata(MockColumnMetadata.builder().name("id").type(R2dbcType.INTEGER).build())
.columnMetadata(MockColumnMetadata.builder().name("name").type(R2dbcType.VARCHAR).build())
.build()
val result = queryCondition.nameValue?.let {
listOf(
MockResult.builder()
.row(
MockRow.builder()
.identified("id", Any::class.java, 1L)
.identified("name", Any::class.java, queryCondition.nameValue)
.metadata(metadata)
.build()
).build()
)
} ?: emptyList()

recorder.addStubbing({ s: String -> s.startsWith("SELECT") }, result)

entityTemplate.select<Person>()
.matching(
query(
queryCondition.nameValue?.run { where("name") isEqual queryCondition.nameValue }
)
)
.all()
.`as`(StepVerifier::create)
.expectNextCount(queryCondition.count)
.verifyComplete()
val statement = recorder.getCreatedStatement { s: String -> s.startsWith("SELECT") }
Assertions.assertThat(statement.sql)
.isEqualTo(queryCondition.resultSql)
}

enum class QueryConditionType(val nameValue: String?, val count: Long, val resultSql: String) {
FILTER("testName", 1L, "SELECT person.* FROM person WHERE (person.name = $1)"),
NON_FILTER(null, 0L, "SELECT person.* FROM person")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ package org.springframework.data.relational.core.query
infix fun Criteria.CriteriaStep.isEqual(value: Any): Criteria =
`is`(value)


/**
* Extension for [Criteria.CriteriaStep.not] providing a
* `isNotEqual(value)` variant.
*/
infix fun Criteria.CriteriaStep.isNotEqual(value: Any): Criteria =
not(value)

/**
* Extension for [Criteria.CriteriaStep.in] providing a
* `isIn(value)` variant.
Expand All @@ -44,3 +52,46 @@ fun Criteria.CriteriaStep.isIn(vararg value: Any): Criteria =
*/
fun Criteria.CriteriaStep.isIn(values: Collection<Any>): Criteria =
`in`(values)


/**
* Extension for [Criteria.CriteriaStep.lessThan] providing a
* `le(value)` variant.
*/
infix fun Criteria.CriteriaStep.le(value: Any): Criteria =
lessThan(value)

/**
* Extension for [Criteria.CriteriaStep.lessThanOrEquals] providing a
* `loe(value)` variant.
*/
infix fun Criteria.CriteriaStep.loe(value: Any): Criteria =
lessThanOrEquals(value)

/**
* Extension for [Criteria.CriteriaStep.greaterThan] providing a
* `ge(value)` variant.
*/
infix fun Criteria.CriteriaStep.ge(value: Any): Criteria =
greaterThan(value)

/**
* Extension for [Criteria.CriteriaStep.greaterThanOrEquals] providing a
* `goe(value)` variant.
*/
infix fun Criteria.CriteriaStep.goe(value: Any): Criteria =
greaterThanOrEquals(value)

/**
* Extension for [Criteria.CriteriaStep.like] providing a
* `like(value)` variant.
*/
infix fun Criteria.CriteriaStep.like(value: Any): Criteria =
like(value)

/**
* Extension for [Criteria.CriteriaStep.notLike] providing a
* `notLike(value)` variant.
*/
infix fun Criteria.CriteriaStep.notLike(value: Any): Criteria =
notLike(value)
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ class CriteriaStepExtensionsTests {
}
}

@Test // gh-1491
fun notEqIsCriteriaStep() {

val spec = mockk<Criteria.CriteriaStep>()
val criteria = mockk<Criteria>()

every { spec.not("test") } returns criteria

assertThat(spec isNotEqual "test").isEqualTo(criteria)

verify {
spec.not("test")
}
}

@Test // DATAJDBC-522
fun inVarargCriteriaStep() {

Expand Down Expand Up @@ -72,4 +87,94 @@ class CriteriaStepExtensionsTests {
spec.`in`(listOf("test"))
}
}

@Test // gh-1491
fun leCriteriaStep() {

val spec = mockk<Criteria.CriteriaStep>()
val criteria = mockk<Criteria>()

every { spec.lessThan(any()) } returns criteria

assertThat(spec le 10).isEqualTo(criteria)

verify {
spec.lessThan(10)
}
}

@Test // gh-1491
fun loeCriteriaStep() {

val spec = mockk<Criteria.CriteriaStep>()
val criteria = mockk<Criteria>()

every { spec.lessThanOrEquals(10) } returns criteria

assertThat(spec loe 10).isEqualTo(criteria)

verify {
spec.lessThanOrEquals(10)
}
}

@Test
fun geCriteriaStep() {

val spec = mockk<Criteria.CriteriaStep>()
val criteria = mockk<Criteria>()

every { spec.greaterThan(10) } returns criteria

assertThat(spec ge 10).isEqualTo(criteria)

verify {
spec.greaterThan(10)
}
}

@Test
fun goeCriteriaStep() {

val spec = mockk<Criteria.CriteriaStep>()
val criteria = mockk<Criteria>()

every { spec.greaterThanOrEquals(10) } returns criteria

assertThat(spec goe 10).isEqualTo(criteria)

verify {
spec.greaterThanOrEquals(10)
}
}

@Test
fun likeCriteriaStep() {

val spec = mockk<Criteria.CriteriaStep>()
val criteria = mockk<Criteria>()

every { spec.like("abc%") } returns criteria

assertThat(spec like "abc%").isEqualTo(criteria)

verify {
spec.like("abc%")
}
}

@Test
fun notLikeCriteriaStep() {

val spec = mockk<Criteria.CriteriaStep>()
val criteria = mockk<Criteria>()

every { spec.notLike("abc%") } returns criteria

assertThat(spec notLike "abc%").isEqualTo(criteria)

verify {
spec.notLike("abc%")
}
}
}