From 2b982f094cd46f6b73e39c6340bc2433aae44d1c Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Wed, 29 Jan 2025 09:50:11 +0100 Subject: [PATCH] Prevent intersection/union of dt ranges with inconsistent tzinfo --- tests/test_ranges.py | 52 ++++++++++++++++++++++++++++++++++++++++---- xocto/ranges.py | 16 ++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/tests/test_ranges.py b/tests/test_ranges.py index b226cc9..e084899 100644 --- a/tests/test_ranges.py +++ b/tests/test_ranges.py @@ -13,6 +13,10 @@ from xocto import localtime, ranges +TZ_LONDON = zoneinfo.ZoneInfo("Europe/London") +TZ_BERLIN = zoneinfo.ZoneInfo("Europe/Berlin") + + @composite def valid_integer_range(draw): boundaries = draw(sampled_from(ranges.RangeBoundaries)) @@ -1021,11 +1025,9 @@ def test_finite_range(self): class TestFiniteDatetimeRange: def test_cannot_construct_with_inconsistent_tzinfo(self): with pytest.raises(ranges.InconsistentTzInfo): - tz_london = zoneinfo.ZoneInfo("Europe/London") - tz_berlin = zoneinfo.ZoneInfo("Europe/Berlin") ranges.FiniteDatetimeRange( - datetime.datetime(2020, 1, 1, tzinfo=tz_london), - datetime.datetime(2021, 1, 1, tzinfo=tz_berlin), + datetime.datetime(2020, 1, 1, tzinfo=TZ_LONDON), + datetime.datetime(2021, 1, 1, tzinfo=TZ_BERLIN), ) @pytest.mark.parametrize( @@ -1133,6 +1135,27 @@ def test_union_of_overlapping_ranges(self): ) ) + @pytest.mark.parametrize( + "other_range_model", + [ + ranges.FiniteDatetimeRange, + ranges.DatetimeRange, + ], + ) + def test_cannot_take_union_of_ranges_with_inconsistent_tzinfo( + self, other_range_model + ): + range = ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1, tzinfo=TZ_LONDON), + end=datetime.datetime(2000, 1, 3, tzinfo=TZ_LONDON), + ) + other = other_range_model( + start=datetime.datetime(2000, 1, 2, tzinfo=TZ_BERLIN), + end=datetime.datetime(2000, 1, 4, tzinfo=TZ_BERLIN), + ) + with pytest.raises(ranges.InconsistentTzInfo): + _ = range | other + class TestIntersection: def test_intersection_of_touching_ranges(self): range = ranges.FiniteDatetimeRange( @@ -1177,6 +1200,27 @@ def test_intersection_of_overlapping_ranges(self): ) ) + @pytest.mark.parametrize( + "other_range_model", + [ + ranges.FiniteDatetimeRange, + ranges.DatetimeRange, + ], + ) + def test_cannot_take_intersection_of_ranges_with_inconsistent_tzinfo( + self, other_range_model + ): + range = ranges.FiniteDatetimeRange( + start=datetime.datetime(2000, 1, 1, tzinfo=TZ_LONDON), + end=datetime.datetime(2000, 1, 3, tzinfo=TZ_LONDON), + ) + other = other_range_model( + start=datetime.datetime(2000, 1, 2, tzinfo=TZ_BERLIN), + end=datetime.datetime(2000, 1, 4, tzinfo=TZ_BERLIN), + ) + with pytest.raises(ranges.InconsistentTzInfo): + _ = range & other + class TestLocalize: def test_converts_timezone(self): # Create a datetime range in Sydney, which is diff --git a/xocto/ranges.py b/xocto/ranges.py index 46b4b3d..cdeb41d 100644 --- a/xocto/ranges.py +++ b/xocto/ranges.py @@ -867,12 +867,21 @@ def intersection( # We're deliberately overriding the base class here for better performance. # We can simplify the implementation since we know we're dealing with finite # ranges with INCLUSIVE_EXCLUSIVE bounds. + if self.tzinfo != other.tzinfo: + raise InconsistentTzInfo( + "inconsistent tzinfo for datetime range intersection" + ) left, right = (self, other) if self.start < other.start else (other, self) if left.end <= right.start: return None else: return FiniteDatetimeRange(right.start, min(left.end, right.end)) + if self.tzinfo != get_tzinfo(other): + raise InconsistentTzInfo( + "inconsistent tzinfo for datetime range intersection" + ) + base_intersection = super().intersection(other) if base_intersection is None: return None @@ -888,12 +897,19 @@ def union(self, other: Range[datetime.datetime]) -> Optional["FiniteDatetimeRang # We're deliberately overriding the base class here for better performance. # We can simplify the implementation since we know we're dealing with finite # ranges with INCLUSIVE_EXCLUSIVE bounds. + if self.tzinfo != other.tzinfo: + raise InconsistentTzInfo( + "inconsistent tzinfo for datetime range intersection" + ) left, right = (self, other) if self.start < other.start else (other, self) if left.end < right.start: return None else: return FiniteDatetimeRange(left.start, max(left.end, right.end)) + if self.tzinfo != get_tzinfo(other): + raise InconsistentTzInfo("inconsistent tzinfo for datetime range union") + try: base_union = super().union(other) except ValueError: