Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 101 additions & 18 deletions src/fromager/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,28 @@ def _is_blocked_specifier(specifier: SpecifierSet) -> bool:
)


def _format_provenance(sources: set[str]) -> str:
"""Format provenance sources as a human-readable string.

Args:
sources: Set of source file paths or URLs.

Returns:
Comma-separated string of sources, e.g.
``"/path/to/base.txt, /path/to/override.txt"``.
"""
return ", ".join(sorted(sources))


class InvalidConstraintError(ValueError):
pass


class Constraints:
def __init__(self) -> None:
# mapping of canonical names to requirements
# mapping of canonical names to (requirement, provenance sources)
# NOTE: Requirement.name is not normalized
self._data: dict[NormalizedName, Requirement] = {}
self._data: dict[NormalizedName, tuple[Requirement, set[str]]] = {}

def __iter__(self) -> Generator[NormalizedName, None, None]:
yield from self._data
Expand All @@ -46,17 +59,26 @@ def __bool__(self) -> bool:
def __len__(self) -> int:
return len(self._data)

def add_constraint(self, unparsed: str) -> None:
"""Add new constraint, must not conflict with any existing constraints
def add_constraint(self, unparsed: str, *, provenance: str | None = None) -> None:
"""Add new constraint, must not conflict with any existing constraints.

Args:
unparsed: Raw constraint string, e.g. ``"foo>=2.0"``.
provenance: Path or URL of the file that contains this
constraint. Used for provenance tracking in error messages
and merged output.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

.. versionchanged: 0.83.0
.. versionchanged:: 0.83.0
Non-conflicting constraints are now combined. Constraints with
conflicts now raise :exc:`InvalidConstraintError`. Inputs without a
version specifier or with extras or url are also refused.

.. versionchanged:: 0.84.0
Added *provenance* parameter for source file tracking.
"""
req = Requirement(unparsed)
canon_name = canonicalize_name(req.name)
previous = self._data.get(canon_name)
existing = self._data.get(canon_name)

# validator properties: must have a specifier, must not have extras or URL
if req.extras:
Expand All @@ -78,46 +100,106 @@ def add_constraint(self, unparsed: str) -> None:
logger.debug(f"Constraint {req} does not match environment")
return

if previous is not None:
if existing is not None:
previous, prev_sources = existing
prev_blocked = _is_blocked_specifier(previous.specifier)
prev_prov = _format_provenance(prev_sources)
existing_desc = (
f"{previous} from {prev_prov}" if prev_prov else str(previous)
)
new_desc = f"{req} from {provenance}" if provenance else str(req)
if blocked != prev_blocked:
raise InvalidConstraintError(
f"Cannot combine blocked and non-blocked constraints "
f"(existing: {previous}, new: {req})"
f"(existing: {existing_desc}, new: {new_desc})"
)
if not blocked:
logger.debug("combining constraints %s and %s", previous, req)
new_specifier = req.specifier & previous.specifier
if new_specifier.is_unsatisfiable():
raise InvalidConstraintError(
f"Combined specifier '{new_specifier}' is not satisfiable "
f"(existing: {previous}, new: {req})"
f"(existing: {existing_desc}, new: {new_desc})"
)
req.specifier = new_specifier
sources = prev_sources
else:
logger.debug(f"adding constraint {req}")
sources = set()

self._data[canon_name] = req
if provenance is not None:
sources.add(provenance)
self._data[canon_name] = (req, sources)

def load_constraints_file(self, constraints_file: str | pathlib.Path) -> None:
"""Load constraints from a constraints file or URL"""
"""Load constraints from a constraints file or URL."""
logger.info("loading constraints from %s", constraints_file)
file_provenance = str(constraints_file)
content = requirements_file.parse_requirements_file(constraints_file)
for line in content:
self.add_constraint(line)
self.add_constraint(line, provenance=file_provenance)

def dump_constraints(self, output: typing.TextIO) -> None:
"""Dump combined constraints to a text stream"""
# sort by normalized name
for _, req in sorted(self._data.items()):
# write requirement without markers. They have been evaluated
# in add_constraint()
"""Dump combined constraints to a text stream.

Source files that contributed each constraint are listed as comment
lines above the constraint line.

Args:
output: Writable text stream.

.. versionchanged:: 0.84.0
Output now includes provenance comments above each constraint.
"""
# sort by normalized name, write requirement without markers.
# They have been evaluated in add_constraint()
for _name, (req, sources) in sorted(self._data.items()):
for source in sorted(sources):
output.write(f"# {source}\n")
output.write(f"{req.name}{req.specifier}\n")

def get_constraint(self, name: str) -> Requirement | None:
return self._data.get(canonicalize_name(name))
"""Return the merged constraint for *name*, or ``None``."""
constraint_entry = self._data.get(canonicalize_name(name))
if constraint_entry is not None:
return constraint_entry[0]
return None

def get_constraint_with_provenance(
self, name: str
) -> tuple[Requirement, set[str]] | tuple[None, None]:
"""Return the constraint and its provenance sources for *name*.

