Skip to content

Commit 1fcab73

Browse files
authored
Merge pull request #222 from abstractqqq/knn_regression
Knn regression
2 parents 49eca96 + 56843a9 commit 1fcab73

File tree

10 files changed

+599
-332
lines changed

10 files changed

+599
-332
lines changed

examples/basics.ipynb

+196-196
Large diffs are not rendered by default.

python/polars_ds/knn_queries.py

+85-14
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from __future__ import annotations
66
import polars as pl
7-
from typing import Iterable
7+
from typing import Iterable, List
88
from .type_alias import StrOrExpr, str_to_expr, Distance
99
from ._utils import pl_plugin
1010

1111
__all__ = [
1212
"query_knn_ptwise",
13+
"query_knn_avg",
1314
"is_knn_from",
1415
"within_dist_from",
1516
"query_radius_ptwise",
@@ -34,14 +35,14 @@ def query_knn_ptwise(
3435
to each row. By default, this will return k + 1 neighbors, because the point (the row) itself
3536
is a neighbor to itself and this returns k additional neighbors. The only exception to this
3637
is when data_mask excludes the point from being a neighbor, in which case, k + 1 distinct neighbors will
37-
be returned.
38+
be returned. Any row with a null/NaN will never be a neighbor and will have null as its neighbor.
3839
3940
Note that the index column must be convertible to u32. If you do not have a u32 column,
4041
you can generate one using pl.int_range(..), which should be a step before this. The index column
4142
must not contain nulls.
4243
4344
Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find
44-
k-neighbors within `max_bound`, then there will be < k neighbors returned.
45+
k neighbors within `max_bound`, then there will be < k neighbors returned.
4546
4647
Also note that this internally builds a kd-tree for fast querying and deallocates it once we
4748
are done. If you need to repeatedly run the same query on the same data, then it is not
@@ -81,25 +82,25 @@ def query_knn_ptwise(
8182

8283
idx = str_to_expr(index).cast(pl.UInt32).rechunk()
8384
cols = [idx]
84-
if eval_mask is None:
85-
skip_eval = False
86-
else:
87-
skip_eval = True
88-
cols.append(str_to_expr(eval_mask))
85+
feats: List[pl.Expr] = [str_to_expr(e) for e in features]
8986

90-
if data_mask is None:
91-
skip_data = False
87+
skip_data = data_mask is not None
88+
if skip_data:
89+
keep_mask = pl.all_horizontal(str_to_expr(data_mask), *(f.is_not_null() for f in feats))
9290
else:
93-
skip_data = True
94-
cols.append(str_to_expr(data_mask))
91+
keep_mask = pl.all_horizontal(f.is_not_null() for f in feats)
9592

96-
cols.extend(str_to_expr(x) for x in features)
93+
cols.append(keep_mask)
94+
skip_eval = eval_mask is not None
95+
if skip_eval:
96+
cols.append(str_to_expr(eval_mask))
97+
98+
cols.extend(feats)
9799
kwargs = {
98100
"k": k,
99101
"metric": str(dist).lower(),
100102
"parallel": parallel,
101103
"skip_eval": skip_eval,
102-
"skip_data": skip_data,
103104
"max_bound": max_bound,
104105
"epsilon": abs(epsilon),
105106
}
@@ -119,6 +120,76 @@ def query_knn_ptwise(
119120
)
120121

121122

123+
def query_knn_avg(
124+
*features: StrOrExpr,
125+
target: StrOrExpr,
126+
k: int,
127+
dist: Distance = "sql2",
128+
weighted: bool = False,
129+
parallel: bool = False,
130+
min_bound: float = 1e-9,
131+
max_bound: float = 99999.0,
132+
) -> pl.Expr:
133+
"""
134+
Takes the target column, and uses feature columns to determine the k nearest neighbors
135+
to each row. By default, this will return k + 1 neighbors, because the point (the row) itself
136+
is a neighbor to itself and this returns k additional neighbors. Any row with a null/NaN will
137+
never be a neighbor and will get null as the average.
138+
139+
Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find
140+
k neighbors within `max_bound`, then there will be < k neighbors returned.
141+
142+
This is also known as KNN Regression, but really it is just the average of the K nearest neighbors.
143+
144+
Parameters
145+
----------
146+
*features : str | pl.Expr
147+
Other columns used as features
148+
target : str | pl.Expr
149+
Float, must be castable to f64. This should not contain null.
150+
k : int
151+
Number of neighbors to query
152+
dist : Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`]
153+
Note `sql2` stands for squared l2.
154+
weighted : bool
155+
If weighted, it will use 1/distance as weights to compute the KNN average. If min_bound is
156+
an extremely small value, this will default to 1/(1+distance) as weights to avoid division by 0.
157+
parallel : bool
158+
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
159+
are running only this expression, and not in group_by context.
160+
min_bound
161+
Min distance (>=) for a neighbor to be part of the average calculation. This prevents "identical"
162+
points from being part of the average and prevents division by 0. Note that this filter is applied
163+
after getting k nearest neighbors.
164+
max_bound
165+
Max distance the neighbors must be within (<)
166+
"""
167+
if k < 1:
168+
raise ValueError("Input `k` must be >= 1.")
169+
170+
idx = str_to_expr(target).cast(pl.Float64).rechunk()
171+
feats = [str_to_expr(f) for f in features]
172+
keep_data = ~pl.any_horizontal(f.is_null() for f in feats)
173+
cols = [idx, keep_data]
174+
cols.extend(feats)
175+
176+
kwargs = {
177+
"k": k,
178+
"metric": str(dist).lower(),
179+
"weighted": weighted,
180+
"parallel": parallel,
181+
"min_bound": min_bound,
182+
"max_bound": max_bound,
183+
}
184+
185+
return pl_plugin(
186+
symbol="pl_knn_avg",
187+
args=cols,
188+
kwargs=kwargs,
189+
is_elementwise=True,
190+
)
191+
192+
122193
def within_dist_from(
123194
*features: StrOrExpr,
124195
pt: Iterable[float],

src/arkadia/arkadia_any.rs

+9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
use std::fmt::Debug;
2+
13
/// A Kdtree
24
use crate::arkadia::{leaf::KdLeaf, suggest_capacity, Leaf, SplitMethod, KDTQ, NB};
35
use num::Float;
46

7+
use super::KNNRegressor;
8+
59
#[derive(Clone, PartialEq, Eq)]
610
pub enum DIST<T: Float + 'static> {
711
L1,
@@ -446,6 +450,11 @@ impl<'a, T: Float + 'static + std::fmt::Debug, A: Copy> KDTQ<'a, T, A> for AnyKD
446450
}
447451
}
448452

453+
impl<'a, T: Float + 'static + std::fmt::Debug + Into<f64>, A: Float + Into<f64>>
454+
KNNRegressor<'a, T, A> for AnyKDT<'a, T, A>
455+
{
456+
}
457+
449458
#[cfg(test)]
450459
mod tests {
451460
use super::super::matrix_to_leaves;

src/arkadia/mod.rs

+75-41
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub mod utils;
2020
pub use arkadia_any::{AnyKDT, DIST};
2121
pub use leaf::{KdLeaf, Leaf};
2222
pub use neighbor::NB;
23+
use serde::Deserialize;
2324
pub use utils::{
2425
matrix_to_empty_leaves, matrix_to_leaves, matrix_to_leaves_w_row_num, suggest_capacity,
2526
SplitMethod,
@@ -28,19 +29,24 @@ pub use utils::{
2829
// ---------------------------------------------------------------------------------------------------------
2930
use num::Float;
3031

31-
#[derive(Clone, Default)]
32+
#[derive(Clone, Copy, Default, Deserialize)]
3233
pub enum KNNMethod {
33-
DInvW, // Distance Inversion Weighted. E.g. Use (1/(1+d)) to weight the regression / classification
34+
P1Weighted, // Distance Inversion Weighted. E.g. Use (1/(1+d)) to weight the regression / classification
35+
Weighted, // Distance Inversion Weighted. E.g. Use (1/d) to weight the regression / classification
3436
#[default]
35-
NoW, // No Weight
37+
NotWeighted, // No Weight
3638
}
3739

38-
impl From<bool> for KNNMethod {
39-
fn from(weighted: bool) -> Self {
40+
impl KNNMethod {
41+
pub fn new(weighted: bool, min_dist: f64) -> Self {
4042
if weighted {
41-
KNNMethod::DInvW
43+
if min_dist <= f64::epsilon() {
44+
Self::P1Weighted
45+
} else {
46+
Self::Weighted
47+
}
4248
} else {
43-
KNNMethod::NoW
49+
Self::NotWeighted
4450
}
4551
}
4652
}
@@ -94,7 +100,7 @@ pub trait KDTQ<'a, T: Float + 'static, A> {
94100
if k == 0
95101
|| (point.len() != self.dim())
96102
|| (point.iter().any(|x| !x.is_finite()))
97-
|| max_dist_bound <= T::zero() + T::epsilon()
103+
|| max_dist_bound <= T::epsilon()
98104
{
99105
None
100106
} else {
@@ -161,31 +167,69 @@ pub trait KDTQ<'a, T: Float + 'static, A> {
161167
pub trait KNNRegressor<'a, T: Float + Into<f64> + 'static, A: Float + Into<f64>>:
162168
KDTQ<'a, T, A>
163169
{
164-
fn knn_regress(&self, k: usize, point: &[T], max_dist_bound: T, how: KNNMethod) -> Option<f64> {
165-
let knn = self.knn_bounded(k, point, max_dist_bound, T::zero());
170+
fn knn_regress(
171+
&self,
172+
k: usize,
173+
point: &[T],
174+
min_dist_bound: T,
175+
max_dist_bound: T,
176+
how: KNNMethod,
177+
) -> Option<f64> {
178+
let knn = self
179+
.knn_bounded(k, point, max_dist_bound, T::zero())
180+
.map(|nn| {
181+
nn.into_iter()
182+
.filter(|nb| nb.dist >= min_dist_bound)
183+
.collect::<Vec<_>>()
184+
});
166185
match knn {
167186
Some(nn) => match how {
168-
KNNMethod::DInvW => {
169-
let weights = nn
170-
.iter()
171-
.map(|nb| (nb.dist + T::one()).recip().into())
172-
.collect::<Vec<f64>>();
173-
let sum = weights.iter().copied().sum::<f64>();
174-
Some(
175-
nn.into_iter()
176-
.zip(weights.into_iter())
177-
.fold(0f64, |acc, (nb, w)| acc + w * nb.to_item().into())
178-
/ sum,
179-
)
187+
KNNMethod::P1Weighted => {
188+
if nn.is_empty() {
189+
None
190+
} else {
191+
let weights = nn
192+
.iter()
193+
.map(|nb| (T::one() + nb.dist).recip().into())
194+
.collect::<Vec<f64>>();
195+
let sum = weights.iter().copied().sum::<f64>();
196+
Some(
197+
nn.into_iter()
198+
.zip(weights.into_iter())
199+
.fold(0f64, |acc, (nb, w)| acc + w * nb.to_item().into())
200+
/ sum,
201+
)
202+
}
203+
}
204+
KNNMethod::Weighted => {
205+
if nn.is_empty() {
206+
None
207+
} else {
208+
let weights = nn
209+
.iter()
210+
.map(|nb| nb.dist.recip().into())
211+
.collect::<Vec<f64>>();
212+
let sum = weights.iter().copied().sum::<f64>();
213+
Some(
214+
nn.into_iter()
215+
.zip(weights.into_iter())
216+
.fold(0f64, |acc, (nb, w)| acc + w * nb.to_item().into())
217+
/ sum,
218+
)
219+
}
180220
}
181-
KNNMethod::NoW => {
182-
let n = nn.len() as f64;
183-
Some(
184-
nn.into_iter()
185-
.fold(A::zero(), |acc, nb| acc + nb.to_item())
186-
.into()
187-
/ n,
188-
)
221+
KNNMethod::NotWeighted => {
222+
if nn.is_empty() {
223+
None
224+
} else {
225+
let n = nn.len() as f64;
226+
Some(
227+
nn.into_iter()
228+
.fold(A::zero(), |acc, nb| acc + nb.to_item())
229+
.into()
230+
/ n,
231+
)
232+
}
189233
}
190234
},
191235
None => None,
@@ -196,17 +240,7 @@ pub trait KNNRegressor<'a, T: Float + Into<f64> + 'static, A: Float + Into<f64>>
196240
pub trait KNNClassifier<'a, T: Float + 'static>: KDTQ<'a, T, u32> {
197241
fn knn_classif(&self, k: usize, point: &[T], max_dist_bound: T, how: KNNMethod) -> Option<u32> {
198242
let knn = self.knn_bounded(k, point, max_dist_bound, T::zero());
199-
match knn {
200-
Some(nn) => match how {
201-
KNNMethod::DInvW => {
202-
todo!()
203-
}
204-
KNNMethod::NoW => {
205-
todo!()
206-
}
207-
},
208-
None => None,
209-
}
243+
todo!()
210244
}
211245
}
212246

src/linalg/lstsq.rs

+10-12
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ pub fn faer_rolling_lstsq(x: MatRef<f64>, y: MatRef<f64>, n: usize) -> Vec<Mat<f
184184
let mut weights = &inv * x0t * y0;
185185
coefficients.push(weights.to_owned());
186186
for j in n..xn {
187-
let remove_x = x.get(j-n..j-n+1, ..);
188-
let remove_y = y.get(j-n..j-n+1, ..);
187+
let remove_x = x.get(j - n..j - n + 1, ..);
188+
let remove_y = y.get(j - n..j - n + 1, ..);
189189
woodbury_step(inv.as_mut(), weights.as_mut(), remove_x, remove_y, -1.0);
190190

191191
let next_x = x.get(j..j + 1, ..); // 1 by m, m = # of columns
@@ -201,18 +201,17 @@ pub fn faer_rolling_lstsq(x: MatRef<f64>, y: MatRef<f64>, n: usize) -> Vec<Mat<f
201201
/// https://en.wikipedia.org/wiki/Woodbury_matrix_identity
202202
#[inline(always)]
203203
fn woodbury_step(
204-
inverse: MatMut<f64>,
205-
weights: MatMut<f64>,
206-
new_x: MatRef<f64>,
204+
inverse: MatMut<f64>,
205+
weights: MatMut<f64>,
206+
new_x: MatRef<f64>,
207207
new_y: MatRef<f64>,
208-
c: f64 // Should be +1 or -1, for a "update" and a "removal"
208+
c: f64, // Should be +1 or -1, for a "update" and a "removal"
209209
) {
210-
211-
// It is truly amazing that the C in the Woodbury identity essentially controls the update and
210+
// It is truly amazing that the C in the Woodbury identity essentially controls the update and
212211
// and removal of a new record (rolling)... Linear regression seems to be designed by God to work so well
213212

214213
let left = &inverse * new_x.transpose(); // corresponding to u in the reference
215-
// right = left.transpose() by the fact that if A is symmetric, invertible, A-1 is also symmetric
214+
// right = left.transpose() by the fact that if A is symmetric, invertible, A-1 is also symmetric
216215
let z = (c + (new_x * &left).read(0, 0)).recip();
217216
// Update the inverse
218217
faer::linalg::matmul::matmul(
@@ -225,7 +224,7 @@ fn woodbury_step(
225224
); // inv is updated
226225

227226
// Difference from esitmate using prior weights vs. actual next y
228-
let y_diff = new_y - (new_x * &weights);
227+
let y_diff = new_y - (new_x * &weights);
229228
// Update weights
230229
faer::linalg::matmul::matmul(
231230
weights,
@@ -235,5 +234,4 @@ fn woodbury_step(
235234
z,
236235
faer::Parallelism::Rayon(0), //
237236
); // weights are updated
238-
239-
}
237+
}

0 commit comments

Comments
 (0)