From f089ba9ff222b3ceed7138858c8e6595a00d6ae2 Mon Sep 17 00:00:00 2001 From: Swastik Baranwal Date: Mon, 23 Dec 2024 19:49:27 +0530 Subject: [PATCH] cgen, checker: allow using smartcasted sumtype variant values in the ORM queries (fix #23239) (#23241) --- cmd/tools/vtest-self.v | 4 ++++ vlib/orm/orm_sum_type_insert_test.v | 26 ++++++++++++++++++++++++++ vlib/v/checker/orm.v | 3 ++- vlib/v/gen/c/orm.v | 7 +++++++ 4 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 vlib/orm/orm_sum_type_insert_test.v diff --git a/cmd/tools/vtest-self.v b/cmd/tools/vtest-self.v index 45fdd36b9bdcce..3696c16afe6d71 100644 --- a/cmd/tools/vtest-self.v +++ b/cmd/tools/vtest-self.v @@ -152,6 +152,7 @@ const skip_with_fsanitize_memory = [ 'vlib/orm/orm_create_and_drop_test.v', 'vlib/orm/orm_insert_test.v', 'vlib/orm/orm_insert_reserved_name_test.v', + 'vlib/orm/orm_sum_type_insert_test.v', 'vlib/orm/orm_fn_calls_test.v', 'vlib/orm/orm_last_id_test.v', 'vlib/orm/orm_string_interpolation_in_where_test.v', @@ -199,6 +200,7 @@ const skip_with_fsanitize_address = [ 'vlib/orm/orm_create_and_drop_test.v', 'vlib/orm/orm_insert_test.v', 'vlib/orm/orm_insert_reserved_name_test.v', + 'vlib/orm/orm_sum_type_insert_test.v', 'vlib/orm/orm_references_test.v', 'vlib/v/tests/websocket_logger_interface_should_compile_test.v', 'vlib/v/tests/orm_enum_test.v', @@ -212,6 +214,7 @@ const skip_with_fsanitize_undefined = [ 'vlib/orm/orm_create_and_drop_test.v', 'vlib/orm/orm_insert_test.v', 'vlib/orm/orm_insert_reserved_name_test.v', + 'vlib/orm/orm_sum_type_insert_test.v', 'vlib/orm/orm_references_test.v', 'vlib/v/tests/orm_enum_test.v', 'vlib/v/tests/orm_sub_array_struct_test.v', @@ -254,6 +257,7 @@ const skip_on_ubuntu_musl = [ 'vlib/orm/orm_create_and_drop_test.v', 'vlib/orm/orm_insert_test.v', 'vlib/orm/orm_insert_reserved_name_test.v', + 'vlib/orm/orm_sum_type_insert_test.v', 'vlib/orm/orm_fn_calls_test.v', 'vlib/orm/orm_null_test.v', 'vlib/orm/orm_last_id_test.v', diff --git a/vlib/orm/orm_sum_type_insert_test.v b/vlib/orm/orm_sum_type_insert_test.v new file mode 100644 index 00000000000000..2ef0add8da6ba4 --- /dev/null +++ b/vlib/orm/orm_sum_type_insert_test.v @@ -0,0 +1,26 @@ +import db.sqlite + +struct SomeStruct { + foo int + bar string +} + +struct OtherStruct { + baz f64 +} + +type SomeSum = SomeStruct | OtherStruct + +fn test_sum_type_insert() { + db := sqlite.connect(':memory:')! + sql db { + create table SomeStruct + }! + + some := SomeSum(SomeStruct{}) + if some is SomeStruct { + sql db { + insert some into SomeStruct + }! + } +} diff --git a/vlib/v/checker/orm.v b/vlib/v/checker/orm.v index 358a854ca73540..a4f157f1715fce 100644 --- a/vlib/v/checker/orm.v +++ b/vlib/v/checker/orm.v @@ -253,7 +253,8 @@ fn (mut c Checker) sql_stmt_line(mut node ast.SqlStmtLine) ast.Type { inserting_object_type = inserting_object.typ.deref() } - if inserting_object_type != node.table_expr.typ { + if inserting_object_type != node.table_expr.typ + && !c.table.sumtype_has_variant(inserting_object_type, node.table_expr.typ, false) { table_name := table_sym.name inserting_type_name := c.table.sym(inserting_object_type).name diff --git a/vlib/v/gen/c/orm.v b/vlib/v/gen/c/orm.v index a1f3b69e2c7338..25827d9c242f13 100644 --- a/vlib/v/gen/c/orm.v +++ b/vlib/v/gen/c/orm.v @@ -334,6 +334,7 @@ fn (mut g Gen) write_orm_insert_with_last_ids(node ast.SqlStmtLine, connection_v is_serial := primary_field.attrs.contains_arg('sql', 'serial') && primary_field.typ == ast.int_type + mut inserting_object_type := ast.void_type mut member_access_type := '.' if node.scope != unsafe { nil } { inserting_object := node.scope.find(node.object_var) or { @@ -342,8 +343,10 @@ fn (mut g Gen) write_orm_insert_with_last_ids(node ast.SqlStmtLine, connection_v if inserting_object.typ.is_ptr() { member_access_type = '->' } + inserting_object_type = inserting_object.typ } + inserting_object_sym := g.table.sym(inserting_object_type) for i, mut sub in subs { if subs_unwrapped_c_typ[i].len > 0 { var := '${node.object_var}${member_access_type}${sub.object_var}' @@ -418,6 +421,10 @@ fn (mut g Gen) write_orm_insert_with_last_ids(node ast.SqlStmtLine, connection_v var := '${node.object_var}${member_access_type}${c_name(field.name)}' if field.typ.has_flag(.option) { g.writeln('${var}.state == 2? _const_orm__null_primitive : orm__${typ}_to_primitive(*(${ctyp}*)(${var}.data)),') + } else if inserting_object_sym.kind == .sum_type { + table_sym := g.table.sym(node.table_expr.typ) + sum_type_var := '(*${node.object_var}._${table_sym.cname})${member_access_type}${c_name(field.name)}' + g.writeln('orm__${typ}_to_primitive(${sum_type_var}),') } else { g.writeln('orm__${typ}_to_primitive(${var}),') }