diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/CompositeColumn.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/CompositeColumn.kt index 96f07229ba..5390cf3d25 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/CompositeColumn.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/CompositeColumn.kt @@ -35,8 +35,8 @@ abstract class CompositeColumn : Expression() { abstract class BiCompositeColumn( protected val column1: Column, protected val column2: Column, - val transformFromValue: (T) -> Pair, - val transformToValue: (Any?, Any?) -> T + val transformToValue: (C1, C2) -> T, + val transformFromValue: (T) -> Pair ) : CompositeColumn() { override fun getRealColumns(): List> = listOf(column1, column2) @@ -50,12 +50,11 @@ abstract class BiCompositeColumn( } override fun restoreValueFromParts(parts: Map, Any?>): T { - val v1 = parts[column1] - val v2 = parts[column2] - val result = transformToValue(v1, v2) - check(result != null || nullable) { + @Suppress("UNCHECKED_CAST") + val result = transformToValue(parts[column1] as C1, parts[column2] as C2) + require(result != null || nullable) { "Null value received from DB for non-nullable ${this::class.simpleName} column" } - return result as T + return result } } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ResultRow.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ResultRow.kt index f1ad570acb..05b4ae7281 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ResultRow.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ResultRow.kt @@ -50,17 +50,21 @@ class ResultRow(val fieldIndex: Map, Int>) { return when { raw == null -> null raw == NotInitializedValue -> error("$c is not initialized yet") + c is CompositeColumn -> c.restoreValueFromParts( + (raw as Map, Any?>).mapValues { (column, rawValue) -> + rawToColumnValue(rawValue, column as Column) + }) c is ExpressionAlias && c.delegate is ExpressionWithColumnType -> c.delegate.columnType.valueFromDB(raw) c is ExpressionWithColumnType -> c.columnType.valueFromDB(raw) else -> raw } as T } - @Suppress("UNCHECKED_CAST") private fun getRaw(c: Expression): T? { if (c is CompositeColumn) { val rawParts = c.getRealColumns().associateWith { getRaw(it) } - return c.restoreValueFromParts(rawParts) + @Suppress("UNCHECKED_CAST") + return rawParts as T? } val index = fieldIndex[c] @@ -70,6 +74,7 @@ class ResultRow(val fieldIndex: Map, Int>) { }?.let { fieldIndex[it] } ?: error("$c is not in record set") + @Suppress("UNCHECKED_CAST") return data[index] as T? } diff --git a/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumn.kt b/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumn.kt index ccc7315533..67416d0910 100644 --- a/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumn.kt +++ b/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumn.kt @@ -20,22 +20,18 @@ class CompositeMoneyColumn - val amountValue = money?.number?.numberValue(BigDecimal::class.java) as? T1 - val currencyValue = money?.currency as? T2 + val amountValue = money?.number?.numberValue(BigDecimal::class.java) as T1 + val currencyValue = money?.currency as T2 amountValue to currencyValue }, transformToValue = { amountVal, currencyVal -> if (amountVal == null || currencyVal == null) { null as R } else { - val result = Monetary.getDefaultAmountFactory().setNumber(amountVal as Number) - - when (currencyVal) { - is CurrencyUnit -> result.setCurrency(currencyVal) - is String -> result.setCurrency(currencyVal) - } - - result.create() as R + Monetary.getDefaultAmountFactory().run { + setNumber(amountVal) + setCurrency(currencyVal) + }.create() as R } } ) diff --git a/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumnType.kt b/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumnType.kt index fef945653a..afa765e15e 100644 --- a/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumnType.kt +++ b/exposed-money/src/main/kotlin/org/jetbrains/exposed/sql/money/CompositeMoneyColumnType.kt @@ -1,6 +1,7 @@ package org.jetbrains.exposed.sql.money import org.jetbrains.exposed.sql.Column +import org.jetbrains.exposed.sql.CompositeColumn import org.jetbrains.exposed.sql.Table import java.math.BigDecimal import javax.money.CurrencyUnit @@ -23,5 +24,7 @@ fun Table.compositeMoney(amountColumn: Column, currencyColumn: Colu if (amountColumn !in columns && currencyColumn !in columns) { registerCompositeColumn(it) } + // Set CompositeColumn.nullable = true + (it as CompositeColumn).nullable() } } diff --git a/exposed-money/src/test/kotlin/org/jetbrains/exposed/sql/money/MoneyDefaultsTest.kt b/exposed-money/src/test/kotlin/org/jetbrains/exposed/sql/money/MoneyDefaultsTest.kt index 2682e5f1c8..47c1d225a6 100644 --- a/exposed-money/src/test/kotlin/org/jetbrains/exposed/sql/money/MoneyDefaultsTest.kt +++ b/exposed-money/src/test/kotlin/org/jetbrains/exposed/sql/money/MoneyDefaultsTest.kt @@ -22,6 +22,11 @@ class MoneyDefaultsTest : DatabaseTestsBase() { val field = varchar("field", 100) val t1 = compositeMoney(10, 0, "t1").default(defaultValue) val t2 = compositeMoney(10, 0, "t2").nullable() + + val price_amount = decimal("price_amount", 10, 0).nullable() + val price_currency = currency("price_currency").nullable() + val price = compositeMoney(price_amount, price_currency) /* it is implicitly nullable since price_amount is nullable */ + val clientDefault = integer("clientDefault").clientDefault { cIndex++ } } @@ -29,6 +34,7 @@ class MoneyDefaultsTest : DatabaseTestsBase() { var field by TableWithDBDefault.field var t1 by TableWithDBDefault.t1 var t2 by TableWithDBDefault.t2 + var price by TableWithDBDefault.price val clientDefault by TableWithDBDefault.clientDefault override fun hashCode(): Int = id.value.hashCode() @@ -36,11 +42,7 @@ class MoneyDefaultsTest : DatabaseTestsBase() { override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is DBDefault) return false - if (other.t1 != other.t1) return false - if (other.t2 != other.t2) return false - if (other.clientDefault != other.clientDefault) return false - - return true + return id.value == other.id.value } companion object : IntEntityClass(TableWithDBDefault) @@ -94,4 +96,19 @@ class MoneyDefaultsTest : DatabaseTestsBase() { assertEquals(TableWithDBDefault.defaultValue, db1.t1) } } + + @Test + fun testImplicitlyNullableCompositeColumnType() { + withTables(TableWithDBDefault) { + TableWithDBDefault.cIndex = 0 + val db1 = DBDefault.new { field = "1" } + flushCache() + assertNull(db1.price, "db1.price should be null since it was not set when calling new") + val money = Money.of(BigDecimal.ONE, "USD") + db1.price = money + db1.refresh(flush = true) + assertEquals(money, db1.t1) + assertEquals(TableWithDBDefault.defaultValue, db1.t1) + } + } }