diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 43798e364..e37137049 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -29,10 +29,10 @@ PlSelectExpr = _selector_proxy_ PlExpr = pl.Expr - PdSeries = pd.Series + PdSeries = pd.Series[Any] PlSeries = pl.Series - PyArrowArray = pa.Array - PyArrowChunkedArray = pa.ChunkedArray + PyArrowArray = pa.Array[Any] + PyArrowChunkedArray = pa.ChunkedArray[Any] PdNA = pd.NA PlNull = pl.Null @@ -763,7 +763,7 @@ def _(df: PyArrowTable, x: Any) -> bool: import pyarrow as pa arr = pa.array([x]) - return arr.is_null().to_pylist()[0] or arr.is_nan().to_pylist()[0] + return arr.is_null(nan_is_null=True).to_pylist()[0] @singledispatch diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..a42ada266 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from importlib.util import find_spec + +import pytest + +from great_tables._tbl_data import DataFrameLike, _re_version +from tests.utils import DataFrameConstructor, DataLike + + +def pandas_constructor(obj: DataLike) -> DataFrameLike: + import pandas as pd + + return pd.DataFrame(obj) # type: ignore[no-any-return] + + +def pandas_nullable_constructor(obj: DataLike) -> DataFrameLike: + import pandas as pd + + return pd.DataFrame(obj).convert_dtypes(dtype_backend="numpy_nullable") # type: ignore[no-any-return] + + +def pandas_pyarrow_constructor(obj: DataLike) -> DataFrameLike: + import pandas as pd + + return pd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return] + + +def polars_constructor(obj: DataLike) -> DataFrameLike: + import polars as pl + + return pl.DataFrame(obj) + + +def pyarrow_table_constructor(obj: DataLike) -> DataFrameLike: + import pyarrow as pa + + return pa.table(obj) # type: ignore[no-any-return] + + +frame_constructors: list[DataFrameConstructor] = [] + +is_pandas_installed = find_spec("pandas") is not None +is_polars_installed = find_spec("polars") is not None +is_pyarrow_installed = find_spec("pyarrow") is not None + +if is_pandas_installed: + import pandas as pd + + frame_constructors.append(pandas_constructor) + + pandas_ge_v2 = _re_version(pd.__version__) >= (2, 0, 0) + + if pandas_ge_v2: + frame_constructors.append(pandas_nullable_constructor) + + if pandas_ge_v2 and is_pyarrow_installed: + # pandas 2.0+ supports pyarrow dtype backend + # https://pandas.pydata.org/docs/whatsnew/v2.0.0.html#new-dtype-backends + frame_constructors.append(pandas_pyarrow_constructor) + +if is_polars_installed: + frame_constructors.append(polars_constructor) + +if is_pyarrow_installed: + frame_constructors.append(pyarrow_table_constructor) + + +@pytest.fixture(params=frame_constructors) +def frame_constructor(request: pytest.FixtureRequest) -> DataFrameConstructor: + return request.param # type: ignore[no-any-return] diff --git a/tests/data_color/test_data_color_utils.py b/tests/data_color/test_data_color_utils.py index 836a20bdf..6fb400032 100644 --- a/tests/data_color/test_data_color_utils.py +++ b/tests/data_color/test_data_color_utils.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import math +from contextlib import nullcontext +from typing import Any import numpy as np -import pandas as pd import pytest from great_tables._data_color.base import ( _add_alpha, @@ -23,315 +26,224 @@ _srgb, ) from great_tables._data_color.palettes import GradientPalette -from great_tables._tbl_data import is_na, Agnostic +from great_tables._tbl_data import is_na, DataFrameLike + +from tests.utils import DataFrameConstructor -def assert_equal_with_na(x: list, y: list): +def assert_equal_with_na(df: DataFrameLike, x: list[Any], y: list[Any]) -> None: """Assert two lists are equal, evaluating all NAs as equivalent Note that some cases like [np.nan] == [np.nan] will be True (since it checks id), but this function handles cases that trigger equality checks (since np.nan == np.nan is False). """ assert len(x) == len(y) - for ii in range(len(x)): - assert (is_na(Agnostic(), x[ii]) and is_na(Agnostic(), y[ii])) or (x[ii] == y[ii]) - - -def test_ideal_fgnd_color_dark_contrast(): - bgnd_color = "#FFFFFF" # White background color - fgnd_color = _ideal_fgnd_color(bgnd_color) - assert fgnd_color == "#000000" # Expected dark foreground color - - -def test_ideal_fgnd_color_light_contrast(): - bgnd_color = "#000000" # Black background color - fgnd_color = _ideal_fgnd_color(bgnd_color) - assert fgnd_color == "#FFFFFF" # Expected light foreground color - - -def test_ideal_fgnd_color_custom_contrast(): - bgnd_color = "#FF0000" # Red background color - light_color = "#00FF00" # Green light color - dark_color = "#0000FF" # Blue dark color - fgnd_color = _ideal_fgnd_color(bgnd_color, light=light_color, dark=dark_color) - assert fgnd_color == "#00FF00" # Expected custom light foreground color - - -def test_ideal_fgnd_color_custom_contrast_with_alpha(): - bgnd_color = "#FF0000FF" # Red background color with alpha - light_color = "#00FF00" # Green light color - dark_color = "#0000FF" # Blue dark color - fgnd_color = _ideal_fgnd_color(bgnd_color, light=light_color, dark=dark_color) - assert fgnd_color == "#00FF00" # Expected custom light foreground color - - -def test_ideal_fgnd_color_custom_contrast_with_custom_colors(): - bgnd_color = "#FF0000" # Red background color - light_color = "#00FF00" # Green light color - dark_color = "#0000FF" # Blue dark color - fgnd_color = _ideal_fgnd_color(bgnd_color, light=light_color, dark=dark_color) - assert fgnd_color == "#00FF00" # Expected custom light foreground color - - -def test_ideal_fgnd_color_custom_contrast_with_custom_colors_and_alpha(): - bgnd_color = "#FF0000FF" # Red background color with alpha - light_color = "#00FF00" # Green light color - dark_color = "#0000FF" # Blue dark color - fgnd_color = _ideal_fgnd_color(bgnd_color, light=light_color, dark=dark_color) - assert fgnd_color == "#00FF00" # Expected custom light foreground color - - -def test_get_wcag_contrast_ratio(): - color_1 = "#FFFFFF" # White color - color_2 = "#000000" # Black color - contrast_ratio = _get_wcag_contrast_ratio(color_1, color_2) - assert contrast_ratio == 21.0 # Expected high contrast ratio - - -def test_get_wcag_contrast_ratio_custom_colors(): - color_1 = "#FF0000" # Red color - color_2 = "#00FF00" # Green color - contrast_ratio = _get_wcag_contrast_ratio(color_1, color_2) - assert contrast_ratio == 2.9139375476009137 # Expected low contrast ratio - - -def test_get_wcag_contrast_ratio_custom_colors_with_alpha(): - color_1 = "#FF0000FF" # Red color with alpha - color_2 = "#00FF00" # Green color - contrast_ratio = _get_wcag_contrast_ratio(color_1, color_2) - assert contrast_ratio == 2.9139375476009137 # Contrast ratio unchanged with alpha - - -def test_get_wcag_contrast_ratio_same_color(): - color_1 = "#FF0000" # Red color - color_2 = "#FF0000" # Red color - contrast_ratio = _get_wcag_contrast_ratio(color_1, color_2) - assert contrast_ratio == 1.0 # Contrast ratio always 1.0 for same color - - -def test_get_wcag_contrast_ratio_same_color_with_alpha(): - color_1 = "#FF0000FF" # Red color with alpha - color_2 = "#FF0000FF" # Red color with alpha - contrast_ratio = _get_wcag_contrast_ratio(color_1, color_2) - assert contrast_ratio == 1.0 # Contrast ratio unchanged with alpha - - -def test_hex_to_rgb(): - hex_color = "#FF0000" # Red color - rgb = _hex_to_rgb(hex_color) - assert rgb == (255, 0, 0) - - hex_color = "#00FF00" # Green color - rgb = _hex_to_rgb(hex_color) - assert rgb == (0, 255, 0) - - hex_color = "#0000FF" # Blue color - rgb = _hex_to_rgb(hex_color) - assert rgb == (0, 0, 255) - - hex_color = "#FFFFFF" # White color - rgb = _hex_to_rgb(hex_color) - assert rgb == (255, 255, 255) - - hex_color = "#000000" # Black color - rgb = _hex_to_rgb(hex_color) - assert rgb == (0, 0, 0) - - hex_color = "#FF0000FF" # Red color with alpha - rgb = _hex_to_rgb(hex_color) - assert rgb == (255, 0, 0) - - hex_color = "#00FF00FF" # Green color with alpha - rgb = _hex_to_rgb(hex_color) - assert rgb == (0, 255, 0) - - hex_color = "#0000FFFF" # Blue color with alpha - rgb = _hex_to_rgb(hex_color) - assert rgb == (0, 0, 255) - - hex_color = "#FFFFFFFF" # White color with alpha - rgb = _hex_to_rgb(hex_color) - assert rgb == (255, 255, 255) + for xi, yi in zip(x, y): + assert (is_na(df, xi) and is_na(df, yi)) or (xi == yi) - hex_color = "#000000FF" # Black color with alpha - rgb = _hex_to_rgb(hex_color) - assert rgb == (0, 0, 0) - -def test_relative_luminance(): - rgb = (255, 255, 255) # White color - luminance = _relative_luminance(rgb) - assert luminance == 1.0 - - rgb = (0, 0, 0) # Black color - luminance = _relative_luminance(rgb) - assert luminance == 0.0 - - rgb = (255, 0, 0) # Red color - luminance = _relative_luminance(rgb) - assert luminance == 0.2126 - - rgb = (0, 255, 0) # Green color - luminance = _relative_luminance(rgb) - assert luminance == 0.7152 - - rgb = (0, 0, 255) # Blue color - luminance = _relative_luminance(rgb) - assert luminance == 0.0722 - - -def test_srgb(): - x = 0 - result = _srgb(x) - assert result == 0.0 - - x = 255 - result = _srgb(x) - assert result == 1.0 - - x = 128 - result = _srgb(x) - assert result == 0.21586050011389926 - - x = 100 - result = _srgb(x) - assert result == 0.12743768043564743 - - x = 200 - result = _srgb(x) - assert result == 0.5775804404296506 - - -def test_html_color_hex_colors(): - colors = ["#FF0000", "#00FF00", "#0000FF"] - result = _html_color(colors) - assert result == ["#FF0000", "#00FF00", "#0000FF"] - - -def test_html_color_named_colors(): - colors = ["red", "green", "blue"] - result = _html_color(colors) - assert result == ["#FF0000", "#008000", "#0000FF"] - - -def test_html_color_mixed_colors(): - colors = ["#FF0000", "green", "#0000FF"] - result = _html_color(colors) - assert result == ["#FF0000", "#008000", "#0000FF"] +@pytest.mark.parametrize( + ("bgnd_color", "fgnd_color"), + [ + ("#FFFFFF", "#000000"), # White background color -> Expected dark foreground color + ("#000000", "#FFFFFF"), # Black background color -> Expected light foreground color + ], +) +def test_ideal_fgnd_color_contrast(bgnd_color: str, fgnd_color: str) -> None: + assert _ideal_fgnd_color(bgnd_color) == fgnd_color -def test_html_color_hex_colors_with_alpha(): - colors = ["#FF0000", "#00FF00", "#0000FF"] - alpha = 0.5 - result = _html_color(colors, alpha) - assert result == ["#FF00007F", "#00FF007F", "#0000FF7F"] +@pytest.mark.parametrize( + ("bgnd_color", "light_color", "dark_color", "fgnd_color"), + [ + ( + "#FF0000", # Red background color + "#00FF00", # Green light color + "#0000FF", # Blue dark color + "#00FF00", # Expected custom light foreground color + ), + ( + "#FF0000FF", # Red background color with alpha + "#00FF00", # Green light color + "#0000FF", # Blue dark color + "#00FF00", # Expected custom light foreground color + ), + ( + "#FF0000", # Red background color + "#00FF00", # Green light color + "#0000FF", # Blue dark color + "#00FF00", # Expected custom light foreground color + ), + ( + "#FF0000FF", # Red background color with alpha + "#00FF00", # Green light color + "#0000FF", # Blue dark color + "#00FF00", # Expected custom light foreground color + ), + ], +) +def test_ideal_fgnd_color_custom_contrast( + bgnd_color: str, fgnd_color: str, light_color: str, dark_color: str +) -> None: + assert _ideal_fgnd_color(bgnd_color, light=light_color, dark=dark_color) == fgnd_color -def test_html_color_named_colors_with_alpha(): - colors = ["red", "green", "blue"] - alpha = 0.5 - result = _html_color(colors, alpha) - assert result == ["#FF00007F", "#0080007F", "#0000FF7F"] +@pytest.mark.parametrize( + ("color_1", "color_2", "contrast_ratio"), + [ + ("#FFFFFF", "#000000", 21.0), # Colors: (White, Black) -> high contrast ratio + ("#FF0000", "#00FF00", 2.9139375476009137), # colors: (Red, Green) -> low contrast ratio + ( + "#FF0000FF", + "#00FF00", + 2.9139375476009137, + ), # colors: (Red with alpha, Green) -> Contrast ratio unchanged with alpha + ( + "#FF0000", + "#FF0000", + 1.0, + ), # colors: (Red, Red) -> Contrast ratio always 1.0 for same color + ( + "#FF0000FF", + "#FF0000FF", + 1.0, + ), # colors: (Red with alpha, Red with alpha) -> Contrast ratio unchanged with alpha + ], +) +def test_get_wcag_contrast_ratio(color_1: str, color_2: str, contrast_ratio: float) -> None: + assert _get_wcag_contrast_ratio(color_1, color_2) == contrast_ratio -def test_html_color_mixed_colors_with_alpha(): - colors = ["#FF0000", "green", "#0000FF"] - alpha = 0.5 - result = _html_color(colors, alpha) - assert result == ["#FF00007F", "#0080007F", "#0000FF7F"] +@pytest.mark.parametrize( + ("hex_color", "rgb"), + [ + ("#FF0000", (255, 0, 0)), # Red color + ("#00FF00", (0, 255, 0)), # Green color + ("#0000FF", (0, 0, 255)), # Blue color + ("#FFFFFF", (255, 255, 255)), # White color + ("#000000", (0, 0, 0)), # Black color + ("#FF0000FF", (255, 0, 0)), # Red color with alpha + ("#00FF00FF", (0, 255, 0)), # Green color with alpha + ("#0000FFFF", (0, 0, 255)), # Blue color with alpha + ("#FFFFFFFF", (255, 255, 255)), # White color with alpha + ("#000000FF", (0, 0, 0)), # Black color with alpha + ], +) +def test_hex_to_rgb(hex_color: str, rgb: tuple[int, int, int]) -> None: + assert _hex_to_rgb(hex_color) == rgb -def test_add_alpha_float_alpha(): - colors = ["#FF0000", "#00FF00", "#0000FF"] - alpha = 0.5 - result = _add_alpha(colors, alpha) - assert result == ["#FF00007F", "#00FF007F", "#0000FF7F"] +@pytest.mark.parametrize( + ("rgb", "luminance"), + [ + ((255, 255, 255), 1.0), # White color + ((0, 0, 0), 0.0), # Black color + ((255, 0, 0), 0.2126), # Red color + ((0, 255, 0), 0.7152), # Green color + ((0, 0, 255), 0.0722), # Blue color + ], +) +def test_relative_luminance(rgb: tuple[int, int, int], luminance: float) -> None: + assert _relative_luminance(rgb) == luminance -def test_add_alpha_invalid_alpha(): - colors = ["#FF0000", "#00FF00", "#0000FF"] - alpha = 1.5 - try: - _add_alpha(colors, alpha) - except ValueError as e: - assert ( - str(e) - == "Invalid alpha value provided (1.5). Please ensure that alpha is a value between 0 and 1." - ) +@pytest.mark.parametrize( + ("x", "srgb"), + [ + (0, 0.0), + (255, 1.0), + (128, 0.21586050011389926), + (100, 0.12743768043564743), + (200, 0.5775804404296506), + ], +) +def test_srgb(x: int, srgb: float) -> None: + assert _srgb(x) == srgb -def test_remove_alpha(): - colors = ["#FF0000FF", "#00FF00FF", "#0000FFFF"] - result = _remove_alpha(colors) - assert result == ["#FF0000", "#00FF00", "#0000FF"] +@pytest.mark.parametrize( + ("colors", "alpha", "result"), + [ + (["#FF0000", "#00FF00", "#0000FF"], None, ["#FF0000", "#00FF00", "#0000FF"]), + (["red", "green", "blue"], None, ["#FF0000", "#008000", "#0000FF"]), + (["#FF0000", "green", "#0000FF"], None, ["#FF0000", "#008000", "#0000FF"]), + (["#FF0000", "#00FF00", "#0000FF"], 0.5, ["#FF00007F", "#00FF007F", "#0000FF7F"]), + (["red", "green", "blue"], 0.5, ["#FF00007F", "#0080007F", "#0000FF7F"]), + (["#FF0000", "green", "#0000FF"], 0.5, ["#FF00007F", "#0080007F", "#0000FF7F"]), + ], +) +def test_html_color_hex_colors(colors: list[str], alpha: float | None, result: list[str]) -> None: + assert _html_color(colors, alpha=alpha) == result - colors = ["#FF000080", "#00FF0080", "#0000FF80"] - result = _remove_alpha(colors) - assert result == ["#FF0000", "#00FF00", "#0000FF"] +@pytest.mark.parametrize( + ("alpha", "context"), + [ + (0.5, nullcontext()), + ( + 1.5, + pytest.raises( + ValueError, + match=r"Invalid alpha value provided \(1.5\). Please ensure that alpha is a value between 0 and 1.", + ), + ), + ], +) +def test_add_alpha_float_alpha(alpha: float, context: Any) -> None: colors = ["#FF0000", "#00FF00", "#0000FF"] - result = _remove_alpha(colors) - assert result == ["#FF0000", "#00FF00", "#0000FF"] - - -def test_float_to_hex(): - # Test case 1: x = 0.0 - x = 0.0 - result = _float_to_hex(x) - assert result == "00" - - # Test case 2: x = 1.0 - x = 1.0 - result = _float_to_hex(x) - assert result == "FF" - - # Test case 3: x = 0.5 - x = 0.5 - result = _float_to_hex(x) - assert result == "7F" - - # Test case 4: x = 0.25 - x = 0.25 - result = _float_to_hex(x) - assert result == "3F" - # Test case 5: x = 0.75 - x = 0.75 - result = _float_to_hex(x) - assert result == "BF" + with context: + result = _add_alpha(colors, alpha) + assert result == ["#FF00007F", "#00FF007F", "#0000FF7F"] - # Test case 6: x = 0.125 - x = 0.125 - result = _float_to_hex(x) - assert result == "1F" +@pytest.mark.parametrize( + ("colors", "result"), + [ + (["#FF0000FF", "#00FF00FF", "#0000FFFF"], ["#FF0000", "#00FF00", "#0000FF"]), + (["#FF000080", "#00FF0080", "#0000FF80"], ["#FF0000", "#00FF00", "#0000FF"]), + (["#FF0000", "#00FF00", "#0000FF"], ["#FF0000", "#00FF00", "#0000FF"]), + ], +) +def test_remove_alpha(colors: list[str], result: list[str]) -> None: + assert _remove_alpha(colors) == result -def test_color_name_to_hex(): - # Test case 1: All colors are already in hexadecimal format - colors = ["#FF0000", "#00FF00", "#0000FF"] - result = _color_name_to_hex(colors) - assert result == ["#FF0000", "#00FF00", "#0000FF"] - # Test case 2: Some colors are in color name format - colors = ["red", "green", "blue"] - result = _color_name_to_hex(colors) - assert result == ["#FF0000", "#008000", "#0000FF"] +@pytest.mark.parametrize( + ("x", "hex"), + [ + (0.0, "00"), # Test case 1: x = 0.0 + (1.0, "FF"), # Test case 2: x = 1.0 + (0.5, "7F"), # Test case 3: x = 0.5 + (0.25, "3F"), # Test case 4: x = 0.25 + (0.75, "BF"), # Test case 5: x = 0.75 + (0.125, "1F"), # Test case 6: x = 0.125 + ], +) +def test_float_to_hex(x: float, hex: str) -> None: + assert _float_to_hex(x) == hex - # Test case 3: All colors are in color name format - colors = ["red", "green", "blue"] - result = _color_name_to_hex(colors) - assert result == ["#FF0000", "#008000", "#0000FF"] - # Test case 4: Empty list of colors - colors = [] - result = _color_name_to_hex(colors) - assert result == [] +@pytest.mark.parametrize( + ("colors", "result"), + [ + # Test case 1: All colors are already in hexadecimal format + (["#FF0000", "#00FF00", "#0000FF"], ["#FF0000", "#00FF00", "#0000FF"]), + # Test case 2: Some colors are in color name format + (["red", "green", "blue"], ["#FF0000", "#008000", "#0000FF"]), + # Test case 3: All colors are in color name format + (["red", "green", "blue"], ["#FF0000", "#008000", "#0000FF"]), + # Test case 4: Empty list of colors [] + ([], []), + # Test case 5: Colors with mixed formats + (["#FF0000", "green", "#0000FF"], ["#FF0000", "#008000", "#0000FF"]), + ], +) +def test_color_name_to_hex(colors: list[str], result: list[str]) -> None: + assert _color_name_to_hex(colors) == result - # Test case 5: Colors with mixed formats - colors = ["#FF0000", "green", "#0000FF"] - result = _color_name_to_hex(colors) - assert result == ["#FF0000", "#008000", "#0000FF"] +def test_color_name_to_hex_invalid() -> None: # Test case 6: Colors with invalid names colors = ["#FF0000", "green", "invalid"] with pytest.raises(ValueError) as e: @@ -340,207 +252,160 @@ def test_color_name_to_hex(): assert "Invalid color name provided (invalid)" in e.value.args[0] -def test_is_short_hex_valid_short_hex(): - color = "#F00" - result = _is_short_hex(color) - assert result is True - - color = "#0F0" - result = _is_short_hex(color) - assert result is True - - color = "#00F" - result = _is_short_hex(color) - assert result is True - - color = "#123" - result = _is_short_hex(color) - assert result is True - - -def test_is_short_hex_valid_long_hex(): - color = "#FF0000" - result = _is_short_hex(color) - assert result is False - - color = "#00FF00" - result = _is_short_hex(color) - assert result is False - - color = "#0000FF" - result = _is_short_hex(color) - assert result is False - - color = "#123456" - result = _is_short_hex(color) - assert result is False - - -def test_is_hex_col_valid_hex_colors(): - colors = ["#FF0000", "#00FF00", "#0000FF"] - result = _is_hex_col(colors) - assert result == [True, True, True] - - colors = ["#123456", "#ABCDEF", "#abcdef"] - result = _is_hex_col(colors) - assert result == [True, True, True] - - colors = ["#F00", "#0F0", "#00F"] - result = _is_hex_col(colors) - assert result == [False, False, False] +@pytest.mark.parametrize( + ("color", "is_short_hex"), + [ + ("#F00", True), + ("#0F0", True), + ("#00F", True), + ("#123", True), + ("#FF0000", False), + ("#00FF00", False), + ("#0000FF", False), + ("#123456", False), + ], +) +def test_is_short_hex(color: str, is_short_hex: bool) -> None: + assert _is_short_hex(color) is is_short_hex - colors = ["#FF0000FF", "#00FF00FF", "#0000FFFF"] - result = _is_hex_col(colors) - assert result == [True, True, True] +@pytest.mark.parametrize( + ("colors", "is_valid"), + [ + (["#FF0000", "#00FF00", "#0000FF"], [True, True, True]), + (["#123456", "#ABCDEF", "#abcdef"], [True, True, True]), + (["#F00", "#0F0", "#00F"], [False, False, False]), + (["#FF0000FF", "#00FF00FF", "#0000FFFF"], [True, True, True]), + (["#FF000", "#00FF00F", "#0000FFG"], [False, False, False]), + (["#12345", "#ABCDEF1", "#abcdefg"], [False, False, False]), + (["#F0", "#0F00", "#00FG"], [False, False, False]), + (["#FF0000F", "#00FF00F", "#0000FFG"], [False, False, False]), + ], +) +def test_is_hex_col_valid_hex_colors(colors: list[str], is_valid: list[bool]) -> None: + assert _is_hex_col(colors) == is_valid -def test_is_hex_col_invalid_hex_colors(): - colors = ["#FF000", "#00FF00F", "#0000FFG"] - result = _is_hex_col(colors) - assert result == [False, False, False] - colors = ["#12345", "#ABCDEF1", "#abcdefg"] - result = _is_hex_col(colors) - assert result == [False, False, False] +@pytest.mark.parametrize( + ("colors", "result"), + [ + (["#FF0000", "#00FF00", "#0000FF"], [True, True, True]), + (["#F00", "#0F0", "#00F"], [False, False, False]), + (["#123456", "#ABCDEF", "#abcdef"], [True, True, True]), + (["#123", "#abc", "#ABC"], [False, False, False]), + ( + [ + "#FF0000", + "#00FF00", + "#0000FF", + "#F00", + "#0F0", + "#00F", + "#123456", + "#ABCDEF", + "#abcdef", + "#123", + "#abc", + "#ABC", + ], + [True, True, True, False, False, False, True, True, True, False, False, False], + ), + ], +) +def test_is_standard_hex_col(colors: list[str], result: list[bool]) -> None: + assert _is_standard_hex_col(colors) == result - colors = ["#F0", "#0F00", "#00FG"] - result = _is_hex_col(colors) - assert result == [False, False, False] - colors = ["#FF0000F", "#00FF00F", "#0000FFG"] - result = _is_hex_col(colors) - assert result == [False, False, False] +@pytest.mark.parametrize( + ("hex_color", "expanded"), + [ + ("#F00", "#FF0000"), + ("#0F0", "#00FF00"), + ("#00F", "#0000FF"), + ("#123", "#112233"), + ], +) +def test_expand_short_hex_valid_short_hex(hex_color: str, expanded: str) -> None: + assert _expand_short_hex(hex_color) == expanded -def test_is_standard_hex_col(): - colors = ["#FF0000", "#00FF00", "#0000FF"] - result = _is_standard_hex_col(colors) - assert result == [True, True, True] - - colors = ["#F00", "#0F0", "#00F"] - result = _is_standard_hex_col(colors) - assert result == [False, False, False] - - colors = ["#123456", "#ABCDEF", "#abcdef"] - result = _is_standard_hex_col(colors) - assert result == [True, True, True] - - colors = ["#123", "#abc", "#ABC"] - result = _is_standard_hex_col(colors) - assert result == [False, False, False] - - colors = [ - "#FF0000", - "#00FF00", - "#0000FF", - "#F00", - "#0F0", - "#00F", - "#123456", - "#ABCDEF", - "#abcdef", - "#123", - "#abc", - "#ABC", - ] - result = _is_standard_hex_col(colors) - assert result == [True, True, True, False, False, False, True, True, True, False, False, False] - - -def test_expand_short_hex_valid_short_hex(): - hex_color = "#F00" - expanded = _expand_short_hex(hex_color) - assert expanded == "#FF0000" - - hex_color = "#0F0" - expanded = _expand_short_hex(hex_color) - assert expanded == "#00FF00" - - hex_color = "#00F" - expanded = _expand_short_hex(hex_color) - assert expanded == "#0000FF" - - hex_color = "#123" - expanded = _expand_short_hex(hex_color) - assert expanded == "#112233" - - -def test_rescale_numeric(): - # Test case 1: Rescale values within the domain range - df = pd.DataFrame({"col": [1, 2, 3, 4, 5]}) - vals = [2, 3, 4] - domain = [1, 5] - expected_result = [0.25, 0.5, 0.75] - result = _rescale_numeric(df, vals, domain) - assert result == expected_result - - # Test case 2: Rescale values outside the domain range - df = pd.DataFrame({"col": [1, 2, 3, 4, 5]}) - vals = [0, 6] - domain = [1, 5] - expected_result = [np.nan, np.nan] - result = _rescale_numeric(df, vals, domain) - assert_equal_with_na(result, expected_result) - - # Test case 3: Rescale values with NA values - df = pd.DataFrame({"col": [1, 2, np.nan, 4, 5]}) - vals = [2, np.nan, 4] - domain = [1, 5] - expected_result = [0.25, np.nan, 0.75] - result = _rescale_numeric(df, vals, domain) - assert_equal_with_na(result, expected_result) - - -def test_get_domain_numeric(): - df = pd.DataFrame({"col1": [1, 2, 3, 4, 5], "col2": [6, 7, 8, 9, 10]}) - vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - domain = _get_domain_numeric(df, vals) - assert domain == [1, 10] +@pytest.mark.parametrize( + ("vals", "data", "expected"), + [ + ([2, 3, 4], [1, 2, 3, 4, 5], [0.25, 0.5, 0.75]), # Rescale values within the domain range + ( + [0, 6], + [1, 2, 3, 4, 5], + [float("nan"), float("nan")], + ), # Rescale values outside the domain range + ( + [2.0, float("nan"), 4.0], + [1.0, 2, float("nan"), 4, 5], + [0.25, float("nan"), 0.75], + ), # Rescale values with NA values + ], +) +def test_rescale_numeric( + frame_constructor: DataFrameConstructor, + vals: list[float], + data: list[float], + expected: list[float], +) -> None: + domain: list[float] = [1, 5] + df = frame_constructor({"col": data}) + result = _rescale_numeric(df=df, vals=vals, domain=domain) + assert_equal_with_na(df, result, expected) - df = pd.DataFrame({"col1": [1, 2, 3, 4, 5], "col2": [6, 7, 8, 9, 10]}) - vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, np.nan] - domain = _get_domain_numeric(df, vals) - assert domain == [1, 10] - df = pd.DataFrame({"col1": [1, 2, 3, 4, 5], "col2": [6, 7, 8, 9, 10]}) - vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, np.nan, np.nan] - domain = _get_domain_numeric(df, vals) +@pytest.mark.parametrize( + "vals", + [ + (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + (1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, float("nan")), + (1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, float("nan"), float("nan")), + ], +) +def test_get_domain_numeric(frame_constructor: DataFrameConstructor, vals: list[float]) -> None: + data = {"col1": [1, 2, 3, 4, 5], "col2": [6, 7, 8, 9, 10]} + df = frame_constructor(data) + domain = _get_domain_numeric(df=df, vals=vals) assert domain == [1, 10] -def test_get_domain_factor(): - # Test case 1: Empty DataFrame - df = pd.DataFrame() - vals = [] - result = _get_domain_factor(df, vals) - assert result == [] - - # Test case 2: DataFrame with factor values - df = pd.DataFrame({"col1": ["A", "B", "A", "C", "B"]}) - vals = ["A", "B", "C"] - result = _get_domain_factor(df, vals) - assert result == ["A", "B", "C"] +def test_get_domain_factor_empty_frame(frame_constructor: DataFrameConstructor) -> None: + df = frame_constructor({}) + domain = _get_domain_factor(df=df, vals=[]) + assert domain == [] - # Test case 3: DataFrame with factor values and NA values - df = pd.DataFrame({"col1": ["A", "B", np.nan, "C", "B"]}) - vals = ["A", "B", "C"] - result = _get_domain_factor(df, vals) - assert result == ["A", "B", "C"] - # Test case 4: DataFrame with factor values and NA values in `vals` - df = pd.DataFrame({"col1": ["A", "B", "C"]}) - vals = ["A", "B", np.nan, "C"] - result = _get_domain_factor(df, vals) - assert result == ["A", "B", "C"] - - # Test case 5: DataFrame with factor values and duplicate values in `vals` - df = pd.DataFrame({"col1": ["A", "B", "C"]}) - vals = ["A", "B", "B", "C"] - result = _get_domain_factor(df, vals) - assert result == ["A", "B", "C"] +@pytest.mark.parametrize( + ("data", "vals"), + [ + (["A", "B", "A", "C", "B"], ["A", "B", "C"]), # DataFrame with factor values + ( + ["A", "B", None, "C", "B"], + ["A", "B", "C"], + ), # DataFrame with factor values and NA values + ( + ["A", "B", "C"], + ["A", "B", None, "C"], + ), # DataFrame with factor values and NA values in `vals` + ( + ["A", "B", "C"], + ["A", "B", "B", "C"], + ), # DataFrame with factor values and duplicate values in `vals` + ], +) +def test_get_domain_factor( + frame_constructor: DataFrameConstructor, data: list[str], vals: list[str] +) -> None: + expected = ["A", "B", "C"] + df = frame_constructor({"col1": data}) + domain = _get_domain_factor(df=df, vals=vals) + assert domain == expected -def test_gradient_n_pal(): +def test_gradient_n_pal() -> None: palette = GradientPalette(["red", "blue"]) res = palette([0, 0.25, 0.5, 0.75, 1]) @@ -548,27 +413,31 @@ def test_gradient_n_pal(): @pytest.mark.parametrize( - "src,dst", [(0.001, "#ff0000"), (0.004, "#fe0001"), (0.999, "#0000ff"), (0.996, "#0100fe")] + ("src", "dst"), [(0.001, "#ff0000"), (0.004, "#fe0001"), (0.999, "#0000ff"), (0.996, "#0100fe")] ) -def test_gradient_n_pal_rounds(src, dst): +def test_gradient_n_pal_rounds(src: float, dst: str) -> None: palette = GradientPalette(["red", "blue"]) res = palette([src]) assert res == [dst] -def test_gradient_n_pal_inf(): +@pytest.mark.parametrize( + ("src", "dst"), + [ + ([-math.inf, 0, math.nan, 1, math.inf], [None, "#ff0000", None, "#0000ff", None]), + # same but with numpy + ([-np.inf, 0, np.nan, 1, np.inf], [None, "#ff0000", None, "#0000ff", None]), + ], +) +def test_gradient_n_pal_inf(src: list[float], dst: list[str]) -> None: palette = GradientPalette(["red", "blue"]) - res = palette([-math.inf, 0, math.nan, 1, math.inf]) - assert res == [None, "#ff0000", None, "#0000ff", None] - - # same but with numpy - res = palette([-np.inf, 0, np.nan, 1, np.inf]) - assert res == [None, "#ff0000", None, "#0000ff", None] + res = palette(src) + assert res == dst -def test_gradient_n_pal_symmetric(): +def test_gradient_n_pal_symmetric() -> None: # based on mizani unit tests palette = GradientPalette(["red", "blue", "red"], values=[0, 0.5, 1]) @@ -576,7 +445,7 @@ def test_gradient_n_pal_symmetric(): assert res == ["#990066", "#0000ff", "#990066"] -def test_gradient_n_pal_manual_values(): +def test_gradient_n_pal_manual_values() -> None: # note that green1 is #0000ff (and green is not!) palette = GradientPalette(["red", "blue", "green1"], values=[0, 0.8, 1]) @@ -584,45 +453,45 @@ def test_gradient_n_pal_manual_values(): assert res == ["#ff0000", "#0000ff", "#008080", "#00ff00"] -def test_gradient_n_pal_guard_raises(): - with pytest.raises(ValueError) as exc_info: - GradientPalette(["red"]) - - assert "only 1 provided" in exc_info.value.args[0] - - # values must start with 0 - with pytest.raises(ValueError) as exc_info: - GradientPalette(["red", "blue"], values=[0.1, 1]) - - assert "start with 0" in exc_info.value.args[0] - - # values must end with 1 - with pytest.raises(ValueError) as exc_info: - GradientPalette(["red", "blue"], values=[0, 0.1]) - - assert "end with 1" in exc_info.value.args[0] - - # len(color) != len(values) - with pytest.raises(ValueError) as exc_info: - GradientPalette(["red", "blue"], values=[0, 1.1, 1]) - - assert "Received 3 values and 2 colors" in exc_info.value.args[0] - - with pytest.raises(NotImplementedError) as exc_info: - GradientPalette([(255, 0, 0), (0, 255, 0)]) - - assert "Currently, rgb tuples can't be passed directly." in exc_info.value.args[0] +@pytest.mark.parametrize( + ("colors", "values", "context"), + [ + (["red"], None, pytest.raises(ValueError, match="only 1 provided")), + # values must start with 0 + (["red", "blue"], [0.1, 1], pytest.raises(ValueError, match="start with 0")), + # values must end with 1 + (["red", "blue"], [0, 0.1], pytest.raises(ValueError, match="end with 1")), + # len(color) != len(values) + ( + ["red", "blue"], + [0, 1.1, 1], + pytest.raises(ValueError, match="Received 3 values and 2 colors"), + ), + ( + [(255, 0, 0), (0, 255, 0)], + None, + pytest.raises( + NotImplementedError, match="Currently, rgb tuples can't be passed directly." + ), + ), + ], +) +def test_gradient_n_pal_guard_raises( + colors: list[str], values: list[float] | None, context: Any +) -> None: + with context: + GradientPalette(colors=colors, values=values) -def test_gradient_n_pal_out_of_bounds_raises(): +@pytest.mark.parametrize( + ("data", "msg"), + [ + ([0, 1.1], "Value: 1.1"), + ([0, -0.1], "Value: -0.1"), + ], +) +def test_gradient_n_pal_out_of_bounds_raises(data: list[float], msg: str) -> None: palette = GradientPalette(["red", "blue"]) - with pytest.raises(ValueError) as exc_info: - palette([0, 1.1]) - - assert "Value: 1.1" in exc_info.value.args[0] - - with pytest.raises(ValueError) as exc_info: - palette([0, -0.1]) - - assert "Value: -0.1" in exc_info.value.args[0] + with pytest.raises(ValueError, match=msg): + palette(data) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..79b33e816 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +from great_tables._tbl_data import DataFrameLike + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + +DataLike: TypeAlias = dict[str, list[Any]] +DataFrameConstructor: TypeAlias = Callable[[DataLike], DataFrameLike]