@@ -477,6 +477,7 @@ def random(
477
477
lower : pl .Expr | float = 0.0 ,
478
478
upper : pl .Expr | float = 1.0 ,
479
479
seed : int | None = None ,
480
+ length : int | pl .Expr = pl .len (),
480
481
) -> pl .Expr :
481
482
"""
482
483
Generate random numbers in [lower, upper)
@@ -489,24 +490,27 @@ def random(
489
490
The upper bound, exclusive
490
491
seed
491
492
The random seed. None means no seed.
493
+ length
494
+ Custom length. Note length needs to match with other columns in the context.
492
495
"""
493
496
lo = pl .lit (lower , pl .Float64 ) if isinstance (lower , float ) else lower
494
497
up = pl .lit (upper , pl .Float64 ) if isinstance (upper , float ) else upper
498
+ len_ = pl .lit (length , pl .UInt32 ) if isinstance (length , int ) else length
495
499
return pl_plugin (
496
500
symbol = "pl_random" ,
497
- args = [pl . len () , lo , up , pl .lit (seed , pl .UInt64 )],
501
+ args = [len_ , lo , up , pl .lit (seed , pl .UInt64 )],
498
502
is_elementwise = True ,
499
503
)
500
504
501
505
502
- def random_null (var : str | pl .Expr , pct : float , seed : int | None = None ) -> pl .Expr :
506
+ def random_null (x : str | pl .Expr , pct : float , seed : int | None = None ) -> pl .Expr :
503
507
"""
504
- Creates random null values in var . If var contains nulls originally, they
508
+ Creates random null values in the columns . If var contains nulls originally, they
505
509
will stay null.
506
510
507
511
Parameters
508
512
----------
509
- var
513
+ x
510
514
Either the name of the column or a Polars expression
511
515
pct
512
516
Percentage of nulls to randomly generate. This percentage is based on the
@@ -518,11 +522,15 @@ def random_null(var: str | pl.Expr, pct: float, seed: int | None = None) -> pl.E
518
522
if pct <= 0.0 or pct >= 1.0 :
519
523
raise ValueError ("Input `pct` must be > 0 and < 1" )
520
524
521
- to_null = random (0.0 , 1.0 , seed = seed ) < pct
522
- return pl .when (to_null ).then (None ).otherwise (str_to_expr (var ))
525
+ return pl .when (random (0.0 , 1.0 , seed = seed ) < pct ).then (None ).otherwise (str_to_expr (x ))
523
526
524
527
525
- def random_int (lower : int | pl .Expr , upper : int | pl .Expr , seed : int | None = None ) -> pl .Expr :
528
+ def random_int (
529
+ lower : int | pl .Expr ,
530
+ upper : int | pl .Expr ,
531
+ seed : int | None = None ,
532
+ length : int | pl .Expr = pl .len (),
533
+ ) -> pl .Expr :
526
534
"""
527
535
Generates random integer between lower and upper.
528
536
@@ -534,16 +542,19 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No
534
542
The upper bound, exclusive
535
543
seed
536
544
The random seed. None means no seed.
545
+ length
546
+ Custom length. Note length needs to match with other columns in the context.
537
547
"""
538
548
if lower == upper :
539
549
raise ValueError ("Input `lower` must be smaller than `higher`" )
540
550
541
551
lo = pl .lit (lower , pl .Int32 ) if isinstance (lower , int ) else lower .cast (pl .Int32 )
542
552
hi = pl .lit (upper , pl .Int32 ) if isinstance (upper , int ) else upper .cast (pl .Int32 )
553
+ len_ = pl .lit (length , pl .UInt32 ) if isinstance (length , int ) else length
543
554
return pl_plugin (
544
555
symbol = "pl_rand_int" ,
545
556
args = [
546
- pl . len (). cast ( pl . UInt32 ) ,
557
+ len_ ,
547
558
lo ,
548
559
hi ,
549
560
pl .lit (seed , pl .UInt64 ),
@@ -552,7 +563,12 @@ def random_int(lower: int | pl.Expr, upper: int | pl.Expr, seed: int | None = No
552
563
)
553
564
554
565
555
- def random_str (min_size : int , max_size : int , seed : int | None = None ) -> pl .Expr :
566
+ def random_str (
567
+ min_size : int ,
568
+ max_size : int ,
569
+ seed : int | None = None ,
570
+ length : int | pl .Expr = pl .len (),
571
+ ) -> pl .Expr :
556
572
"""
557
573
Generates random strings of length between min_size and max_size. The characters are
558
574
uniformly distributed over ASCII letters and numbers: a-z, A-Z and 0-9.
@@ -565,6 +581,8 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
565
581
The max size of the string, inclusive
566
582
seed
567
583
The random seed. None means no seed.
584
+ length
585
+ Custom length. Note length needs to match with other columns in the context.
568
586
"""
569
587
mi , ma = min_size , max_size
570
588
if min_size > max_size :
@@ -573,7 +591,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
573
591
return pl_plugin (
574
592
symbol = "pl_rand_str" ,
575
593
args = [
576
- pl .len (). cast ( pl .UInt32 ),
594
+ pl .lit ( length , pl .UInt32 ) if isinstance ( length , int ) else length ,
577
595
pl .lit (mi , pl .UInt32 ),
578
596
pl .lit (ma , pl .UInt32 ),
579
597
pl .lit (seed , pl .UInt64 ),
@@ -582,7 +600,7 @@ def random_str(min_size: int, max_size: int, seed: int | None = None) -> pl.Expr
582
600
)
583
601
584
602
585
- def random_binomial (n : int , p : int , seed : int | None = None ) -> pl .Expr :
603
+ def random_binomial (n : int , p : int , seed : int | None = None , length : int | pl . Expr = pl . len () ) -> pl .Expr :
586
604
"""
587
605
Generates random integer following a binomial distribution.
588
606
@@ -594,14 +612,16 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
594
612
The p in a binomial distribution
595
613
seed
596
614
The random seed. None means no seed.
615
+ length
616
+ Custom length. Note length needs to match with other columns in the context.
597
617
"""
598
618
if n < 1 :
599
619
raise ValueError ("Input `n` must be > 1." )
600
620
601
621
return pl_plugin (
602
622
symbol = "pl_rand_binomial" ,
603
623
args = [
604
- pl .len (). cast ( pl .UInt32 ),
624
+ pl .lit ( length , pl .UInt32 ) if isinstance ( length , int ) else length ,
605
625
pl .lit (n , pl .Int32 ),
606
626
pl .lit (p , pl .Float64 ),
607
627
pl .lit (seed , pl .UInt64 ),
@@ -610,7 +630,7 @@ def random_binomial(n: int, p: int, seed: int | None = None) -> pl.Expr:
610
630
)
611
631
612
632
613
- def random_exp (lambda_ : float , seed : int | None = None ) -> pl .Expr :
633
+ def random_exp (lambda_ : float , seed : int | None = None , length : int | pl . Expr = pl . len () ) -> pl .Expr :
614
634
"""
615
635
Generates random numbers following an exponential distribution.
616
636
@@ -620,15 +640,26 @@ def random_exp(lambda_: float, seed: int | None = None) -> pl.Expr:
620
640
The lambda in an exponential distribution
621
641
seed
622
642
The random seed. None means no seed.
643
+ length
644
+ Custom length. Note length needs to match with other columns in the context.
623
645
"""
624
646
return pl_plugin (
625
647
symbol = "pl_rand_exp" ,
626
- args = [pl .len ().cast (pl .UInt32 ), pl .lit (lambda_ , pl .Float64 ), pl .lit (seed , pl .UInt64 )],
648
+ args = [
649
+ pl .lit (length , pl .UInt32 ) if isinstance (length , int ) else length ,
650
+ pl .lit (lambda_ , pl .Float64 ),
651
+ pl .lit (seed , pl .UInt64 )
652
+ ],
627
653
is_elementwise = True ,
628
654
)
629
655
630
656
631
- def random_normal (mean : pl .Expr | float , std : pl .Expr | float , seed : int | None = None ) -> pl .Expr :
657
+ def random_normal (
658
+ mean : pl .Expr | float ,
659
+ std : pl .Expr | float ,
660
+ seed : int | None = None ,
661
+ length : int | pl .Expr = pl .len ()
662
+ ) -> pl .Expr :
632
663
"""
633
664
Generates random number following a normal distribution.
634
665
@@ -640,12 +671,17 @@ def random_normal(mean: pl.Expr | float, std: pl.Expr | float, seed: int | None
640
671
The std in a normal distribution
641
672
seed
642
673
The random seed. None means no seed.
674
+ length
675
+ Custom length. Note length needs to match with other columns in the context.
643
676
"""
644
- m = pl .lit (mean , pl .Float64 ) if isinstance (mean , float ) else mean
645
- s = pl .lit (std , pl .Float64 ) if isinstance (std , float ) else std
646
677
return pl_plugin (
647
678
symbol = "pl_rand_normal" ,
648
- args = [pl .len ().cast (pl .UInt32 ), m , s , pl .lit (seed , pl .UInt64 )],
679
+ args = [
680
+ pl .lit (length , pl .UInt32 ) if isinstance (length , int ) else length ,
681
+ pl .lit (mean , pl .Float64 ) if isinstance (mean , float ) else mean ,
682
+ pl .lit (std , pl .Float64 ) if isinstance (std , float ) else std ,
683
+ pl .lit (seed , pl .UInt64 )
684
+ ],
649
685
is_elementwise = True ,
650
686
)
651
687
0 commit comments