Skip to content

Commit 4250890

Browse files
committed
Optimize type inference in joinWith
1 parent 457e9de commit 4250890

File tree

2 files changed

+42
-2
lines changed
  • core/src
    • main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api
    • test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person

2 files changed

+42
-2
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/joinWith.kt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.impl.api
33
import org.jetbrains.kotlinx.dataframe.DataColumn
44
import org.jetbrains.kotlinx.dataframe.DataFrame
55
import org.jetbrains.kotlinx.dataframe.DataRow
6+
import org.jetbrains.kotlinx.dataframe.api.Infer
67
import org.jetbrains.kotlinx.dataframe.api.JoinExpression
78
import org.jetbrains.kotlinx.dataframe.api.JoinType
89
import org.jetbrains.kotlinx.dataframe.api.JoinedDataRow
@@ -12,6 +13,7 @@ import org.jetbrains.kotlinx.dataframe.api.cast
1213
import org.jetbrains.kotlinx.dataframe.api.count
1314
import org.jetbrains.kotlinx.dataframe.api.indices
1415
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
16+
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
1517
import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator
1618
import org.jetbrains.kotlinx.dataframe.impl.DataRowImpl
1719

@@ -91,8 +93,26 @@ internal fun <A, B> DataFrame<A>.joinWithImpl(
9193
}
9294
}
9395

96+
val leftColumns = columns()
97+
val rightColumns = if (addNewColumns) right.columns() else emptyList()
9498
val df: DataFrame<*> = outputData.mapIndexed { index, values ->
95-
DataColumn.createByInference(generator.names[index], values)
99+
val srcColumn = if (index < leftColumns.size) {
100+
leftColumns[index]
101+
} else {
102+
rightColumns[index - leftColumns.size]
103+
}
104+
// let's optimize an easy case.
105+
// handling introduction of nulls into ColumnGroup and FrameColumn is not straightforward
106+
when (srcColumn.kind()) {
107+
ColumnKind.Value -> DataColumn.createByType(
108+
name = generator.names[index],
109+
values = values,
110+
type = srcColumn.type(),
111+
infer = Infer.Nulls,
112+
)
113+
114+
else -> DataColumn.createByInference(generator.names[index], values)
115+
}
96116
}.toDataFrame()
97117

98118
return df.cast()

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/JoinWithTests.kt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dataframe.api.remove
2222
import org.jetbrains.kotlinx.dataframe.api.rightJoinWith
2323
import org.jetbrains.kotlinx.dataframe.api.select
2424
import org.junit.Test
25+
import kotlin.reflect.typeOf
2526

2627
@Suppress("ktlint:standard:argument-list-wrapping")
2728
class JoinWithTests : BaseJoinTest() {
@@ -34,9 +35,12 @@ class JoinWithTests : BaseJoinTest() {
3435
res.columnsCount() shouldBe 8
3536
res.rowsCount() shouldBe 7
3637
res["age1"].hasNulls() shouldBe false
38+
res["age1"].type() shouldBe typeOf<String>()
39+
res["age1"].values().all { it != null } shouldBe true
3740
res.count { name == "Charlie" && city == "Moscow" } shouldBe 4
3841
res.select { city and name }.distinct().rowsCount() shouldBe 3
3942
res[Person2::grade].hasNulls() shouldBe false
43+
res.age.type() shouldBe typeOf<Int>()
4044
}
4145

4246
@Test
@@ -49,6 +53,8 @@ class JoinWithTests : BaseJoinTest() {
4953
res.select { city and name }.distinct().rowsCount() shouldBe 6
5054
res.count { it["grade"] == null } shouldBe 3
5155
res.age.hasNulls() shouldBe false
56+
res.age.type() shouldBe typeOf<Int>()
57+
res["age1"].type() shouldBe typeOf<String?>()
5258
}
5359

5460
@Test
@@ -62,6 +68,8 @@ class JoinWithTests : BaseJoinTest() {
6268
res.select { city and name }.distinct().rowsCount() shouldBe 4
6369
res[Person2::grade].hasNulls() shouldBe false
6470
res.age.hasNulls() shouldBe true
71+
res.age.type() shouldBe typeOf<Int?>()
72+
res["age1"].type() shouldBe typeOf<String>()
6573
val newEntries = res.filter { it["age"] == null }
6674
newEntries.rowsCount() shouldBe 2
6775
newEntries.all { it["name1"] == "Bob" && it["origin"] == "Paris" && weight == null } shouldBe true
@@ -78,6 +86,8 @@ class JoinWithTests : BaseJoinTest() {
7886
val distinct = res.select { name and age and city and weight }.distinct()
7987
val expected = typed.append(null, null, null, null)
8088
distinct shouldBe expected
89+
res.age.type() shouldBe typeOf<Int?>()
90+
res["age1"].type() shouldBe typeOf<String?>()
8191
}
8292

8393
@Test
@@ -103,7 +113,7 @@ class JoinWithTests : BaseJoinTest() {
103113
}
104114

105115
@Test
106-
fun rightJoin() {
116+
fun `exclude join`() {
107117
val df = dataFrameOf("a", "b")(
108118
1, "a",
109119
2, "b",
@@ -116,7 +126,17 @@ class JoinWithTests : BaseJoinTest() {
116126
2, "II",
117127
3, "III",
118128
)
129+
119130
df.append(4, "e").excludeJoin(df1).print()
131+
132+
val res = df.append(4, "e").excludeJoin(df1)
133+
134+
res.rowsCount() shouldBe 1
135+
res["a"].values() shouldBe listOf(4)
136+
res["b"].values() shouldBe listOf("e")
137+
res.columnsCount() shouldBe 2
138+
res["a"].type() shouldBe typeOf<Int>()
139+
res["b"].type() shouldBe typeOf<String>()
120140
}
121141

122142
@Test

0 commit comments

Comments
 (0)