Skip to content

Commit 80efce7

Browse files
authored
Merge pull request #302 from abstractqqq/add_len_in_random
added len in random
2 parents 8a10f2a + e95ab8d commit 80efce7

File tree

1 file changed

+54
-18
lines changed

1 file changed

+54
-18
lines changed

python/polars_ds/stats.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def random(
477477
lower: pl.Expr | float = 0.0,
478478
upper: pl.Expr | float = 1.0,
479479
seed: int | None = None,
480+
length: int | pl.Expr = pl.len(),
480481
) -> pl.Expr:
481482
"""
482483
Generate random numbers in [lower, upper)
@@ -489,24 +490,27 @@ def random(
489490
The upper bound, exclusive
490491
seed
491492
The random seed. None means no seed.
493+
length
494+
Custom length. Note length needs to match with other columns in the context.
492495
"""
493496
lo = pl.lit(lower, pl.Float64) if isinstance(lower, float) else lower
494497
up = pl.lit(upper, pl.Float64) if isinstance(upper, float) else upper
498+
len_ = pl.lit(length, pl.UInt32) if isinstance(length, int) else length
495499
return pl_plugin(
496500
symbol="pl_random",
497-
args=[pl.len(), lo, up, pl.lit(seed, pl.UInt64)],
501+
args=[len_, lo, up, pl.lit(seed, pl.UInt64)],
498502
is_elementwise=True,
499503
)
500504

501505

