diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt index f3c35ac1f4..ac473dfdd1 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ColumnType.kt @@ -46,7 +46,7 @@ interface IColumnType { */ fun valueToString(value: Any?): String = when (value) { null -> { - check(nullable) { "NULL in non-nullable column" } + check(nullable) { "NULL in non-nullable column with type ${sqlType()}" } "NULL" } DefaultValueMarker -> "DEFAULT" diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt index ba715c178f..5c53e3f3df 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Expression.kt @@ -55,30 +55,59 @@ class QueryBuilder( when (argument) { is Expression<*> -> append(argument) DefaultValueMarker -> append(TransactionManager.current().db.dialect.dataTypeProvider.processForDefaultValue(column.dbDefaultValue!!)) - else -> registerArgument(column.columnType, argument) + else -> { + require(argument != null || column.columnType.nullable) { + "Column ${column.table.nameInDatabaseCase()}.${column.nameInDatabaseCase()} does not support NULLs" + } + @Suppress("DEPRECATION") + registerArgument(column.columnType, argument) + } + } + } + + /** Adds the specified [argument] as a value of the specified [expression]. */ + fun registerArgument(expression: ExpressionWithColumnType, argument: T) { + val sqlType = expression.columnType + require(argument != null || sqlType.nullable) { + "Can't register NULL value since expression has non-nullable type ${sqlType.sqlType()}, expression: $expression" } + @Suppress("DEPRECATION") + registerArgument(sqlType, argument) } /** Adds the specified [argument] as a value of the specified [sqlType]. */ - fun registerArgument(sqlType: IColumnType, argument: T): Unit = registerArguments(sqlType, listOf(argument)) + @Suppress("DeprecatedCallableAddReplaceWith") + @Deprecated( + level = DeprecationLevel.WARNING, + message = "Prefer registerArgument(Column, ...) and registerArgument(ExpressionWithColumnType, ...) since they have better error reporting" + ) + fun registerArgument(sqlType: IColumnType, argument: T) { + if (!prepared) { + +sqlType.valueToString(argument) + return + } + require(argument != null || sqlType.nullable) { + "Can't register NULL value for non-nullable type ${sqlType.sqlType()}" + } + +"?" + _args += sqlType to argument + } /** Adds the specified sequence of [arguments] as values of the specified [sqlType]. */ + @Deprecated( + message = "Replace with [SingleValueInListOp]", + level = DeprecationLevel.ERROR, + replaceWith = ReplaceWith("org.jetbrains.exposed.sql.ops.SingleValueInListOp") + ) fun registerArguments(sqlType: IColumnType, arguments: Iterable) { - fun toString(value: T) = when { - prepared && value is String -> value - else -> sqlType.valueToString(value) + if (!prepared) { + arguments.appendTo { +sqlType.valueToString(it) } + return + } + arguments.appendTo { + +"?" + _args += sqlType to it } - - arguments.map { it to toString(it) } - .sortedBy { it.second } - .appendTo { - if (prepared) { - _args.add(sqlType to it.first) - append("?") - } else { - append(it.second) - } - } } override fun toString(): String = internalBuilder.toString() diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt index 0972fc4c61..3279dbb428 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Op.kt @@ -1,6 +1,7 @@ package org.jetbrains.exposed.sql import org.jetbrains.exposed.dao.id.EntityID +import org.jetbrains.exposed.sql.ops.SingleValueInListOp import org.jetbrains.exposed.sql.vendors.OracleDialect import org.jetbrains.exposed.sql.vendors.SQLServerDialect import org.jetbrains.exposed.sql.vendors.currentDialect @@ -408,35 +409,10 @@ class InListOrNotInListOp( /** Returns `true` if the check is inverted, `false` otherwise. */ val isInList: Boolean = true ) : Op(), ComplexExpression { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { - list.iterator().let { i -> - if (!i.hasNext()) { - if (isInList) { - +FALSE - } else { - +TRUE - } - } else { - val first = i.next() - if (!i.hasNext()) { - append(expr) - when { - isInList -> append(" = ") - else -> append(" != ") - } - registerArgument(expr.columnType, first) - } else { - append(expr) - when { - isInList -> append(" IN (") - else -> append(" NOT IN (") - } - registerArguments(expr.columnType, list) - append(")") - } - } - } - } + private val impl = SingleValueInListOp(expr, list, isInList) + + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = + impl.toQueryBuilder(queryBuilder) } // Literals @@ -449,6 +425,11 @@ class LiteralOp( /** Returns the value being used as a literal. */ val value: T ) : ExpressionWithColumnType() { + init { + require(value != null || columnType.nullable) { + "Can't create NULL literal for non-nullable type $columnType" + } + } override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { +columnType.valueToString(value) } } @@ -506,7 +487,15 @@ class QueryParameter( /** Returns the column type of this expression. */ val sqlType: IColumnType ) : Expression() { - override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { registerArgument(sqlType, value) } + init { + require(value != null || sqlType.nullable) { + "Can't create NULL query parameter for non-nullable type $sqlType" + } + } + override fun toQueryBuilder(queryBuilder: QueryBuilder): Unit = queryBuilder { + @Suppress("DEPRECATION") + registerArgument(sqlType, value) + } } /** Returns the specified [value] as a query parameter with the same type as [column]. */ diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/InListOps.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/InListOps.kt index dd2e17c0e5..32af7cf623 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/InListOps.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/ops/InListOps.kt @@ -59,12 +59,12 @@ abstract class InListOrNotInListBaseOp ( private fun QueryBuilder.registerValues(values: List) { val singleColumn = columnTypes.singleOrNull() if (singleColumn != null) - registerArgument(singleColumn.columnType, values.single()) + registerArgument(singleColumn as ExpressionWithColumnType, values.single()) else { append("(") columnTypes.forEachIndexed { index, columnExpression -> if (index != 0) append(", ") - registerArgument(columnExpression.columnType, values[index]) + registerArgument(columnExpression as ExpressionWithColumnType, values[index]) } append(")") } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt index 8a1e34b2a3..bf09034e30 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateBuilder.kt @@ -3,7 +3,6 @@ package org.jetbrains.exposed.sql.statements import org.jetbrains.exposed.dao.id.EntityID import org.jetbrains.exposed.sql.* -import java.util.* /** * @author max @@ -13,28 +12,28 @@ abstract class UpdateBuilder(type: StatementType, targets: List) : protected val values: MutableMap, Any?> = LinkedHashMap() open operator fun set(column: Column, value: S) { - when { - values.containsKey(column) -> error("$column is already initialized") - !column.columnType.nullable && value == null -> error("Trying to set null to not nullable column $column") - else -> { - column.columnType.validateValueBeforeUpdate(value) - values[column] = value - } + require(!values.containsKey(column)) { "$column is already initialized" } + column.validateValueBeforeUpdate(value) + values[column] = value + } + + private fun Column<*>.validateValueBeforeUpdate(value: Any?) { + require(columnType.nullable || value != null && !(value is LiteralOp<*> && value.value == null)) { + "Can't set NULL into non-nullable column ${table.tableName}.$name, column type is $columnType" } + columnType.validateValueBeforeUpdate(value) } @JvmName("setWithEntityIdExpression") operator fun > set(column: Column?>, value: Expression) { - require(!values.containsKey(column)) { "$column is already initialized" } - column.columnType.validateValueBeforeUpdate(value) - values[column] = value + @Suppress("UNCHECKED_CAST") + set(column as Column, value as Any?) } @JvmName("setWithEntityIdValue") operator fun > set(column: Column?>, value: S?) { - require(!values.containsKey(column)) { "$column is already initialized" } - column.columnType.validateValueBeforeUpdate(value) - values[column] = value + @Suppress("UNCHECKED_CAST") + set(column as Column, value as Any?) } /** @@ -47,19 +46,22 @@ abstract class UpdateBuilder(type: StatementType, targets: List
) : open operator fun set(column: Column, value: Expression) = update(column, value) open operator fun set(column: CompositeColumn, value: S) { - column.getRealColumnsWithValues(value).forEach { (realColumn, itsValue) -> set(realColumn as Column, itsValue) } + column.getRealColumnsWithValues(value).forEach { (realColumn, itsValue) -> + @Suppress("UNCHECKED_CAST") + set(realColumn as Column, itsValue) + } } open fun update(column: Column, value: Expression) { - require(!values.containsKey(column)) { "$column is already initialized" } - column.columnType.validateValueBeforeUpdate(value) - values[column] = value + @Suppress("UNCHECKED_CAST") + set(column as Column, value as Any?) } open fun update(column: Column, value: SqlExpressionBuilder.() -> Expression) { - require(!values.containsKey(column)) { "$column is already initialized" } - val expression = SqlExpressionBuilder.value() - column.columnType.validateValueBeforeUpdate(expression) - values[column] = expression + // Note: the implementation builds value before it verifies if the column is already initialized + // however it makes the implementation easier and the optimization is not that important since + // the exceptional case should be rare. + @Suppress("UNCHECKED_CAST") + set(column as Column, SqlExpressionBuilder.value() as Any?) } } diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/Mysql.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/Mysql.kt index 4d211be92c..9d5c0d94f9 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/Mysql.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/vendors/Mysql.kt @@ -62,7 +62,7 @@ internal open class MysqlFunctionProvider : FunctionProvider() { override fun replace(table: Table, data: List, Any?>>, transaction: Transaction): String { val builder = QueryBuilder(true) val columns = data.joinToString { transaction.identity(it.first) } - val values = builder.apply { data.appendTo { registerArgument(it.first.columnType, it.second) } }.toString() + val values = builder.apply { data.appendTo { registerArgument(it.first, it.second) } }.toString() return "REPLACE INTO ${transaction.identity(table)} ($columns) VALUES ($values)" } diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/InsertTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/InsertTests.kt index 32c57839ab..5c158e7712 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/InsertTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/InsertTests.kt @@ -17,7 +17,9 @@ import org.jetbrains.exposed.sql.tests.shared.expectException import org.jetbrains.exposed.sql.vendors.MysqlDialect import org.junit.Test import java.math.BigDecimal +import kotlin.test.assertContains import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotNull class InsertTests : DatabaseTestsBase() { @@ -201,6 +203,63 @@ class InsertTests : DatabaseTestsBase() { } } + @Test + fun testInsertNullIntoNonNullableColumn() { + val cities = object : IntIdTable("cities") { + } + val users = object : IntIdTable("users") { + val cityId = reference("city_id", cities) + } + + withTables(users, cities) { + // This is needed so valid inserts to users to succeed + cities.insert { + it[id] = 42 + } + users.insert { + // The assertion would try inserting null, and it ensures the insert would fail before the statement is even generated + it.assertInsertNullFails(cityId) + // This is needed for insert statement to succeed + it[cityId] = 42 + } + } + } + + private fun > UpdateBuilder.assertInsertNullFails(column: Column>) { + fun assertInsertNullFails(column: Column>, block: () -> Unit) { + val e = assertFailsWith( + """ + Unfortunately, type system can't protect from inserting null here + since the setter is declared as set(column: Column?>, value: S?), + and there's no way to tell that nullness of both arguments should match, so expecting it[${column.name}] = null + to fail at runtime + """.trimIndent() + ) { + block() + } + val message = e.toString() + assertContains( + message, + "${column.table.tableName}.${column.name}", ignoreCase = true, + "Exception message should contain table and column name" + ) + assertContains(message, column.columnType.toString(), ignoreCase = true, "Exception message should contain column type") + } + + require(!column.columnType.nullable) { + "Assertion works for non-nullable columns only. Given column ${column.table.tableName}.${column.name} is nullable ${column.columnType}" + } + assertInsertNullFails(column) { + // This is written explicitly to demonstrate that the code compiles, yet it fails in the runtime + // This call resolves to set(column: Column?>, value: S?) + this[column] = null + } + val nullableType = EntityIDColumnType(column).apply { nullable = true } + assertInsertNullFails(column) { + this[column] = LiteralOp(nullableType, null) + } + } + @Test fun testInsertWithPredefinedId() { val stringTable = object : IdTable("stringTable") { override val id = varchar("id", 15).entityId()