Skip to content

Commit

Permalink
Improve type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Jul 26, 2024
1 parent 7ca7c65 commit 055865c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 36 deletions.
49 changes: 22 additions & 27 deletions mizani/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,19 @@
import sys
import typing
from copy import copy
from typing import overload

import numpy as np
import pandas as pd

from .utils import get_null_value, is_vector
from .utils import get_null_value

if typing.TYPE_CHECKING:
from typing import Any, Optional, Sequence
from typing import Any, Optional

from mizani.typing import (
FloatArrayLike,
FloatSeries,
NDArrayFloat,
TFloatVector,
TupleFloat2,
TupleFloat4,
)
Expand Down Expand Up @@ -129,6 +128,19 @@ def rescale_mid(
array([0.5 , 0.75, 1. ])
>>> rescale_mid([1, 2, 3], mid=2)
array([0. , 0.5, 1. ])
`rescale_mid` does have the same signature as `rescale` and
`rescale_max`. In cases where we need a compatible function with
the same signature, we use a closure around the extra `mid` argument.
>>> def rescale_mid_compat(mid):
... def _rescale(x, to=(0, 1), _from=None):
... return rescale_mid(x, to, _from, mid=mid)
... return _rescale
>>> rescale_mid2 = rescale_mid_compat(mid=2)
>>> rescale_mid2([1, 2, 3])
array([0. , 0.5, 1. ])
"""
__from: NDArrayFloat = np.array(
(np.min(x), np.max(x)) if _from is None else _from
Expand Down Expand Up @@ -280,25 +292,11 @@ def squish(
return _x


@overload
def censor(
x: NDArrayFloat | Sequence[float],
x: TFloatVector,
range: TupleFloat2 = (0, 1),
only_finite: bool = True,
) -> NDArrayFloat: ...


@overload
def censor(
x: FloatSeries, range: TupleFloat2 = (0, 1), only_finite: bool = True
) -> FloatSeries: ...


def censor(
x: NDArrayFloat | Sequence[float] | FloatSeries,
range: TupleFloat2 = (0, 1),
only_finite: bool = True,
) -> NDArrayFloat | FloatSeries:
) -> TFloatVector:
"""
Convert any values outside of range to a **NULL** type object.
Expand Down Expand Up @@ -340,11 +338,9 @@ def censor(
- :class:`datetime.timedelta` : :py:`np.timedelta64(NaT)`
"""
res = copy(x)
if not len(x):
return np.array([])

if not is_vector(x):
x = np.asarray(x)
return res

null = get_null_value(x)

Expand All @@ -360,10 +356,9 @@ def censor(
with np.errstate(invalid="ignore"):
outside = (x < range[0]) | (x > range[1])
bool_idx = finite & outside
res = copy(x)
if bool_idx.any():
if res.dtype.kind == "i":
res = np.asarray(res, dtype=float)
if res.dtype == int:
res = res.astype(float)
res[bool_idx] = null
return res

Expand Down
24 changes: 22 additions & 2 deletions mizani/palettes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .utils import identity

if TYPE_CHECKING:
from typing import Any, Literal, Optional, Sequence
from typing import Any, Literal, Optional, Sequence, TypeVar

from mizani.typing import (
Callable,
Expand All @@ -49,6 +49,8 @@
TupleFloat3,
)

T = TypeVar("T")


__all__ = [
"hls_palette",
Expand All @@ -67,6 +69,8 @@
"xkcd_palette",
"crayon_palette",
"cubehelix_pal",
"identity_pal",
"none_pal",
]


Expand Down Expand Up @@ -831,7 +835,7 @@ def __call__(self, n: int) -> Sequence[RGBHexColor]:
return self._chmap.discrete_palette(n)


def identity_pal() -> Callable[[], Any]:
def identity_pal() -> Callable[[T], T]:
"""
Create palette that maps values onto themselves
Expand All @@ -850,3 +854,19 @@ def identity_pal() -> Callable[[], Any]:
[2, 4, 6]
"""
return identity


@dataclass
class none_pal(_discrete_pal):
"""
Discrete palette that returns only None values
Example
-------
>>> palette = none_pal()
>>> palette(5)
[None, None, None, None, None]
"""

def __call__(self, n: int) -> Sequence[None]:
return [None] * n
14 changes: 10 additions & 4 deletions mizani/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
FormatFunction,
InverseFunction,
MinorBreaksFunction,
NDArrayAny,
NDArrayDatetime,
NDArrayFloat,
NDArrayTimedelta,
Expand All @@ -70,6 +69,7 @@
TupleFloat2,
)


__all__ = [
"asn_trans",
"atanh_trans",
Expand Down Expand Up @@ -187,7 +187,7 @@ def inverse(x: TFloatArrayLike) -> TFloatArrayLike:
"""
...

def breaks(self, limits: tuple[Any, Any]) -> NDArrayAny:
def breaks(self, limits: DomainType) -> NDArrayFloat:
"""
Calculate breaks in data space and return them
in transformed space.
Expand Down Expand Up @@ -898,21 +898,27 @@ def inverse(x: FloatArrayLike) -> NDArrayFloat:
return np.sign(x) * (np.exp(np.abs(x)) - 1) # type: ignore


def gettrans(t: str | Callable[[], Type[trans]] | Type[trans] | trans):
def gettrans(
t: str | Callable[[], Type[trans]] | Type[trans] | trans | None = None,
):
"""
Return a trans object
Parameters
----------
t : str | callable | type | trans
name of transformation function
Name of transformation function. If None, returns an
identity transform.
Returns
-------
out : trans
"""
obj = t
# Make sure trans object is instantiated
if t is None:
return identity_trans()

if isinstance(obj, str):
name = "{}_trans".format(obj)
obj = globals()[name]()
Expand Down
24 changes: 24 additions & 0 deletions mizani/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
# Type variable
TFloatLike = TypeVar("TFloatLike", bound=NDArrayFloat | float)
TFloatArrayLike = TypeVar("TFloatArrayLike", bound=FloatArrayLike)
TFloatVector = TypeVar("TFloatVector", bound=NDArrayFloat | FloatSeries)
TConstrained = TypeVar(
"TConstrained", int, float, bool, str, complex, datetime, timedelta
)

NumericUFunction: TypeAlias = Callable[[TFloatLike], TFloatLike]

# Nulls for different types
Expand Down Expand Up @@ -178,6 +183,25 @@ class SegmentFunctionColorMapData(TypedDict):
[FloatArrayLike, Optional[TupleFloat2], Optional[int]], NDArrayFloat
]

# Rescale functions
# This Protocol does not apply to rescale_mid
class PRescale(Protocol):
def __call__(
self,
x: FloatArrayLike,
to: TupleFloat2 = (0, 1),
_from: TupleFloat2 | None = None,
) -> NDArrayFloat: ...

# Censor functions
class PCensor(Protocol):
def __call__(
self,
x: NDArrayFloat,
range: TupleFloat2 = (0, 1),
only_finite: bool = True,
) -> NDArrayFloat: ...

# Any type that has comparison operators can be used to define
# the domain of a transformation. And implicitly the type of the
# dataspace.
Expand Down
7 changes: 4 additions & 3 deletions mizani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
from datetime import tzinfo
from typing import Any, Optional, Sequence, TypeGuard
from typing import Any, Optional, Sequence, TypeGuard, TypeVar

from mizani.typing import (
AnyArrayLike,
Expand All @@ -26,6 +26,7 @@
TupleFloat2,
)

T = TypeVar("T")

__all__ = [
"round_any",
Expand Down Expand Up @@ -269,11 +270,11 @@ def same_log10_order_of_magnitude(x, delta=0.1):
return np.floor(dmin) == np.floor(dmax)


def identity(*args):
def identity(param: T) -> T:
"""
Return whatever is passed in
"""
return args[0] if len(args) == 1 else args
return param


def get_categories(x):
Expand Down

0 comments on commit 055865c

Please sign in to comment.