502-
def random_null(var: str | pl.Expr, pct: float, seed: int | None = None) -> pl.Expr:
506+
def random_null(x: str | pl.Expr, pct: float, seed: int | None = None) -> pl.Expr:
503507
"""
504-
Creates random null values in var. If var contains nulls originally, they
508+
Creates random null values in the columns. If var contains nulls originally, they
505509
will stay null.
506510
507511
Parameters
508512
----------
509-
var
513+
x
510514
Either the name of the column or a Polars expression
511515
pct
512516
Percentage of nulls to randomly generate. This percentage is based on the
@@ -518,11 +522,15 @@ def random_null(var: str | pl.Expr, pct: float, seed: int | None = None) -> pl.E
518522
if pct <= 0.0 or pct >= 1.0:
519523
raise ValueError("Input `pct` must be > 0 and < 1")
520524

521-
to_null = random(0.0, 1.0, seed=seed) < pct
522-
return pl.when(to_null).then(None).otherwise(str_to_expr(var))
525+
return pl.when(random(0.0, 1.0, seed=seed) < pct).then(None).otherwise(str_to_expr(x))
523526

524527

525-
def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = None) -> pl.Expr:
528+
def random_int(
529+
lower: int | pl.Expr,
530+
upper: int | pl.Expr,
531+
seed: int | None = None,
532+
length: int | pl.Expr = pl.len(),
533+
) -> pl.Expr:
526534
"""
527535
Generates random integer between lower and upper.
528536
@@ -534,16 +542,19 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No
534542
The upper bound, exclusive
535543
seed
536544
The random seed. None means no seed.
545+
length
546+
Custom length. Note length needs to match with other columns in the context.
537547
"""
538548
if lower == upper:
539549
raise ValueError("Input `lower` must be smaller than `higher`")
540550

541551
lo = pl.lit(lower, pl.Int32) if isinstance(lower, int) else lower.cast(pl.Int32)
542552
hi = pl.lit(upper, pl.Int32) if isinstance(upper, int) else upper.cast(pl.Int32)
553+
len_ = pl.lit(length, pl.UInt32) if isinstance(length, int) else length
543554
return pl_plugin(
544555
symbol="pl_rand_int",
545556
args=[
546-
pl.len().cast(pl.UInt32),
557+
len_,
547558
lo,
548559
hi,
549560
pl.lit(seed, pl.UInt64),
@@ -552,7 +563,12 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No
552563
)
553564

554565

555-
def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr:
566+
def random_str(
567+
min_size: int,
568+
max_size: int,
569+
seed: int | None = None,
570+
length: int | pl.Expr = pl.len(),
571+
) -> pl.Expr:
556572
"""
557573
Generates random strings of length between min_size and max_size. The characters are
558574
uniformly distributed over ASCII letters and numbers: a-z, A-Z and 0-9.
@@ -565,6 +581,8 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
565581
The max size of the string, inclusive
566582
seed
567583
The random seed. None means no seed.
584+
length
585+
Custom length. Note length needs to match with other columns in the context.
568586
"""
569587
mi, ma = min_size, max_size
570588
if min_size > max_size:
@@ -573,7 +591,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
573591
return pl_plugin(
574592
symbol="pl_rand_str",
575593
args=[
576-
pl.len().cast(pl.UInt32),
594+
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
577595
pl.lit(mi, pl.UInt32),
578596
pl.lit(ma, pl.UInt32),
579597
pl.lit(seed, pl.UInt64),
@@ -582,7 +600,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
582600
)
583601

584602

585-
def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
603+
def random_binomial(n: int, p: int, seed: int | None = None, length: int | pl.Expr = pl.len()) -> pl.Expr:
586604
"""
587605
Generates random integer following a binomial distribution.
588606
@@ -594,14 +612,16 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
594612
The p in a binomial distribution
595613
seed
596614
The random seed. None means no seed.
615+
length
616+
Custom length. Note length needs to match with other columns in the context.
597617
"""
598618
if n < 1:
599619
raise ValueError("Input `n` must be > 1.")
600620

601621
return pl_plugin(
602622
symbol="pl_rand_binomial",
603623
args=[
604-
pl.len().cast(pl.UInt32),
624+
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
605625
pl.lit(n, pl.Int32),
606626
pl.lit(p, pl.Float64),
607627
pl.lit(seed, pl.UInt64),
@@ -610,7 +630,7 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
610630
)
611631

612632

613-
def random_exp(lambda_: float, seed: int | None = None) -> pl.Expr:
633+
def random_exp(lambda_: float, seed: int | None = None, length: int | pl.Expr = pl.len()) -> pl.Expr:
614634
"""
615635
Generates random numbers following an exponential distribution.
616636
@@ -620,15 +640,26 @@ def random_exp(lambda_: float, seed: int | None = None) -> pl.Expr:
620640
The lambda in an exponential distribution
621641
seed
622642
The random seed. None means no seed.
643+
length
644+
Custom length. Note length needs to match with other columns in the context.
623645
"""
624646
return pl_plugin(
625647
symbol="pl_rand_exp",
626-
args=[pl.len().cast(pl.UInt32), pl.lit(lambda_, pl.Float64), pl.lit(seed, pl.UInt64)],
648+
args=[
649+
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
650+
pl.lit(lambda_, pl.Float64),
651+
pl.lit(seed, pl.UInt64)
652+
],
627653
is_elementwise=True,
628654
)
629655

630656

631-
def random_normal(mean: pl.Expr | float, std: pl.Expr | float, seed: int | None = None) -> pl.Expr:
657+
def random_normal(
658+
mean: pl.Expr | float,
659+
std: pl.Expr | float,
660+
seed: int | None = None,
661+
length: int | pl.Expr = pl.len()
662+
) -> pl.Expr:
632663
"""
633664
Generates random number following a normal distribution.
634665
@@ -640,12 +671,17 @@ def random_normal(mean: pl.Expr | float, std: pl.Expr | float, seed: int | None
640671
The std in a normal distribution
641672
seed
642673
The random seed. None means no seed.
674+
length
675+
Custom length. Note length needs to match with other columns in the context.
643676
"""
644-
m = pl.lit(mean, pl.Float64) if isinstance(mean, float) else mean
645-
s = pl.lit(std, pl.Float64) if isinstance(std, float) else std
646677
return pl_plugin(
647678
symbol="pl_rand_normal",
648-
args=[pl.len().cast(pl.UInt32), m, s, pl.lit(seed, pl.UInt64)],
679+
args=[
680+
pl.lit(length, pl.UInt32) if isinstance(length, int) else length,
681+
pl.lit(mean, pl.Float64) if isinstance(mean, float) else mean,
682+
pl.lit(std, pl.Float64) if isinstance(std, float) else std,
683+
pl.lit(seed, pl.UInt64)
684+
],
649685
is_elementwise=True,
650686
)
651687

0 commit comments

Comments
 (0)