2323 from pointblank ._typing import AbsoluteTolBounds
2424
2525
26+ def _safe_modify_datetime_compare_val (data_frame : Any , column : str , compare_val : Any ) -> Any :
27+ """
28+ Safely modify datetime comparison values for LazyFrame compatibility.
29+
30+ This function handles the case where we can't directly slice LazyFrames
31+ to get column dtypes for datetime conversion.
32+ """
33+ try :
34+ # First try to get column dtype from schema for LazyFrames
35+ column_dtype = None
36+
37+ if hasattr (data_frame , "collect_schema" ):
38+ schema = data_frame .collect_schema ()
39+ column_dtype = schema .get (column )
40+ elif hasattr (data_frame , "schema" ):
41+ schema = data_frame .schema
42+ column_dtype = schema .get (column )
43+
44+ # If we got a dtype from schema, use it
45+ if column_dtype is not None :
46+ # Create a mock column object for _modify_datetime_compare_val
47+ class MockColumn :
48+ def __init__ (self , dtype ):
49+ self .dtype = dtype
50+
51+ mock_column = MockColumn (column_dtype )
52+ return _modify_datetime_compare_val (tgt_column = mock_column , compare_val = compare_val )
53+
54+ # Fallback: try collecting a small sample if possible
55+ try :
56+ sample = data_frame .head (1 ).collect ()
57+ if hasattr (sample , "dtypes" ) and column in sample .columns :
58+ # For pandas-like dtypes
59+ column_dtype = sample .dtypes [column ] if hasattr (sample , "dtypes" ) else None
60+ if column_dtype :
61+
62+ class MockColumn :
63+ def __init__ (self , dtype ):
64+ self .dtype = dtype
65+
66+ mock_column = MockColumn (column_dtype )
67+ return _modify_datetime_compare_val (
68+ tgt_column = mock_column , compare_val = compare_val
69+ )
70+ except Exception :
71+ pass
72+
73+ # Final fallback: try direct access (for eager DataFrames)
74+ try :
75+ if hasattr (data_frame , "dtypes" ) and column in data_frame .columns :
76+ column_dtype = data_frame .dtypes [column ]
77+
78+ class MockColumn :
79+ def __init__ (self , dtype ):
80+ self .dtype = dtype
81+
82+ mock_column = MockColumn (column_dtype )
83+ return _modify_datetime_compare_val (tgt_column = mock_column , compare_val = compare_val )
84+ except Exception :
85+ pass
86+
87+ except Exception :
88+ pass
89+
90+ # If all else fails, return the original compare_val
91+ return compare_val
92+
93+
2694@dataclass
2795class Interrogator :
2896 """
@@ -136,9 +204,7 @@ def gt(self) -> FrameT | Any:
136204
137205 compare_expr = _get_compare_expr_nw (compare = self .compare )
138206
139- compare_expr = _modify_datetime_compare_val (
140- tgt_column = self .x [self .column ], compare_val = compare_expr
141- )
207+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
142208
143209 return (
144210 self .x .with_columns (
@@ -211,9 +277,7 @@ def lt(self) -> FrameT | Any:
211277
212278 compare_expr = _get_compare_expr_nw (compare = self .compare )
213279
214- compare_expr = _modify_datetime_compare_val (
215- tgt_column = self .x [self .column ], compare_val = compare_expr
216- )
280+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
217281
218282 return (
219283 self .x .with_columns (
@@ -329,9 +393,7 @@ def eq(self) -> FrameT | Any:
329393 else :
330394 compare_expr = _get_compare_expr_nw (compare = self .compare )
331395
332- compare_expr = _modify_datetime_compare_val (
333- tgt_column = self .x [self .column ], compare_val = compare_expr
334- )
396+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
335397
336398 tbl = self .x .with_columns (
337399 pb_is_good_1 = nw .col (self .column ).is_null () & self .na_pass ,
@@ -421,9 +483,7 @@ def ne(self) -> FrameT | Any:
421483 ).to_native ()
422484
423485 else :
424- compare_expr = _modify_datetime_compare_val (
425- tgt_column = self .x [self .column ], compare_val = self .compare
426- )
486+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , self .compare )
427487
428488 return self .x .with_columns (
429489 pb_is_good_ = nw .col (self .column ) != nw .lit (compare_expr ),
@@ -544,9 +604,7 @@ def ne(self) -> FrameT | Any:
544604 if ref_col_has_null_vals :
545605 # Create individual cases for Pandas and Polars
546606
547- compare_expr = _modify_datetime_compare_val (
548- tgt_column = self .x [self .column ], compare_val = self .compare
549- )
607+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , self .compare )
550608
551609 if is_pandas_dataframe (self .x .to_native ()):
552610 tbl = self .x .with_columns (
@@ -584,6 +642,25 @@ def ne(self) -> FrameT | Any:
584642
585643 return tbl
586644
645+ else :
646+ # Generic case for other DataFrame types (PySpark, etc.)
647+ # Use similar logic to Polars but handle potential differences
648+ tbl = self .x .with_columns (
649+ pb_is_good_1 = nw .col (self .column ).is_null (), # val is Null in Column
650+ pb_is_good_2 = nw .lit (self .na_pass ), # Pass if any Null in val or compare
651+ )
652+
653+ tbl = tbl .with_columns (pb_is_good_3 = nw .col (self .column ) != nw .lit (compare_expr ))
654+
655+ tbl = tbl .with_columns (
656+ pb_is_good_ = (
657+ (nw .col ("pb_is_good_1" ) & nw .col ("pb_is_good_2" ))
658+ | (nw .col ("pb_is_good_3" ) & ~ nw .col ("pb_is_good_1" ))
659+ )
660+ )
661+
662+ return tbl .drop ("pb_is_good_1" , "pb_is_good_2" , "pb_is_good_3" ).to_native ()
663+
587664 def ge (self ) -> FrameT | Any :
588665 # Ibis backends ---------------------------------------------
589666
@@ -629,9 +706,7 @@ def ge(self) -> FrameT | Any:
629706
630707 compare_expr = _get_compare_expr_nw (compare = self .compare )
631708
632- compare_expr = _modify_datetime_compare_val (
633- tgt_column = self .x [self .column ], compare_val = compare_expr
634- )
709+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
635710
636711 tbl = (
637712 self .x .with_columns (
@@ -702,9 +777,7 @@ def le(self) -> FrameT | Any:
702777
703778 compare_expr = _get_compare_expr_nw (compare = self .compare )
704779
705- compare_expr = _modify_datetime_compare_val (
706- tgt_column = self .x [self .column ], compare_val = compare_expr
707- )
780+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
708781
709782 return (
710783 self .x .with_columns (
@@ -834,10 +907,8 @@ def between(self) -> FrameT | Any:
834907 low_val = _get_compare_expr_nw (compare = self .low )
835908 high_val = _get_compare_expr_nw (compare = self .high )
836909
837- low_val = _modify_datetime_compare_val (tgt_column = self .x [self .column ], compare_val = low_val )
838- high_val = _modify_datetime_compare_val (
839- tgt_column = self .x [self .column ], compare_val = high_val
840- )
910+ low_val = _safe_modify_datetime_compare_val (self .x , self .column , low_val )
911+ high_val = _safe_modify_datetime_compare_val (self .x , self .column , high_val )
841912
842913 tbl = self .x .with_columns (
843914 pb_is_good_1 = nw .col (self .column ).is_null (), # val is Null in Column
@@ -1026,10 +1097,8 @@ def outside(self) -> FrameT | Any:
10261097 low_val = _get_compare_expr_nw (compare = self .low )
10271098 high_val = _get_compare_expr_nw (compare = self .high )
10281099
1029- low_val = _modify_datetime_compare_val (tgt_column = self .x [self .column ], compare_val = low_val )
1030- high_val = _modify_datetime_compare_val (
1031- tgt_column = self .x [self .column ], compare_val = high_val
1032- )
1100+ low_val = _safe_modify_datetime_compare_val (self .x , self .column , low_val )
1101+ high_val = _safe_modify_datetime_compare_val (self .x , self .column , high_val )
10331102
10341103 tbl = self .x .with_columns (
10351104 pb_is_good_1 = nw .col (self .column ).is_null (), # val is Null in Column
@@ -1209,14 +1278,15 @@ def rows_distinct(self) -> FrameT | Any:
12091278 else :
12101279 columns_subset = self .columns_subset
12111280
1212- # Create a subset of the table with only the columns of interest
1213- subset_tbl = tbl .select (columns_subset )
1281+ # Create a count of duplicates using group_by approach like Ibis backend
1282+ # Group by the columns of interest and count occurrences
1283+ count_tbl = tbl .group_by (columns_subset ).agg (nw .len ().alias ("pb_count_" ))
12141284
1215- # Check for duplicates in the subset table, creating a series of booleans
1216- pb_is_good_series = subset_tbl . is_duplicated ( )
1285+ # Join back to original table to get count for each row
1286+ tbl = tbl . join ( count_tbl , on = columns_subset , how = "left" )
12171287
1218- # Add the series to the input table
1219- tbl = tbl .with_columns (pb_is_good_ = ~ pb_is_good_series )
1288+ # Passing rows will have the value `1` (no duplicates, so True), otherwise False applies
1289+ tbl = tbl .with_columns (pb_is_good_ = nw . col ( "pb_count_" ) == 1 ). drop ( "pb_count_" )
12201290
12211291 return tbl .to_native ()
12221292
@@ -2088,6 +2158,8 @@ def get_test_results(self):
20882158 return self ._get_pandas_results ()
20892159 elif "duckdb" in self .tbl_type or "ibis" in self .tbl_type :
20902160 return self ._get_ibis_results ()
2161+ elif "pyspark" in self .tbl_type :
2162+ return self ._get_pyspark_results ()
20912163 else : # pragma: no cover
20922164 raise NotImplementedError (f"Support for { self .tbl_type } is not yet implemented" )
20932165
@@ -2247,6 +2319,53 @@ def _get_ibis_results(self):
22472319 results_tbl = self .data_tbl .mutate (pb_is_good_ = ibis .literal (True ))
22482320 return results_tbl
22492321
2322+ def _get_pyspark_results (self ):
2323+ """Process expressions for PySpark DataFrames."""
2324+ from pyspark .sql import functions as F
2325+
2326+ pyspark_columns = []
2327+
2328+ for expr_fn in self .expressions :
2329+ try :
2330+ # First try direct evaluation with PySpark DataFrame
2331+ expr_result = expr_fn (self .data_tbl )
2332+
2333+ # Check if it's a PySpark Column
2334+ if hasattr (expr_result , "_jc" ): # PySpark Column has _jc attribute
2335+ pyspark_columns .append (expr_result )
2336+ else :
2337+ raise TypeError (
2338+ f"Expression returned { type (expr_result )} , expected PySpark Column"
2339+ )
2340+
2341+ except Exception as e :
2342+ try :
2343+ # Try as a ColumnExpression (for pb.expr_col style)
2344+ col_expr = expr_fn (None )
2345+
2346+ if hasattr (col_expr , "to_pyspark_expr" ):
2347+ # Convert to PySpark expression
2348+ pyspark_expr = col_expr .to_pyspark_expr (self .data_tbl )
2349+ pyspark_columns .append (pyspark_expr )
2350+ else :
2351+ raise TypeError (f"Cannot convert { type (col_expr )} to PySpark Column" )
2352+ except Exception as nested_e :
2353+ print (f"Error evaluating PySpark expression: { e } -> { nested_e } " )
2354+
2355+ # Combine results with AND logic
2356+ if pyspark_columns :
2357+ final_result = pyspark_columns [0 ]
2358+ for col in pyspark_columns [1 :]:
2359+ final_result = final_result & col
2360+
2361+ # Create results table with boolean column
2362+ results_tbl = self .data_tbl .withColumn ("pb_is_good_" , final_result )
2363+ return results_tbl
2364+
2365+ # Default case
2366+ results_tbl = self .data_tbl .withColumn ("pb_is_good_" , F .lit (True ))
2367+ return results_tbl
2368+
22502369
22512370class SpeciallyValidation :
22522371 def __init__ (self , data_tbl , expression , threshold , tbl_type ):
@@ -2359,13 +2478,22 @@ class NumberOfTestUnits:
23592478 column : str
23602479
23612480 def get_test_units (self , tbl_type : str ) -> int :
2362- if tbl_type == "pandas" or tbl_type == "polars" :
2481+ if (
2482+ tbl_type == "pandas"
2483+ or tbl_type == "polars"
2484+ or tbl_type == "pyspark"
2485+ or tbl_type == "local"
2486+ ):
23632487 # Convert the DataFrame to a format that narwhals can work with and:
23642488 # - check if the column exists
23652489 dfn = _column_test_prep (
23662490 df = self .df , column = self .column , allowed_types = None , check_exists = False
23672491 )
23682492
2493+ # Handle LazyFrames which don't have len()
2494+ if hasattr (dfn , "collect" ):
2495+ dfn = dfn .collect ()
2496+
23692497 return len (dfn )
23702498
23712499 if tbl_type in IBIS_BACKENDS :
@@ -2383,7 +2511,22 @@ def _get_compare_expr_nw(compare: Any) -> Any:
23832511
23842512
23852513def _column_has_null_values (table : FrameT , column : str ) -> bool :
2386- null_count = (table .select (column ).null_count ())[column ][0 ]
2514+ try :
2515+ # Try the standard null_count() method
2516+ null_count = (table .select (column ).null_count ())[column ][0 ]
2517+ except AttributeError :
2518+ # For LazyFrames, collect first then get null count
2519+ try :
2520+ collected = table .select (column ).collect ()
2521+ null_count = (collected .null_count ())[column ][0 ]
2522+ except Exception :
2523+ # Fallback: check if any values are null
2524+ try :
2525+ result = table .select (nw .col (column ).is_null ().sum ().alias ("null_count" )).collect ()
2526+ null_count = result ["null_count" ][0 ]
2527+ except Exception :
2528+ # Last resort: return False (assume no nulls)
2529+ return False
23872530
23882531 if null_count is None or null_count == 0 :
23892532 return False
@@ -2414,7 +2557,7 @@ def _check_nulls_across_columns_nw(table, columns_subset):
24142557
24152558 # Build the expression by combining each column's `is_null()` with OR operations
24162559 null_expr = functools .reduce (
2417- lambda acc , col : acc | table [ col ] .is_null () if acc is not None else table [ col ] .is_null (),
2560+ lambda acc , col : acc | nw . col ( col ) .is_null () if acc is not None else nw . col ( col ) .is_null (),
24182561 column_names ,
24192562 None ,
24202563 )
0 commit comments