Skip to content

Commit

Permalink
Prevent intersection/union of dt ranges with inconsistent tzinfo
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter554 committed Jan 29, 2025
1 parent 3892375 commit 2b982f0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
52 changes: 48 additions & 4 deletions tests/test_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from xocto import localtime, ranges


TZ_LONDON = zoneinfo.ZoneInfo("Europe/London")
TZ_BERLIN = zoneinfo.ZoneInfo("Europe/Berlin")


@composite
def valid_integer_range(draw):
boundaries = draw(sampled_from(ranges.RangeBoundaries))
Expand Down Expand Up @@ -1021,11 +1025,9 @@ def test_finite_range(self):
class TestFiniteDatetimeRange:
def test_cannot_construct_with_inconsistent_tzinfo(self):
with pytest.raises(ranges.InconsistentTzInfo):
tz_london = zoneinfo.ZoneInfo("Europe/London")
tz_berlin = zoneinfo.ZoneInfo("Europe/Berlin")
ranges.FiniteDatetimeRange(
datetime.datetime(2020, 1, 1, tzinfo=tz_london),
datetime.datetime(2021, 1, 1, tzinfo=tz_berlin),
datetime.datetime(2020, 1, 1, tzinfo=TZ_LONDON),
datetime.datetime(2021, 1, 1, tzinfo=TZ_BERLIN),
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -1133,6 +1135,27 @@ def test_union_of_overlapping_ranges(self):
)
)

@pytest.mark.parametrize(
"other_range_model",
[
ranges.FiniteDatetimeRange,
ranges.DatetimeRange,
],
)
def test_cannot_take_union_of_ranges_with_inconsistent_tzinfo(
self, other_range_model
):
range = ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1, tzinfo=TZ_LONDON),
end=datetime.datetime(2000, 1, 3, tzinfo=TZ_LONDON),
)
other = other_range_model(
start=datetime.datetime(2000, 1, 2, tzinfo=TZ_BERLIN),
end=datetime.datetime(2000, 1, 4, tzinfo=TZ_BERLIN),
)
with pytest.raises(ranges.InconsistentTzInfo):
_ = range | other

class TestIntersection:
def test_intersection_of_touching_ranges(self):
range = ranges.FiniteDatetimeRange(
Expand Down Expand Up @@ -1177,6 +1200,27 @@ def test_intersection_of_overlapping_ranges(self):
)
)

@pytest.mark.parametrize(
"other_range_model",
[
ranges.FiniteDatetimeRange,
ranges.DatetimeRange,
],
)
def test_cannot_take_intersection_of_ranges_with_inconsistent_tzinfo(
self, other_range_model
):
range = ranges.FiniteDatetimeRange(
start=datetime.datetime(2000, 1, 1, tzinfo=TZ_LONDON),
end=datetime.datetime(2000, 1, 3, tzinfo=TZ_LONDON),
)
other = other_range_model(
start=datetime.datetime(2000, 1, 2, tzinfo=TZ_BERLIN),
end=datetime.datetime(2000, 1, 4, tzinfo=TZ_BERLIN),
)
with pytest.raises(ranges.InconsistentTzInfo):
_ = range & other

class TestLocalize:
def test_converts_timezone(self):
# Create a datetime range in Sydney, which is
Expand Down
16 changes: 16 additions & 0 deletions xocto/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,21 @@ def intersection(
# We're deliberately overriding the base class here for better performance.
# We can simplify the implementation since we know we're dealing with finite
# ranges with INCLUSIVE_EXCLUSIVE bounds.
if self.tzinfo != other.tzinfo:
raise InconsistentTzInfo(
"inconsistent tzinfo for datetime range intersection"
)
left, right = (self, other) if self.start < other.start else (other, self)
if left.end <= right.start:
return None
else:
return FiniteDatetimeRange(right.start, min(left.end, right.end))

if self.tzinfo != get_tzinfo(other):
raise InconsistentTzInfo(
"inconsistent tzinfo for datetime range intersection"
)

base_intersection = super().intersection(other)
if base_intersection is None:
return None
Expand All @@ -888,12 +897,19 @@ def union(self, other: Range[datetime.datetime]) -> Optional["FiniteDatetimeRang
# We're deliberately overriding the base class here for better performance.
# We can simplify the implementation since we know we're dealing with finite
# ranges with INCLUSIVE_EXCLUSIVE bounds.
if self.tzinfo != other.tzinfo:
raise InconsistentTzInfo(
"inconsistent tzinfo for datetime range intersection"
)
left, right = (self, other) if self.start < other.start else (other, self)
if left.end < right.start:
return None
else:
return FiniteDatetimeRange(left.start, max(left.end, right.end))

if self.tzinfo != get_tzinfo(other):
raise InconsistentTzInfo("inconsistent tzinfo for datetime range union")

try:
base_union = super().union(other)
except ValueError:
Expand Down

0 comments on commit 2b982f0

Please sign in to comment.