diff --git a/mizani/bounds.py b/mizani/bounds.py index a93f529..4b81040 100644 --- a/mizani/bounds.py +++ b/mizani/bounds.py @@ -25,20 +25,19 @@ import sys import typing from copy import copy -from typing import overload import numpy as np import pandas as pd -from .utils import get_null_value, is_vector +from .utils import get_null_value if typing.TYPE_CHECKING: - from typing import Any, Optional, Sequence + from typing import Any, Optional from mizani.typing import ( FloatArrayLike, - FloatSeries, NDArrayFloat, + TFloatVector, TupleFloat2, TupleFloat4, ) @@ -129,6 +128,19 @@ def rescale_mid( array([0.5 , 0.75, 1. ]) >>> rescale_mid([1, 2, 3], mid=2) array([0. , 0.5, 1. ]) + + `rescale_mid` does have the same signature as `rescale` and + `rescale_max`. In cases where we need a compatible function with + the same signature, we use a closure around the extra `mid` argument. + + >>> def rescale_mid_compat(mid): + ... def _rescale(x, to=(0, 1), _from=None): + ... return rescale_mid(x, to, _from, mid=mid) + ... return _rescale + + >>> rescale_mid2 = rescale_mid_compat(mid=2) + >>> rescale_mid2([1, 2, 3]) + array([0. , 0.5, 1. ]) """ __from: NDArrayFloat = np.array( (np.min(x), np.max(x)) if _from is None else _from @@ -280,25 +292,11 @@ def squish( return _x -@overload def censor( - x: NDArrayFloat | Sequence[float], + x: TFloatVector, range: TupleFloat2 = (0, 1), only_finite: bool = True, -) -> NDArrayFloat: ... - - -@overload -def censor( - x: FloatSeries, range: TupleFloat2 = (0, 1), only_finite: bool = True -) -> FloatSeries: ... - - -def censor( - x: NDArrayFloat | Sequence[float] | FloatSeries, - range: TupleFloat2 = (0, 1), - only_finite: bool = True, -) -> NDArrayFloat | FloatSeries: +) -> TFloatVector: """ Convert any values outside of range to a **NULL** type object. @@ -340,11 +338,9 @@ def censor( - :class:`datetime.timedelta` : :py:`np.timedelta64(NaT)` """ + res = copy(x) if not len(x): - return np.array([]) - - if not is_vector(x): - x = np.asarray(x) + return res null = get_null_value(x) @@ -360,10 +356,9 @@ def censor( with np.errstate(invalid="ignore"): outside = (x < range[0]) | (x > range[1]) bool_idx = finite & outside - res = copy(x) if bool_idx.any(): - if res.dtype.kind == "i": - res = np.asarray(res, dtype=float) + if res.dtype == int: + res = res.astype(float) res[bool_idx] = null return res diff --git a/mizani/palettes.py b/mizani/palettes.py index 4f09667..95902b8 100644 --- a/mizani/palettes.py +++ b/mizani/palettes.py @@ -36,7 +36,7 @@ from .utils import identity if TYPE_CHECKING: - from typing import Any, Literal, Optional, Sequence + from typing import Any, Literal, Optional, Sequence, TypeVar from mizani.typing import ( Callable, @@ -49,6 +49,8 @@ TupleFloat3, ) + T = TypeVar("T") + __all__ = [ "hls_palette", @@ -67,6 +69,8 @@ "xkcd_palette", "crayon_palette", "cubehelix_pal", + "identity_pal", + "none_pal", ] @@ -831,7 +835,7 @@ def __call__(self, n: int) -> Sequence[RGBHexColor]: return self._chmap.discrete_palette(n) -def identity_pal() -> Callable[[], Any]: +def identity_pal() -> Callable[[T], T]: """ Create palette that maps values onto themselves @@ -850,3 +854,19 @@ def identity_pal() -> Callable[[], Any]: [2, 4, 6] """ return identity + + +@dataclass +class none_pal(_discrete_pal): + """ + Discrete palette that returns only None values + + Example + ------- + >>> palette = none_pal() + >>> palette(5) + [None, None, None, None, None] + """ + + def __call__(self, n: int) -> Sequence[None]: + return [None] * n diff --git a/mizani/transforms.py b/mizani/transforms.py index 6e9b91e..d24bf0e 100644 --- a/mizani/transforms.py +++ b/mizani/transforms.py @@ -60,7 +60,6 @@ FormatFunction, InverseFunction, MinorBreaksFunction, - NDArrayAny, NDArrayDatetime, NDArrayFloat, NDArrayTimedelta, @@ -70,6 +69,7 @@ TupleFloat2, ) + __all__ = [ "asn_trans", "atanh_trans", @@ -187,7 +187,7 @@ def inverse(x: TFloatArrayLike) -> TFloatArrayLike: """ ... - def breaks(self, limits: tuple[Any, Any]) -> NDArrayAny: + def breaks(self, limits: DomainType) -> NDArrayFloat: """ Calculate breaks in data space and return them in transformed space. @@ -898,14 +898,17 @@ def inverse(x: FloatArrayLike) -> NDArrayFloat: return np.sign(x) * (np.exp(np.abs(x)) - 1) # type: ignore -def gettrans(t: str | Callable[[], Type[trans]] | Type[trans] | trans): +def gettrans( + t: str | Callable[[], Type[trans]] | Type[trans] | trans | None = None, +): """ Return a trans object Parameters ---------- t : str | callable | type | trans - name of transformation function + Name of transformation function. If None, returns an + identity transform. Returns ------- @@ -913,6 +916,9 @@ def gettrans(t: str | Callable[[], Type[trans]] | Type[trans] | trans): """ obj = t # Make sure trans object is instantiated + if t is None: + return identity_trans() + if isinstance(obj, str): name = "{}_trans".format(obj) obj = globals()[name]() diff --git a/mizani/typing.py b/mizani/typing.py index d9d308b..92decb8 100644 --- a/mizani/typing.py +++ b/mizani/typing.py @@ -75,6 +75,11 @@ # Type variable TFloatLike = TypeVar("TFloatLike", bound=NDArrayFloat | float) TFloatArrayLike = TypeVar("TFloatArrayLike", bound=FloatArrayLike) + TFloatVector = TypeVar("TFloatVector", bound=NDArrayFloat | FloatSeries) + TConstrained = TypeVar( + "TConstrained", int, float, bool, str, complex, datetime, timedelta + ) + NumericUFunction: TypeAlias = Callable[[TFloatLike], TFloatLike] # Nulls for different types @@ -178,6 +183,25 @@ class SegmentFunctionColorMapData(TypedDict): [FloatArrayLike, Optional[TupleFloat2], Optional[int]], NDArrayFloat ] + # Rescale functions + # This Protocol does not apply to rescale_mid + class PRescale(Protocol): + def __call__( + self, + x: FloatArrayLike, + to: TupleFloat2 = (0, 1), + _from: TupleFloat2 | None = None, + ) -> NDArrayFloat: ... + + # Censor functions + class PCensor(Protocol): + def __call__( + self, + x: NDArrayFloat, + range: TupleFloat2 = (0, 1), + only_finite: bool = True, + ) -> NDArrayFloat: ... + # Any type that has comparison operators can be used to define # the domain of a transformation. And implicitly the type of the # dataspace. diff --git a/mizani/utils.py b/mizani/utils.py index 9537798..a606ef4 100644 --- a/mizani/utils.py +++ b/mizani/utils.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from datetime import tzinfo - from typing import Any, Optional, Sequence, TypeGuard + from typing import Any, Optional, Sequence, TypeGuard, TypeVar from mizani.typing import ( AnyArrayLike, @@ -26,6 +26,7 @@ TupleFloat2, ) + T = TypeVar("T") __all__ = [ "round_any", @@ -269,11 +270,11 @@ def same_log10_order_of_magnitude(x, delta=0.1): return np.floor(dmin) == np.floor(dmax) -def identity(*args): +def identity(param: T) -> T: """ Return whatever is passed in """ - return args[0] if len(args) == 1 else args + return param def get_categories(x):