Skip to content

Replace pkg resources #3772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import shutil
import subprocess
import sys
from importlib.metadata import (
PackageNotFoundError,
distribution,
distributions,
)
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand All @@ -40,7 +45,6 @@
)

import click
import pkg_resources
import yaml
from pydantic import BaseModel, SecretStr
from rich import box, table
Expand Down Expand Up @@ -1120,9 +1124,9 @@ def is_installed_in_python_environment(package: str) -> bool:
True if the package is installed, False otherwise.
"""
try:
pkg_resources.get_distribution(package)
distribution(package)
return True
except pkg_resources.DistributionNotFound:
except PackageNotFoundError:
return False


Expand Down Expand Up @@ -2511,16 +2515,18 @@ def get_package_information(
A dictionary of the name:version for the package names passed in or
all packages and their respective versions.
"""
import pkg_resources
all_packages = {
dist.metadata["name"].lower(): dist.version for dist in distributions()
}

if package_names:
return {
pkg.key: pkg.version
for pkg in pkg_resources.working_set
if pkg.key in package_names
name: version
for name, version in all_packages.items()
if name in package_names
}

return {pkg.key: pkg.version for pkg in pkg_resources.working_set}
return all_packages


def print_user_info(info: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -2618,9 +2624,9 @@ def is_jupyter_installed() -> bool:
bool: True if Jupyter notebook is installed, False otherwise.
"""
try:
pkg_resources.get_distribution("notebook")
distribution("notebook")
return True
except pkg_resources.DistributionNotFound:
except PackageNotFoundError:
return False


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

import os
import sys
from importlib.metadata import distribution
from typing import Any, List, Set

import pkg_resources

from zenml.entrypoints.step_entrypoint_configuration import (
StepEntrypointConfiguration,
)
Expand Down Expand Up @@ -81,17 +80,19 @@ def run(self) -> None:
"""Runs the step."""
# Get the wheel package and add it to the sys path
wheel_package = self.entrypoint_args[WHEEL_PACKAGE_OPTION]
distribution = pkg_resources.get_distribution(wheel_package)
project_root = os.path.join(distribution.location, wheel_package)
dist = distribution(wheel_package)
# Get the location from distribution files (first file's parent directory)
location = str(dist.locate_file("")).parent if dist.files else ""
project_root = os.path.join(location, wheel_package)
if project_root not in sys.path:
sys.path.insert(0, project_root)
sys.path.insert(-1, project_root)

# Get the job id and add it to the environment
databricks_job_id = self.entrypoint_args[DATABRICKS_JOB_ID_OPTION]
os.environ[ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID] = (
databricks_job_id
)
os.environ[
ENV_ZENML_DATABRICKS_ORCHESTRATOR_RUN_ID
] = databricks_job_id

# Run the step
super().run()
71 changes: 49 additions & 22 deletions src/zenml/integrations/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
"""Base and meta classes for ZenML integrations."""

import re
from importlib.metadata import PackageNotFoundError, distribution
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast

import pkg_resources
from pkg_resources import Requirement
from packaging.requirements import Requirement

from zenml.integrations.registry import integration_registry
from zenml.logger import get_logger
Expand Down Expand Up @@ -72,57 +72,84 @@ def check_installation(cls) -> bool:
for r in cls.get_requirements():
try:
# First check if the base package is installed
dist = pkg_resources.get_distribution(r)
package_name, extras = parse_requirement(r)
dist = distribution(package_name)

# Next, check if the dependencies (including extras) are
# installed
deps: List[Requirement] = []
deps: List[str] = []

_, extras = parse_requirement(r)
if extras:
extra_list = extras[1:-1].split(",")
for extra in extra_list:
try:
requirements = dist.requires(extras=[extra]) # type: ignore[arg-type]
except pkg_resources.UnknownExtra as e:
logger.debug(f"Unknown extra: {str(e)}")
# Get requires for specific extra
if (
dist.requires
and extra
in dist.metadata.get_all("Provides-Extra", [])
):
extra_deps = [
req
for req in dist.requires
if f'extra == "{extra}"' in req
]
deps.extend(extra_deps)
else:
logger.debug(f"Unknown extra: {extra}")
return False
except Exception as e:
logger.debug(
f"Error processing extra {extra}: {str(e)}"
)
return False
deps.extend(requirements)
else:
deps = dist.requires()
deps = list(dist.requires or [])

for ri in deps:
try:
# Remove the "extra == ..." part from the requirement string
cleaned_req = re.sub(
r"; extra == \"\w+\"", "", str(ri)
)
pkg_resources.get_distribution(cleaned_req)
except pkg_resources.DistributionNotFound as e:
req_obj = Requirement(cleaned_req)
dep_dist = distribution(req_obj.name)

# Check version compatibility
if (
req_obj.specifier
and not req_obj.specifier.contains(
dep_dist.version
)
):
logger.debug(
f"Package version '{dep_dist.version}' does not match "
f"version '{req_obj.specifier}' required by '{r}' "
f"necessary for integration '{cls.NAME}'."
)
return False
except PackageNotFoundError:
logger.debug(
f"Unable to find required dependency "
f"'{e.req}' for requirement '{r}' "
f"'{cleaned_req}' for requirement '{r}' "
f"necessary for integration '{cls.NAME}'."
)
return False
except pkg_resources.VersionConflict as e:
except Exception as e:
logger.debug(
f"Package version '{e.dist}' does not match "
f"version '{e.req}' required by '{r}' "
f"necessary for integration '{cls.NAME}'."
f"Error checking dependency '{cleaned_req}': {str(e)}"
)
return False

except pkg_resources.DistributionNotFound as e:
except PackageNotFoundError:
logger.debug(
f"Unable to find required package '{e.req}' for "
f"Unable to find required package '{package_name}' for "
f"integration {cls.NAME}."
)
return False
except pkg_resources.VersionConflict as e:
except Exception as e:
logger.debug(
f"Package version '{e.dist}' does not match version "
f"'{e.req}' necessary for integration {cls.NAME}."
f"Error checking package '{package_name}': {str(e)}"
)
return False

Expand Down
Loading