Skip to content

Commit

Permalink
Refine BiCompositeColumn generics
Browse files Browse the repository at this point in the history
  • Loading branch information
vlsi committed Jun 19, 2021
1 parent c9c87c3 commit 858c54b
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ abstract class CompositeColumn<T> : Expression<T>() {
abstract class BiCompositeColumn<C1, C2, T>(
protected val column1: Column<C1>,
protected val column2: Column<C2>,
val transformFromValue: (T) -> Pair<C1?, C2?>,
val transformToValue: (Any?, Any?) -> T
val transformToValue: (C1, C2) -> T,
val transformFromValue: (T) -> Pair<C1, C2>
) : CompositeColumn<T>() {

override fun getRealColumns(): List<Column<*>> = listOf(column1, column2)
Expand All @@ -50,12 +50,11 @@ abstract class BiCompositeColumn<C1, C2, T>(
}

override fun restoreValueFromParts(parts: Map<Column<*>, Any?>): T {
val v1 = parts[column1]
val v2 = parts[column2]
val result = transformToValue(v1, v2)
check(result != null || nullable) {
@Suppress("UNCHECKED_CAST")
val result = transformToValue(parts[column1] as C1, parts[column2] as C2)
require(result != null || nullable) {
"Null value received from DB for non-nullable ${this::class.simpleName} column"
}
return result as T
return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,21 @@ class ResultRow(val fieldIndex: Map<Expression<*>, Int>) {
return when {
raw == null -> null
raw == NotInitializedValue -> error("$c is not initialized yet")
c is CompositeColumn<T> -> c.restoreValueFromParts(
(raw as Map<Column<*>, Any?>).mapValues { (column, rawValue) ->
rawToColumnValue(rawValue, column as Column<Any?>)
})
c is ExpressionAlias<T> && c.delegate is ExpressionWithColumnType<T> -> c.delegate.columnType.valueFromDB(raw)
c is ExpressionWithColumnType<T> -> c.columnType.valueFromDB(raw)
else -> raw
} as T
}

@Suppress("UNCHECKED_CAST")
private fun <T> getRaw(c: Expression<T>): T? {
if (c is CompositeColumn<T>) {
val rawParts = c.getRealColumns().associateWith { getRaw(it) }
return c.restoreValueFromParts(rawParts)
@Suppress("UNCHECKED_CAST")
return rawParts as T?
}

val index = fieldIndex[c]
Expand All @@ -70,6 +74,7 @@ class ResultRow(val fieldIndex: Map<Expression<*>, Int>) {
}?.let { fieldIndex[it] }
?: error("$c is not in record set")

@Suppress("UNCHECKED_CAST")
return data[index] as T?
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@ class CompositeMoneyColumn<T1 : BigDecimal?, T2 : CurrencyUnit?, R : MonetaryAmo
column1 = amount,
column2 = currency,
transformFromValue = { money ->
val amountValue = money?.number?.numberValue(BigDecimal::class.java) as? T1
val currencyValue = money?.currency as? T2
val amountValue = money?.number?.numberValue(BigDecimal::class.java) as T1
val currencyValue = money?.currency as T2
amountValue to currencyValue
},
transformToValue = { amountVal, currencyVal ->
if (amountVal == null || currencyVal == null) {
null as R
} else {
val result = Monetary.getDefaultAmountFactory().setNumber(amountVal as Number)

when (currencyVal) {
is CurrencyUnit -> result.setCurrency(currencyVal)
is String -> result.setCurrency(currencyVal)
}

result.create() as R
Monetary.getDefaultAmountFactory().run {
setNumber(amountVal)
setCurrency(currencyVal)
}.create() as R
}
}
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.exposed.sql.money

import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.CompositeColumn
import org.jetbrains.exposed.sql.Table
import java.math.BigDecimal
import javax.money.CurrencyUnit
Expand All @@ -23,5 +24,7 @@ fun Table.compositeMoney(amountColumn: Column<BigDecimal?>, currencyColumn: Colu
if (amountColumn !in columns && currencyColumn !in columns) {
registerCompositeColumn(it)
}
// Set CompositeColumn.nullable = true
(it as CompositeColumn<MonetaryAmount>).nullable()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,27 @@ class MoneyDefaultsTest : DatabaseTestsBase() {
val field = varchar("field", 100)
val t1 = compositeMoney(10, 0, "t1").default(defaultValue)
val t2 = compositeMoney(10, 0, "t2").nullable()

val price_amount = decimal("price_amount", 10, 0).nullable()
val price_currency = currency("price_currency").nullable()
val price = compositeMoney(price_amount, price_currency) /* it is implicitly nullable since price_amount is nullable */

val clientDefault = integer("clientDefault").clientDefault { cIndex++ }
}

class DBDefault(id: EntityID<Int>) : IntEntity(id) {
var field by TableWithDBDefault.field
var t1 by TableWithDBDefault.t1
var t2 by TableWithDBDefault.t2
var price by TableWithDBDefault.price
val clientDefault by TableWithDBDefault.clientDefault

override fun hashCode(): Int = id.value.hashCode()

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is DBDefault) return false
if (other.t1 != other.t1) return false
if (other.t2 != other.t2) return false
if (other.clientDefault != other.clientDefault) return false

return true
return id.value == other.id.value
}

companion object : IntEntityClass<DBDefault>(TableWithDBDefault)
Expand Down Expand Up @@ -94,4 +96,19 @@ class MoneyDefaultsTest : DatabaseTestsBase() {
assertEquals(TableWithDBDefault.defaultValue, db1.t1)
}
}

@Test
fun testImplicitlyNullableCompositeColumnType() {
withTables(TableWithDBDefault) {
TableWithDBDefault.cIndex = 0
val db1 = DBDefault.new { field = "1" }
flushCache()
assertNull(db1.price, "db1.price should be null since it was not set when calling new")
val money = Money.of(BigDecimal.ONE, "USD")
db1.price = money
db1.refresh(flush = true)
assertEquals(money, db1.t1)
assertEquals(TableWithDBDefault.defaultValue, db1.t1)
}
}
}

0 comments on commit 858c54b

Please sign in to comment.