Skip to content

Commit 10a65b0

Browse files
authored
Add squared distance function for arrays (kuzudb#5008)
1 parent 4a46184 commit 10a65b0

File tree

11 files changed

+120
-50
lines changed

11 files changed

+120
-50
lines changed

src/function/array/array_functions.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "function/array/functions/array_cross_product.h"
55
#include "function/array/functions/array_distance.h"
66
#include "function/array/functions/array_inner_product.h"
7+
#include "function/array/functions/array_squared_distance.h"
78
#include "function/array/vector_array_functions.h"
89
#include "function/scalar_function.h"
910

@@ -188,6 +189,10 @@ function_set ArrayDistanceFunction::getFunctionSet() {
188189
return templateGetFunctionSet<ArrayDistance>(name);
189190
}
190191

192+
function_set ArraySquaredDistanceFunction::getFunctionSet() {
193+
return templateGetFunctionSet<ArraySquaredDistance>(name);
194+
}
195+
191196
function_set ArrayInnerProductFunction::getFunctionSet() {
192197
return templateGetFunctionSet<ArrayInnerProduct>(name);
193198
}

src/function/function_collection.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ FunctionCollection* FunctionCollection::getFunctions() {
101101
// Array Functions
102102
SCALAR_FUNCTION(ArrayValueFunction), SCALAR_FUNCTION(ArrayCrossProductFunction),
103103
SCALAR_FUNCTION(ArrayCosineSimilarityFunction), SCALAR_FUNCTION(ArrayDistanceFunction),
104-
SCALAR_FUNCTION(ArrayInnerProductFunction), SCALAR_FUNCTION(ArrayDotProductFunction),
104+
SCALAR_FUNCTION(ArraySquaredDistanceFunction), SCALAR_FUNCTION(ArrayInnerProductFunction),
105+
SCALAR_FUNCTION(ArrayDotProductFunction),
105106

106107
// List functions
107108
SCALAR_FUNCTION(ListCreationFunction), SCALAR_FUNCTION(ListRangeFunction),

src/include/function/array/functions/array_cosine_similarity.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,27 @@
33
#include "math.h"
44

55
#include "common/vector/value_vector.h"
6+
#include <simsimd.h>
67

78
namespace kuzu {
89
namespace function {
910

1011
struct ArrayCosineSimilarity {
11-
template<typename T>
12+
template<std::floating_point T>
1213
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
1314
common::ValueVector& leftVector, common::ValueVector& rightVector,
1415
common::ValueVector& /*resultVector*/) {
1516
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
1617
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
17-
T distance = 0;
18-
T normLeft = 0;
19-
T normRight = 0;
20-
for (auto i = 0u; i < left.size; i++) {
21-
auto x = leftElements[i];
22-
auto y = rightElements[i];
23-
distance += x * y;
24-
normLeft += x * x;
25-
normRight += y * y;
18+
KU_ASSERT(left.size == right.size);
19+
simsimd_distance_t tmpResult = 0.0;
20+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
21+
if constexpr (std::is_same_v<T, float>) {
22+
simsimd_cos_f32(leftElements, rightElements, left.size, &tmpResult);
23+
} else {
24+
simsimd_cos_f64(leftElements, rightElements, left.size, &tmpResult);
2625
}
27-
auto similarity = distance / (std::sqrt(normLeft) * std::sqrt(normRight));
28-
result = std::max(static_cast<T>(-1), std::min(similarity, static_cast<T>(1)));
26+
result = 1.0 - tmpResult;
2927
}
3028
};
3129

src/include/function/array/functions/array_distance.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,18 @@
33
#include "math.h"
44

55
#include "common/vector/value_vector.h"
6+
#include "function/array/functions/array_squared_distance.h"
67

78
namespace kuzu {
89
namespace function {
910

1011
// Euclidean distance between two arrays.
1112
struct ArrayDistance {
12-
template<typename T>
13+
template<std::floating_point T>
1314
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
1415
common::ValueVector& leftVector, common::ValueVector& rightVector,
15-
common::ValueVector& /*resultVector*/) {
16-
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
17-
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
18-
KU_ASSERT(left.size == right.size);
19-
result = 0;
20-
for (auto i = 0u; i < left.size; i++) {
21-
auto diff = leftElements[i] - rightElements[i];
22-
result += diff * diff;
23-
}
16+
common::ValueVector& resultVector) {
17+
ArraySquaredDistance::operation(left, right, result, leftVector, rightVector, resultVector);
2418
result = std::sqrt(result);
2519
}
2620
};

src/include/function/array/functions/array_inner_product.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
#pragma once
22

33
#include "common/vector/value_vector.h"
4+
#include <simsimd.h>
45

56
namespace kuzu {
67
namespace function {
78

89
struct ArrayInnerProduct {
9-
template<typename T>
10+
template<std::floating_point T>
1011
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
1112
common::ValueVector& leftVector, common::ValueVector& rightVector,
1213
common::ValueVector& /*resultVector*/) {
1314
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
1415
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
15-
result = 0;
16-
for (auto i = 0u; i < left.size; i++) {
17-
result += leftElements[i] * rightElements[i];
16+
KU_ASSERT(left.size == right.size);
17+
simsimd_distance_t tmpResult = 0.0;
18+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
19+
if constexpr (std::is_same_v<T, float>) {
20+
simsimd_dot_f32(leftElements, rightElements, left.size, &tmpResult);
21+
} else {
22+
simsimd_dot_f64(leftElements, rightElements, left.size, &tmpResult);
1823
}
24+
result = tmpResult;
1925
}
2026
};
2127

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include "common/vector/value_vector.h"
4+
#include <simsimd.h>
5+
6+
namespace kuzu {
7+
namespace function {
8+
9+
struct ArraySquaredDistance {
10+
template<std::floating_point T>
11+
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
12+
common::ValueVector& leftVector, common::ValueVector& rightVector,
13+
common::ValueVector& /*resultVector*/) {
14+
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
15+
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
16+
KU_ASSERT(left.size == right.size);
17+
simsimd_distance_t tmpResult = 0.0;
18+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>);
19+
if constexpr (std::is_same_v<T, float>) {
20+
simsimd_l2sq_f32(leftElements, rightElements, left.size, &tmpResult);
21+
} else {
22+
simsimd_l2sq_f64(leftElements, rightElements, left.size, &tmpResult);
23+
}
24+
result = tmpResult;
25+
}
26+
};
27+
28+
} // namespace function
29+
} // namespace kuzu

src/include/function/array/vector_array_functions.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ struct ArrayDistanceFunction {
3030
static function_set getFunctionSet();
3131
};
3232

33+
struct ArraySquaredDistanceFunction {
34+
static constexpr const char* name = "ARRAY_SQUARED_DISTANCE";
35+
36+
static function_set getFunctionSet();
37+
};
38+
3339
struct ArrayInnerProductFunction {
3440
static constexpr const char* name = "ARRAY_INNER_PRODUCT";
3541

test/test_files/function/array.test

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,21 @@ Binder exception: ARRAY_COSINE_SIMILARITY requires argument type to be FLOAT[] o
108108
---- 1
109109
5.000000
110110

111+
-LOG ArraySquaredDistance
112+
-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return round(ARRAY_SQUARED_DISTANCE(e.location, array_value(to_float(3.4), to_float(2.7))),2)
113+
---- 7
114+
1.820000
115+
2.570000
116+
2.620000
117+
20.240000
118+
33.010000
119+
41.130000
120+
6.410000
121+
122+
-STATEMENT RETURN ARRAY_SQUARED_DISTANCE([-1, -2, -3.0], [2, -6.0, -3])
123+
---- 1
124+
25.000000
125+
111126
-LOG ArrayInnerProduct
112127
-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return round(ARRAY_INNER_PRODUCT(e.location, array_value(to_float(3.4), to_float(2.7))),2)
113128
---- 7
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
-DATASET CSV empty
2+
3+
--
4+
5+
-CASE 8DimL2Naive
6+
-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id));
7+
---- ok
8+
-STATEMENT COPY embeddings FROM "${KUZU_ROOT_DIRECTORY}/dataset/embeddings/embeddings-8-1k.csv" (deLim=',');
9+
---- ok
10+
-STATEMENT match (a:embeddings) with array_distance(CAST([0.1521,0.3021,0.5366,0.2774,0.5593,0.5589,0.1365,0.8557],'FLOAT[8]'), a.vec) as distance, a return a.id order by distance limit 3
11+
-CHECK_ORDER
12+
---- 3
13+
333
14+
444
15+
133
16+
17+
-CASE 8DimCosNaive
18+
-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id));
19+
---- ok
20+
-STATEMENT COPY embeddings FROM "${KUZU_ROOT_DIRECTORY}/dataset/embeddings/embeddings-8-1k.csv" (deLim=',');
21+
---- ok
22+
-STATEMENT match (a:embeddings) with 1 - array_cosine_similarity(CAST([0.1521,0.3021,0.5366,0.2774,0.5593,0.5589,0.1365,0.8557],'FLOAT[8]'), a.vec) as distance, a return a.id order by distance limit 3
23+
-CHECK_ORDER
24+
---- 3
25+
333
26+
444
27+
146
28+
29+
-CASE 8DimDPNaive
30+
-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id));
31+
---- ok
32+
-STATEMENT COPY embeddings FROM "${KUZU_ROOT_DIRECTORY}/dataset/embeddings/embeddings-8-1k.csv" (deLim=',');
33+
---- ok
34+
-STATEMENT match (a:embeddings) with array_dot_product(CAST([0.1521,0.3021,0.5366,0.2774,0.5593,0.5589,0.1365,0.8557],'FLOAT[8]'), a.vec) as distance, a return a.id order by distance limit 3
35+
-CHECK_ORDER
36+
---- 3
37+
499
38+
581
39+
58

