Skip to content

Commit 69d4a5a

Browse files
authored
feat: block partition (#79)
* feat: block partition * chore: update pyright * fix: export partition
1 parent 14b67ec commit 69d4a5a

File tree

3 files changed

+104
-54
lines changed

3 files changed

+104
-54
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ repos:
2828
language: node
2929
pass_filenames: false
3030
types: [python]
31-
additional_dependencies: ["[email protected].241"]
31+
additional_dependencies: ["[email protected].243"]
3232
repo: local
3333
- hooks:
3434
- id: jb-to-sphinx

expression/collections/block.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
Some,
4848
SupportsLessThan,
4949
SupportsSum,
50+
curry_flipped,
5051
pipe,
5152
)
5253
from expression.core.typing import GenericValidator, ModelField, SupportsValidation
@@ -364,6 +365,34 @@ def of_seq(xs: Iterable[_TSource]) -> Block[_TSource]:
364365
def of_option(option: Option[_TSource]) -> Block[_TSource]:
365366
return of_option(option)
366367

368+
def partition(
369+
self, predicate: Callable[[_TSource], bool]
370+
) -> Tuple[Block[_TSource], Block[_TSource]]:
371+
"""Partition block.
372+
373+
Splits the collection into two collections, containing the
374+
elements for which the given predicate returns True and False
375+
respectively. Element order is preserved in both of the created
376+
lists.
377+
378+
Args:
379+
predicate: The function to test the input elements.
380+
381+
Returns:
382+
A list containing the elements for which the predicate
383+
evaluated to true and a list containing the elements for
384+
which the predicate evaluated to false.
385+
"""
386+
list1: List[_TSource] = []
387+
list2: List[_TSource] = []
388+
389+
for item in self.value:
390+
if predicate(item):
391+
list1.append(item)
392+
else:
393+
list2.append(item)
394+
return (Block(list1), Block(list2))
395+
367396
@overload
368397
@staticmethod
369398
def range(stop: int) -> Block[int]:
@@ -580,9 +609,10 @@ def _choose(source: Block[_TSource]) -> Block[_TResult]:
580609
return _choose
581610

582611

612+
@curry_flipped(1)
583613
def collect(
584-
mapping: Callable[[_TSource], Block[_TResult]]
585-
) -> Callable[[Block[_TSource]], Block[_TResult]]:
614+
source: Block[_TSource], mapping: Callable[[_TSource], Block[_TResult]]
615+
) -> Block[_TResult]:
586616
"""For each element of the list, applies the given function.
587617
Concatenates all the results and return the combined list.
588618
@@ -595,19 +625,7 @@ def collect(
595625
list and returns the concatenation of the transformed sublists.
596626
"""
597627

598-
def _collect(source: Block[_TSource]) -> Block[_TResult]:
599-
"""For each element of the list, applies the given function.
600-
Concatenates all the results and return the combined list.
601-
602-
Args:
603-
source: The input list.
604-
605-
Returns:
606-
The concatenation of the transformed sublists.
607-
"""
608-
return source.collect(mapping)
609-
610-
return _collect
628+
return source.collect(mapping)
611629

612630

613631
def concat(sources: Iterable[Block[_TSource]]) -> Block[_TSource]:
@@ -865,6 +883,29 @@ def of_option(option: Option[_TSource]) -> Block[_TSource]:
865883
return empty
866884

867885

886+
@curry_flipped(1)
887+
def partition(
888+
source: Block[_TSource], predicate: Callable[[_TSource], bool]
889+
) -> Tuple[Block[_TSource], Block[_TSource]]:
890+
"""Partition block.
891+
892+
Splits the collection into two collections, containing the
893+
elements for which the given predicate returns True and False
894+
respectively. Element order is preserved in both of the created
895+
lists.
896+
897+
Args:
898+
source: The source block to partition (curried flipped)
899+
predicate: The function to test the input elements.
900+
901+
Returns:
902+
A list containing the elements for which the predicate
903+
evaluated to true and a list containing the elements for
904+
which the predicate evaluated to false.
905+
"""
906+
return source.partition(predicate)
907+
908+
868909
@overload
869910
def range(stop: int) -> Block[int]:
870911
...
@@ -1104,6 +1145,7 @@ def _zip(source: Block[_TSource]) -> Block[Tuple[_TSource, _TResult]]:
11041145
"choose",
11051146
"collect",
11061147
"concat",
1148+
"dict",
11071149
"empty",
11081150
"filter",
11091151
"fold",
@@ -1115,6 +1157,7 @@ def _zip(source: Block[_TSource]) -> Block[Tuple[_TSource, _TResult]]:
11151157
"mapi",
11161158
"of_seq",
11171159
"of_option",
1160+
"partition",
11181161
"singleton",
11191162
"skip",
11201163
"skip_last",
@@ -1123,7 +1166,6 @@ def _zip(source: Block[_TSource]) -> Block[Tuple[_TSource, _TResult]]:
11231166
"tail",
11241167
"take",
11251168
"take_last",
1126-
"dict",
11271169
"try_head",
11281170
"unfold",
11291171
"zip",

0 commit comments

Comments
 (0)