diff --git a/plotly/shapeannotation.py b/plotly/shapeannotation.py index a2323ed02d4..41173054a11 100644 --- a/plotly/shapeannotation.py +++ b/plotly/shapeannotation.py @@ -1,10 +1,52 @@ # some functions defined here to avoid numpy import +import datetime + + +def _is_date_string(val): + """Check if a value is a date/datetime string.""" + if not isinstance(val, str): + return False + try: + datetime.datetime.fromisoformat(val.replace("Z", "+00:00")) + return True + except (ValueError, AttributeError): + return False + + +def _datetime_str_to_ms(val): + """Convert a datetime string to milliseconds since epoch.""" + dt = datetime.datetime.fromisoformat(val.replace("Z", "+00:00")) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=datetime.timezone.utc) + return dt.timestamp() * 1000 + + +def _ms_to_datetime_str(ms): + """Convert milliseconds since epoch back to a datetime string.""" + dt = datetime.datetime.fromtimestamp(ms / 1000, tz=datetime.timezone.utc) + return dt.strftime("%Y-%m-%d %H:%M:%S") + def _mean(x): if len(x) == 0: raise ValueError("x must have positive length") - return float(sum(x)) / len(x) + try: + return float(sum(x)) / len(x) + except TypeError: + # Handle non-numeric types like datetime strings or datetime objects + if all(_is_date_string(v) for v in x): + ms_values = [_datetime_str_to_ms(v) for v in x] + mean_ms = sum(ms_values) / len(ms_values) + return _ms_to_datetime_str(mean_ms) + # Handle datetime.datetime, pd.Timestamp, or similar objects + if all(hasattr(v, "timestamp") for v in x): + ts_values = [v.timestamp() * 1000 for v in x] + mean_ms = sum(ts_values) / len(ts_values) + return datetime.datetime.fromtimestamp( + mean_ms / 1000, tz=datetime.timezone.utc + ).isoformat() + raise def _argmin(x): diff --git a/tests/test_optional/test_autoshapes/test_annotated_shapes.py b/tests/test_optional/test_autoshapes/test_annotated_shapes.py index a008e3bda12..14f7b31792a 100644 --- a/tests/test_optional/test_autoshapes/test_annotated_shapes.py +++ b/tests/test_optional/test_autoshapes/test_annotated_shapes.py @@ -425,5 +425,84 @@ def test_all_annotation_positions(): draw_all_annotation_positions(testing=True) + if __name__ == "__main__": draw_all_annotation_positions() + + +# Tests for datetime axis annotation support (issue #3065) +import datetime + + +def test_vline_datetime_string_annotation(): + """add_vline with annotation_text on datetime x-axis should not crash.""" + fig = go.Figure() + fig.add_trace( + go.Scatter(x=["2018-01-01", "2018-06-01", "2018-12-31"], y=[1, 2, 3]) + ) + fig.add_vline(x="2018-09-24", annotation_text="test") + assert len(fig.layout.annotations) == 1 + assert fig.layout.annotations[0].text == "test" + + +def test_hline_with_datetime_vline(): + """add_hline should still work alongside datetime vline usage.""" + fig = go.Figure() + fig.add_trace( + go.Scatter(x=["2018-01-01", "2018-06-01", "2018-12-31"], y=[1, 2, 3]) + ) + fig.add_hline(y=2, annotation_text="hline test") + assert len(fig.layout.annotations) == 1 + assert fig.layout.annotations[0].text == "hline test" + + +def test_vrect_datetime_string_annotation(): + """add_vrect with annotation_text on datetime x-axis should not crash.""" + fig = go.Figure() + fig.add_trace( + go.Scatter(x=["2018-01-01", "2018-06-01", "2018-12-31"], y=[1, 2, 3]) + ) + fig.add_vrect(x0="2018-03-01", x1="2018-09-01", annotation_text="rect test") + assert len(fig.layout.annotations) == 1 + assert fig.layout.annotations[0].text == "rect test" + + +def test_vline_datetime_object_annotation(): + """add_vline with datetime.datetime object should not crash.""" + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=[ + datetime.datetime(2018, 1, 1), + datetime.datetime(2018, 6, 1), + datetime.datetime(2018, 12, 31), + ], + y=[1, 2, 3], + ) + ) + fig.add_vline(x=datetime.datetime(2018, 9, 24), annotation_text="dt test") + assert len(fig.layout.annotations) == 1 + assert fig.layout.annotations[0].text == "dt test" + + +def test_vrect_datetime_object_annotation(): + """add_vrect with datetime.datetime objects should compute correct mean.""" + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=[ + datetime.datetime(2018, 1, 1), + datetime.datetime(2018, 6, 1), + datetime.datetime(2018, 12, 31), + ], + y=[1, 2, 3], + ) + ) + fig.add_vrect( + x0=datetime.datetime(2018, 3, 1), + x1=datetime.datetime(2018, 9, 1), + annotation_text="rect dt test", + ) + assert len(fig.layout.annotations) == 1 + assert fig.layout.annotations[0].text == "rect dt test" +