test/test_files/function/hnsw/small.test

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
133
1818

1919
-CASE 8DimCos
20-
# When using dynamic dispatch simsimd uses its own approximate inverse square root
21-
# See third_party/simsimd/lib.c:10
22-
# In wasm (assumingly because it is 32-bit) that function will return slightly different distance values
23-
-SKIP_WASM
2420
-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id));
2521
---- ok
2622
-STATEMENT COPY embeddings FROM "${KUZU_ROOT_DIRECTORY}/dataset/embeddings/embeddings-8-1k.csv" (deLim=',');
@@ -34,19 +30,6 @@
3430
444
3531
146
3632

37-
-CASE 8DimCosIgnoreOrder
38-
-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id));
39-
---- ok
40-
-STATEMENT COPY embeddings FROM "${KUZU_ROOT_DIRECTORY}/dataset/embeddings/embeddings-8-1k.csv" (deLim=',');
41-
---- ok
42-
-STATEMENT CALL CREATE_HNSW_INDEX('e_hnsw_index', 'embeddings', 'vec');
43-
---- ok
44-
-STATEMENT CALL QUERY_HNSW_INDEX('e_hnsw_index', 'embeddings', CAST([0.1521,0.3021,0.5366,0.2774,0.5593,0.5589,0.1365,0.8557],'FLOAT[8]'), 3) RETURN nn.id ORDER BY _distance;
45-
---- 3
46-
333
47-
444
48-
146
49-
5033
# DP: DotProduct
5134
-CASE 8DimDP
5235
-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id));

0 commit comments

Comments
 (0)