Skip to content

Commit

Permalink
✅ Improved automated test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
daquintero committed Sep 28, 2024
1 parent 069b2e4 commit 94c39e8
Show file tree
Hide file tree
Showing 22 changed files with 8,948 additions and 12 deletions.
10 changes: 7 additions & 3 deletions piel/analysis/signals/dc/transfer/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ def calculate_power_signal_from_collection(
)
break

logger.debug(f"Voltage values: {voltage}")
logger.debug(f"Current values: {current}")
logger.debug(f"Power values: {power_values}")
try:
logger.debug(f"Voltage values: {voltage}")
logger.debug(f"Current values: {current}")
logger.debug(f"Power values: {power_values}")
except Exception:
pass

if power_values is None or len(power_values) == 0:
raise ValueError("Power trace not found or empty in the collection.")

Expand Down
4 changes: 3 additions & 1 deletion piel/analysis/signals/time_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@
extract_statistical_metrics_collection,
)
from .offset import offset_to_first_rising_edge
from .remove import remove_before_first_rising_edge
from .remove import (
remove_before_first_rising_edge,
)
6 changes: 3 additions & 3 deletions piel/types/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class Unit(PielBaseModel):


A = Unit(name="ampere", datum="ampere", base=1, label=r"Current $A$")
dB = Unit(name="Decibel", datum="dB", base=1, label=r"Ratio $dB$")
GHz = Unit(name="Gigahertz", datum="Hertz", base=1e9, label=r"Frequency $GHz$")
Hz = Unit(name="Hertz", datum="Hertz", base=1, label=r"Frequency $Hz$")
dB = Unit(name="decibel", datum="dB", base=1, label=r"Ratio $dB$")
GHz = Unit(name="gigahertz", datum="hertz", base=1e9, label=r"Frequency $GHz$")
Hz = Unit(name="hertz", datum="hertz", base=1, label=r"Frequency $Hz$")
nm = Unit(name="nanometer", datum="meter", base=1e-9, label=r"Length $nm$")
ns = Unit(name="nanosecond", datum="second", base=1e-9, label=r"Time $ns$")
mm2 = Unit(
Expand Down
13 changes: 9 additions & 4 deletions piel/units/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ def get_unit_by_datum(datum: str) -> Optional[Unit]:
"""
import piel.types.units as units

for unit in dir(units):
if unit.datum.lower() == datum.lower():
return unit
return None
exact_match = None
for attr_name in dir(units):
attr = getattr(units, attr_name)
if isinstance(attr, Unit) and attr.datum.lower() == datum.lower():
if attr.base == 1: # Prioritize units with base 1 (e.g., 's' for second)
return attr
if exact_match is None:
exact_match = attr
return exact_match
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"pydantic>=2.0,<3.0",
"scipy>=1.11.4,<2.0.0",
"setuptools",
"xarray>=2024.9.0,<2024.10.0"
"xarray>=2024.1.0,<2024.10.0"
]

[project.urls]
Expand Down
Empty file.
190 changes: 190 additions & 0 deletions tests/analysis/dc_data/test_dc_traces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import pytest
import numpy as np

# Import the functions to be tested
from piel.analysis.signals.dc import (
get_trace_values_by_datum,
get_trace_values_by_unit,
)

# Import necessary classes and units
from piel.types import (
SignalDC,
SignalTraceDC,
Unit,
V,
A,
ratio,
)

# Sample Units
W = Unit(name="watt", datum="watt", base=1, label="Power W")
dB = Unit(name="Decibel", datum="dB", base=1, label="Ratio dB")

# Sample Data for Testing
VOLTAGE_VALUES = [0.0, 1.1, 2.2, 3.3, 4.4]
CURRENT_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0]
RATIO_VALUES = [1, 2, 3, 4, 5]


def create_signal_dc(name: str, values: list, unit: Unit) -> SignalDC:
"""
Helper function to create a SignalDC instance with a single trace.
"""
trace = SignalTraceDC(name=name, values=values, unit=unit)
return SignalDC(trace_list=[trace])


def test_get_trace_values_by_datum_voltage():
"""
Test retrieving voltage trace values by datum.
"""
signal_dc = create_signal_dc("Voltage Trace", VOLTAGE_VALUES, V)
retrieved_values = get_trace_values_by_datum(signal_dc, "voltage")
assert retrieved_values is not None, "Should retrieve voltage values."
np.testing.assert_array_equal(retrieved_values, np.array(VOLTAGE_VALUES))


def test_get_trace_values_by_datum_current():
"""
Test retrieving current trace values by datum.
"""
signal_dc = create_signal_dc("Current Trace", CURRENT_VALUES, A)
retrieved_values = get_trace_values_by_datum(signal_dc, "ampere")
assert retrieved_values is not None, "Should retrieve current values."
np.testing.assert_array_equal(retrieved_values, np.array(CURRENT_VALUES))


def test_get_trace_values_by_datum_ratio():
"""
Test retrieving ratio trace values by datum.
"""
signal_dc = create_signal_dc("Ratio Trace", RATIO_VALUES, ratio)
retrieved_values = get_trace_values_by_datum(signal_dc, "1")
assert retrieved_values is not None, "Should retrieve ratio values."
np.testing.assert_array_equal(retrieved_values, np.array(RATIO_VALUES))


def test_get_trace_values_by_datum_case_insensitive():
"""
Test that datum matching is case-insensitive.
"""
signal_dc = create_signal_dc("Voltage Trace", VOLTAGE_VALUES, V)
retrieved_values = get_trace_values_by_datum(signal_dc, "Voltage")
assert (
retrieved_values is not None
), "Should retrieve voltage values with case-insensitive datum."
np.testing.assert_array_equal(retrieved_values, np.array(VOLTAGE_VALUES))


def test_get_trace_values_by_datum_not_found():
"""
Test retrieving values with a datum that does not exist.
"""
signal_dc = create_signal_dc("Voltage Trace", VOLTAGE_VALUES, V)
retrieved_values = get_trace_values_by_datum(signal_dc, "current")
assert retrieved_values is None, "Should return None when datum is not found."


def test_get_trace_values_by_unit_voltage():
"""
Test retrieving voltage trace values by exact unit.
"""
signal_dc = create_signal_dc("Voltage Trace", VOLTAGE_VALUES, V)
retrieved_values = get_trace_values_by_unit(signal_dc, V)
assert retrieved_values is not None, "Should retrieve voltage values by unit."
np.testing.assert_array_equal(retrieved_values, np.array(VOLTAGE_VALUES))


def test_get_trace_values_by_unit_current():
"""
Test retrieving current trace values by exact unit.
"""
signal_dc = create_signal_dc("Current Trace", CURRENT_VALUES, A)
retrieved_values = get_trace_values_by_unit(signal_dc, A)
assert retrieved_values is not None, "Should retrieve current values by unit."
np.testing.assert_array_equal(retrieved_values, np.array(CURRENT_VALUES))


def test_get_trace_values_by_unit_not_found():
"""
Test retrieving values with a unit that does not exist.
"""
signal_dc = create_signal_dc("Voltage Trace", VOLTAGE_VALUES, V)
retrieved_values = get_trace_values_by_unit(
signal_dc, W
) # Looking for Watt in Voltage Trace
assert retrieved_values is None, "Should return None when unit is not found."


def test_get_trace_values_by_unit_case_insensitive():
"""
Test that unit matching is case-insensitive.
"""
signal_dc = create_signal_dc("Voltage Trace", VOLTAGE_VALUES, V)
V_upper = Unit(name="Volt", datum="voltage", base=1, label="V")
retrieved_values = get_trace_values_by_unit(signal_dc, V_upper)
assert (
retrieved_values is not None
), "Should retrieve voltage values with case-insensitive unit."
np.testing.assert_array_equal(retrieved_values, np.array(VOLTAGE_VALUES))


def test_get_trace_values_by_unit_multiple_traces():
"""
Test retrieving values when multiple traces exist.
"""
voltage_signal = create_signal_dc("Voltage Trace", VOLTAGE_VALUES, V)
current_signal = create_signal_dc("Current Trace", CURRENT_VALUES, A)
combined_signal_dc = SignalDC(
trace_list=voltage_signal.trace_list + current_signal.trace_list
)

retrieved_voltage = get_trace_values_by_unit(combined_signal_dc, V)
assert (
retrieved_voltage is not None
), "Should retrieve voltage values from combined traces."
np.testing.assert_array_equal(retrieved_voltage, np.array(VOLTAGE_VALUES))

retrieved_current = get_trace_values_by_unit(combined_signal_dc, A)
assert (
retrieved_current is not None
), "Should retrieve current values from combined traces."
np.testing.assert_array_equal(retrieved_current, np.array(CURRENT_VALUES))


def test_get_trace_values_by_unit_duplicate_units():
"""
Test retrieving values when multiple traces have the same unit.
"""
voltage_signal1 = create_signal_dc("Voltage Trace 1", VOLTAGE_VALUES, V)
voltage_signal2 = create_signal_dc("Voltage Trace 2", VOLTAGE_VALUES, V)
combined_signal_dc = SignalDC(
trace_list=voltage_signal1.trace_list + voltage_signal2.trace_list
)

retrieved_values = get_trace_values_by_unit(combined_signal_dc, V)
assert (
retrieved_values is not None
), "Should retrieve the first matching voltage trace."
np.testing.assert_array_equal(
retrieved_values, np.array(VOLTAGE_VALUES)
) # Assuming first trace is returned


def test_get_trace_values_by_unit_empty_trace_list():
"""
Test retrieving values from a SignalDC with no traces.
"""
empty_signal_dc = SignalDC(trace_list=[])
retrieved_values = get_trace_values_by_unit(empty_signal_dc, V)
assert retrieved_values is None, "Should return None when trace list is empty."


def test_get_trace_values_by_datum_empty_trace_list():
"""
Test retrieving values by datum from a SignalDC with no traces.
"""
empty_signal_dc = SignalDC(trace_list=[])
retrieved_values = get_trace_values_by_datum(empty_signal_dc, "voltage")
assert retrieved_values is None, "Should return None when trace list is empty."
Loading

0 comments on commit 94c39e8

Please sign in to comment.