Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix target version inference. #3583

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

- Add trailing commas to collection literals even if there's a comment after the last
entry (#3393)
- Fix target version inference (#3583)

### Configuration

Expand Down
197 changes: 164 additions & 33 deletions src/black/files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import io
import operator
import os
import re
import sys
from functools import lru_cache
from enum import Enum
from functools import lru_cache, reduce
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -168,46 +171,174 @@ def parse_req_python_version(requires_python: str) -> Optional[List[TargetVersio
return None


class Endpoint(Enum):
CLOSED = 1
OPEN = 2


CLOSED = Endpoint.CLOSED
OPEN = Endpoint.OPEN


class Interval:
def __init__(self, left: Endpoint, lower: Any, upper: Any, right: Endpoint):
if not (lower < upper or (lower == upper and left == right == CLOSED)):
raise ValueError("empty interval")
self.left = left
self.lower = lower
self.upper = upper
self.right = right


class IntervalSet:
"""Represents a union of intervals."""

def __init__(self, intervals: List[Any]):
self.intervals = intervals

def __and__(self, other: "IntervalSet") -> "IntervalSet":
new_intervals = []
for i1 in self.intervals:
for i2 in other.intervals:
if i1.lower < i2.lower:
lower = i2.lower
left = i2.left
elif i2.lower < i1.lower:
lower = i1.lower
left = i1.left
else:
lower = i1.lower
left = CLOSED if i1.left == i2.left == CLOSED else OPEN
if i1.upper < i2.upper:
upper = i1.upper
right = i1.right
elif i2.upper < i1.upper:
upper = i2.upper
right = i2.right
else:
upper = i1.upper
right = CLOSED if i1.right == i2.right == CLOSED else OPEN
try:
new_intervals.append(Interval(left, lower, upper, right))
except ValueError:
pass
return IntervalSet(new_intervals)

def __or__(self, other: "IntervalSet") -> "IntervalSet":
return IntervalSet(self.intervals + other.intervals)

@property
def empty(self) -> bool:
return len(self.intervals) == 0


def interval(left: Endpoint, lower: Any, upper: Any, right: Endpoint) -> IntervalSet:
try:
return IntervalSet([Interval(left, lower, upper, right)])
except ValueError:
return empty


def singleton(value: Any) -> IntervalSet:
return interval(CLOSED, value, value, CLOSED)


empty = IntervalSet([])
min_ver = Version(f"3.{tuple(TargetVersion)[0].value}")
above_max_ver = Version(f"3.{tuple(TargetVersion)[-1].value + 1}")


def get_interval_set(specifier: Specifier) -> IntervalSet:
if specifier.version.endswith(".*"):
assert specifier.operator in ("==", "!=")
wildcard = True
ver = Version(specifier.version[:-2])
else:
wildcard = False
if specifier.operator != "===":
ver = Version(specifier.version)

if specifier.operator == ">=":
return interval(CLOSED, ver, above_max_ver, OPEN)
if specifier.operator == ">":
return interval(OPEN, ver, above_max_ver, OPEN)
if specifier.operator == "<=":
return interval(CLOSED, min_ver, ver, CLOSED)
if specifier.operator == "<":
return interval(CLOSED, min_ver, ver, OPEN)
if specifier.operator == "==":
if wildcard:
return interval(
CLOSED,
ver,
Version(".".join(map(str, (*ver.release[:-1], ver.release[-1] + 1)))),
OPEN,
)
else:
return singleton(ver)
if specifier.operator == "!=":
if wildcard:
return interval(CLOSED, min_ver, ver, OPEN) | interval(
CLOSED,
Version(".".join(map(str, (*ver.release[:-1], ver.release[-1] + 1)))),
above_max_ver,
OPEN,
)
else:
return interval(CLOSED, min_ver, ver, OPEN) | interval(
OPEN, ver, above_max_ver, OPEN
)
if specifier.operator == "~=":
return interval(
CLOSED,
ver,
Version(".".join(map(str, (*ver.release[:-2], ver.release[-2] + 1)))),
OPEN,
)
if specifier.operator == "===":
# This operator should do a simple string equality test. Pip compares
# it with "X.Y.Z", so only if the version in the specifier is in this
# exact format, it has a chance to match.
if re.fullmatch(r"\d+\.\d+\.\d+", specifier.version):
return singleton(Version(specifier.version))
else:
return empty
raise AssertionError() # pragma: no cover


def parse_req_python_specifier(requires_python: str) -> Optional[List[TargetVersion]]:
"""Parse a specifier string (i.e. ``">=3.7,<3.10"``) to a list of TargetVersion.

If parsing fails, will raise a packaging.specifiers.InvalidSpecifier error.
If the parsed specifier cannot be mapped to a valid TargetVersion, returns None.
If the parsed specifier is empty or cannot be mapped to a valid TargetVersion,
returns None.
"""
specifier_set = strip_specifier_set(SpecifierSet(requires_python))
specifier_set = SpecifierSet(requires_python)
if not specifier_set:
# This means that the specifier has no version clauses. Technically,
# all Python versions are included in this specifier. But because the
# user didn't refer to any specific Python version, we fall back to
# per-file auto-detection.
return None

target_version_map = {f"3.{v.value}": v for v in TargetVersion}
compatible_versions: List[str] = list(specifier_set.filter(target_version_map))
if compatible_versions:
return [target_version_map[v] for v in compatible_versions]
return None


def strip_specifier_set(specifier_set: SpecifierSet) -> SpecifierSet:
"""Strip minor versions for some specifiers in the specifier set.

For background on version specifiers, see PEP 440:
https://peps.python.org/pep-0440/#version-specifiers
"""
specifiers = []
for s in specifier_set:
if "*" in str(s):
specifiers.append(s)
elif s.operator in ["~=", "==", ">=", "==="]:
version = Version(s.version)
stripped = Specifier(f"{s.operator}{version.major}.{version.minor}")
specifiers.append(stripped)
elif s.operator == ">":
version = Version(s.version)
if len(version.release) > 2:
s = Specifier(f">={version.major}.{version.minor}")
specifiers.append(s)
else:
specifiers.append(s)

return SpecifierSet(",".join(str(s) for s in specifiers))
# First, we determine the version interval set from the specifier set (the
# clauses in the specifier set are connected by the logical and operator).
# Then, for each supported Python (minor) version, we check whether the
# interval set intersects with the interval for this Python version.
spec_intervals = reduce(
operator.and_,
map(get_interval_set, specifier_set),
interval(CLOSED, min_ver, above_max_ver, OPEN),
)
target_versions = [
tv
for tv in TargetVersion
if not (spec_intervals & get_interval_set(Specifier(f"==3.{tv.value}.*"))).empty
]
if not target_versions:
return None
else:
return target_versions


@lru_cache()
Expand Down
Loading