Returns:
``(requirement, source_files)`` if constrained, or
``(None, None)`` if the package has no constraints.
The returned set is a copy.

.. versionadded:: 0.84.0
"""
constraint_entry = self._data.get(canonicalize_name(name))
if constraint_entry is not None:
req, sources = constraint_entry
return req, set(sources)
return None, None

def format_provenance(self, name: str) -> str:
"""Return a human-readable provenance string for *name*.

Returns:
Comma-separated list of source files, e.g.
``"/path/to/base.txt, /path/to/override.txt"``,
or an empty string if the package has no constraints.

.. versionadded:: 0.84.0
"""
constraint_entry = self._data.get(canonicalize_name(name))
if constraint_entry is not None:
return _format_provenance(constraint_entry[1])
return ""

def allow_prerelease(self, pkg_name: str) -> bool:
"""Return ``True`` if the constraint for *pkg_name* allows prereleases."""
constraint = self.get_constraint(pkg_name)
if constraint:
return bool(constraint.specifier.prereleases)
Expand All @@ -131,6 +213,7 @@ def is_blocked(self, pkg_name: str) -> bool:
return False

def is_satisfied_by(self, pkg_name: str, version: Version) -> bool:
"""Return ``True`` if *version* satisfies the constraint for *pkg_name*."""
constraint = self.get_constraint(pkg_name)
if constraint:
return constraint.specifier.contains(version, prereleases=True)
Expand Down
10 changes: 8 additions & 2 deletions src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,13 @@ def find_all_matching_from_provider(
)
except resolvelib.resolvers.ResolverException as err:
constraint = provider.constraints.get_constraint(req.name)
prov_str = provider.constraints.format_provenance(req.name)
provider_desc = provider.get_provider_description()
original_msg = str(err)
prov_msg = f" (from {prov_str})" if prov_str else ""
raise resolvelib.resolvers.ResolverException(
f"Unable to resolve requirement specifier {req} with constraint {constraint} using {provider_desc}: {original_msg}"
f"Unable to resolve requirement specifier {req} with constraint "
f"{constraint}{prov_msg} using {provider_desc}: {original_msg}"
) from err

# Materialize candidates so we can iterate more than once if filtering
Expand Down Expand Up @@ -689,8 +692,11 @@ def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> boo
if not self.constraints.is_satisfied_by(requirement.name, candidate.version):
if DEBUG_RESOLVER:
c = self.constraints.get_constraint(requirement.name)
prov_str = self.constraints.format_provenance(requirement.name)
prov_msg = f" from {prov_str}" if prov_str else ""
logger.debug(
f"{requirement.name}: skipping {candidate.version} due to constraint {c}"
f"{requirement.name}: skipping {candidate.version} "
f"due to constraint {c}{prov_msg}"
)
return False

Expand Down
136 changes: 136 additions & 0 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,139 @@ def test_non_blocked_then_blocked_raises() -> None:
def test_is_blocked_unknown_package() -> None:
c = Constraints()
assert not c.is_blocked("unknown")


def test_provenance_single_source() -> None:
"""Provenance tracks the source for a directly added constraint."""
c = Constraints()
c.add_constraint("foo>=2.0", provenance="/path/to/base.txt")
req, sources = c.get_constraint_with_provenance("foo")
assert req == Requirement("foo>=2.0")
assert sources == {"/path/to/base.txt"}


def test_provenance_multiple_sources() -> None:
"""Provenance records both files when two files constrain the same package."""
c = Constraints()
c.add_constraint("foo>=2.0", provenance="/path/to/base.txt")
c.add_constraint("foo!=2.1.1", provenance="/path/to/override.txt")
req, sources = c.get_constraint_with_provenance("foo")
assert req == Requirement("foo!=2.1.1,>=2.0")
assert sources == {"/path/to/base.txt", "/path/to/override.txt"}


def test_provenance_same_source_multiple_lines() -> None:
"""Multiple constraints from the same file appear once in the set."""
c = Constraints()
c.add_constraint("foo>=2.0", provenance="shared.txt")
c.add_constraint("foo!=2.1.1", provenance="shared.txt")
req, sources = c.get_constraint_with_provenance("foo")
assert req == Requirement("foo!=2.1.1,>=2.0")
assert sources == {"shared.txt"}


def test_provenance_unknown_package() -> None:
"""Provenance returns (None, None) for unconstrained packages."""
c = Constraints()
req, sources = c.get_constraint_with_provenance("nonexistent")
assert req is None
assert sources is None


def test_provenance_load_constraints_file(tmp_path: pathlib.Path) -> None:
"""Loading a file records the file path as the provenance source."""
constraint_file = tmp_path / "constraints-base.txt"
constraint_file.write_text("egg==1.0\ntorch>=2.0\n")
c = Constraints()
c.load_constraints_file(constraint_file)
_, egg_sources = c.get_constraint_with_provenance("egg")
_, torch_sources = c.get_constraint_with_provenance("torch")
assert egg_sources == {str(constraint_file)}
assert torch_sources == {str(constraint_file)}


def test_provenance_load_multiple_files(tmp_path: pathlib.Path) -> None:
"""Loading two files with the same package tracks both sources."""
base = tmp_path / "base.txt"
base.write_text("foo>=2.0\nbar==1.0\n")
override = tmp_path / "override.txt"
override.write_text("foo!=2.1.1\n")

c = Constraints()
c.load_constraints_file(base)
c.load_constraints_file(override)

_, foo_sources = c.get_constraint_with_provenance("foo")
_, bar_sources = c.get_constraint_with_provenance("bar")
assert foo_sources == {str(base), str(override)}
assert bar_sources == {str(base)}


def test_provenance_returns_copy() -> None:
"""get_constraint_with_provenance returns a copy of sources."""
c = Constraints()
c.add_constraint("foo>=1.0", provenance="a.txt")
_, sources = c.get_constraint_with_provenance("foo")
assert sources is not None
sources.add("injected.txt")
_, sources2 = c.get_constraint_with_provenance("foo")
assert sources2 is not None
assert "injected.txt" not in sources2


def test_dump_constraints_multiple_sources() -> None:
"""dump_constraints lists source files as comments above each constraint."""
c = Constraints()
c.add_constraint("foo>=2.0", provenance="/path/to/base.txt")
c.add_constraint("foo!=2.1.1", provenance="/path/to/override.txt")
c.add_constraint("bar==1.0", provenance="/path/to/base.txt")

out = io.StringIO()
c.dump_constraints(out)
result = out.getvalue()

assert "# /path/to/base.txt\nbar==1.0\n" in result
assert "# /path/to/base.txt\n# /path/to/override.txt\nfoo!=2.1.1,>=2.0\n" in result


def test_conflict_error_includes_provenance() -> None:
"""InvalidConstraintError message includes source file provenance."""
c = Constraints()
c.add_constraint("foo>=2.0", provenance="/constraints/base.txt")
with pytest.raises(
InvalidConstraintError,
match=r"(?=.*base\.txt)(?=.*override\.txt)",
):
c.add_constraint("foo<1.0", provenance="/constraints/override.txt")


def test_conflict_error_without_provenance() -> None:
"""Error messages omit 'from' clause when provenance is None."""
c = Constraints()
c.add_constraint("foo>=2.0")
with pytest.raises(InvalidConstraintError, match=r"existing: foo>=2\.0,") as exc:
c.add_constraint("foo<1.0")
assert "from None" not in str(exc.value)
assert "from ," not in str(exc.value)


def test_add_constraint_without_provenance() -> None:
"""Constraints added without provenance work and don't pollute tracking."""
c = Constraints()
c.add_constraint("foo>=1.0")
req, sources = c.get_constraint_with_provenance("foo")
assert req == Requirement("foo>=1.0")
assert sources == set()
assert c.format_provenance("foo") == ""


def test_format_provenance() -> None:
"""format_provenance returns a sorted comma-separated string of sources."""
c = Constraints()
assert c.format_provenance("foo") == ""

c.add_constraint("foo>=2.0", provenance="/path/to/base.txt")
assert c.format_provenance("foo") == "/path/to/base.txt"

c.add_constraint("foo!=2.1.1", provenance="/path/to/override.txt")
assert c.format_provenance("foo") == "/path/to/base.txt, /path/to/override.txt"
4 changes: 4 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_pip_constraints_args(tmp_path: pathlib.Path) -> None:
"# auto-generated constraints file",
f"# {constraints_file}",
"",
f"# {constraints_file}",
"test==1.0",
"",
)
Expand Down Expand Up @@ -87,7 +88,10 @@ def test_multiple_constraints_files(tmp_path: pathlib.Path) -> None:
f"# {constraints2}",
f"# {constraints3}",
"",
f"# {constraints2}",
f"# {constraints3}",
"foo!=2.1.1,>=2.0",
f"# {constraints1}",
"test==1.0",
"",
)
Expand Down
Loading