Skip to content

Commit

Permalink
Add extra entrypoints setting for user module injection (PrefectHQ#7179)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanieb committed Dec 6, 2022
1 parent 420139b commit 309b1a9
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/prefect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
import prefect.plugins

prefect.plugins.load_prefect_collections()

prefect.plugins.load_extra_entrypoints()

# Configure logging
import prefect.logging.configuration
Expand Down
85 changes: 69 additions & 16 deletions src/prefect/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@
- prefect.collections: Identifies this package as a Prefect collection that
should be imported when Prefect is imported.
"""

import sys
from types import ModuleType
from typing import Dict, Union
from typing import Any, Dict, Union

import prefect.settings
from prefect.utilities.compat import EntryPoints, entry_points
from prefect.utilities.compat import EntryPoint, EntryPoints, entry_points


def safe_load_entrypoints(group: str) -> Dict[str, Union[Exception, ModuleType]]:
def safe_load_entrypoints(entrypoints: EntryPoints) -> Dict[str, Union[Exception, Any]]:
"""
Load entry points for a group capturing any exceptions that occur.
"""
entrypoints: EntryPoints = entry_points(group=group)

# TODO: `load()` claims to return module types but could return arbitrary types
# too. We can cast the return type if we want to be more correct. We may
# also want to validate the type for the group for entrypoints that have
# a specific type we expect.

results = {}

for entrypoint in entrypoints:
Expand All @@ -34,7 +33,60 @@ def safe_load_entrypoints(group: str) -> Dict[str, Union[Exception, ModuleType]]
except Exception as exc:
result = exc

results[entrypoint.name] = result
results[entrypoint.name or entrypoint.value] = result

return results


def load_extra_entrypoints() -> Dict[str, Union[Exception, Any]]:
# Note: Return values are only exposed for testing.
results = {}

if not prefect.settings.PREFECT_EXTRA_ENTRYPOINTS.value():
return results

values = {
value.strip()
for value in prefect.settings.PREFECT_EXTRA_ENTRYPOINTS.value().split(",")
}

entrypoints = []
for value in values:
try:
entrypoint = EntryPoint(name=None, value=value, group="prefect-extra")
except Exception as exc:
print(
f"Warning! Failed to parse extra entrypoint {value!r}: {type(result).__name__}: {result}",
file=sys.stderr,
)
results[value] = exc
else:
entrypoints.append(entrypoint)

for value, result in zip(
values, safe_load_entrypoints(EntryPoints(entrypoints)).values()
):
results[value] = result

if isinstance(result, Exception):
print(
f"Warning! Failed to load extra entrypoint {value!r}: {type(result).__name__}: {result}",
file=sys.stderr,
)
elif callable(result):
try:
results[value] = result()
except Exception as exc:
print(
f"Warning! Failed to run callable entrypoint {value!r}: {type(exc).__name__}: {exc}",
file=sys.stderr,
)
results[value] = exc
else:
if prefect.settings.PREFECT_DEBUG_MODE:
print(
"Loaded extra entrypoint {value!r} successfully.", file=sys.stderr
)

return results

Expand All @@ -44,20 +96,21 @@ def load_prefect_collections() -> Dict[str, ModuleType]:
Load all Prefect collections that define an entrypoint in the group
`prefect.collections`.
"""
collections = safe_load_entrypoints(group="prefect.collections")
collection_entrypoints: EntryPoints = entry_points(group="prefect.collections")
collections = safe_load_entrypoints(collection_entrypoints)

# TODO: Consider the utility of this once we've established this pattern.
# We cannot use a logger here because logging is not yet initialized.
# It would be nice if logging was initialized so we could log failures
# at least.
if prefect.settings.PREFECT_TEST_MODE or prefect.settings.PREFECT_DEBUG_MODE:
for name, result in collections.items():
if isinstance(result, Exception):
print(
# TODO: Use exc_info if we have a logger
f"Failed to load collection {name!r}: {type(result).__name__}: {result}"
)
else:
for name, result in collections.items():
if isinstance(result, Exception):
print(
# TODO: Use exc_info if we have a logger
f"Warning! Failed to load collection {name!r}: {type(result).__name__}: {result}"
)
else:
if prefect.settings.PREFECT_DEBUG_MODE:
print(f"Loaded collection {name!r}.")

return collections
12 changes: 12 additions & 0 deletions src/prefect/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,18 @@ def default_cloud_ui_url(settings, value):
directory may be created automatically when required.
"""

PREFECT_EXTRA_ENTRYPOINTS = Setting(
str,
default="",
)
"""
Modules for Prefect to import when Prefect is imported.
Values should be separated by commas, e.g. `my_module,my_other_module`.
Objects within modules may be specified by a ':' partition, e.g. `my_module:my_object`.
If a callable object is provided, it will be called with no arguments on import.
"""

PREFECT_DEBUG_MODE = Setting(
bool,
default=False,
Expand Down
238 changes: 238 additions & 0 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import importlib
import pathlib
import textwrap

import pytest

from prefect.plugins import load_extra_entrypoints
from prefect.settings import PREFECT_EXTRA_ENTRYPOINTS, temporary_settings
from prefect.testing.utilities import exceptions_equal


@pytest.fixture
def module_fixture(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.syspath_prepend(tmp_path)
(tmp_path / ("test_module_name.py")).write_text(
textwrap.dedent(
"""
import os
from unittest.mock import MagicMock
def returns_test():
return "test"
def returns_foo():
return "foo"
def returns_bar():
return "bar"
def raises_value_error():
raise ValueError("test")
def raises_base_exception():
raise BaseException("test")
def mock_function(*args, **kwargs):
mock = MagicMock()
mock(*args, **kwargs)
return mock
"""
)
)

yield

importlib.invalidate_caches()


@pytest.fixture
def raising_module_fixture(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
monkeypatch.syspath_prepend(tmp_path)
(tmp_path / ("raising_module_name.py")).write_text(
textwrap.dedent(
"""
raise RuntimeError("test")
"""
)
)

yield

importlib.invalidate_caches()


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_imports_module():
with temporary_settings({PREFECT_EXTRA_ENTRYPOINTS: "test_module_name"}):
result = load_extra_entrypoints()

assert set(result.keys()) == {"test_module_name"}
assert result["test_module_name"] == importlib.import_module("test_module_name")


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_strips_spaces():
with temporary_settings({PREFECT_EXTRA_ENTRYPOINTS: " test_module_name "}):
result = load_extra_entrypoints()

assert set(result.keys()) == {"test_module_name"}
assert result["test_module_name"] == importlib.import_module("test_module_name")


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_unparsable_entrypoint(capsys):
with temporary_settings({PREFECT_EXTRA_ENTRYPOINTS: "foo$bar"}):
result = load_extra_entrypoints()

assert set(result.keys()) == {"foo$bar"}
assert exceptions_equal(
result["foo$bar"], AttributeError("'NoneType' object has no attribute 'group'")
)

_, stderr = capsys.readouterr()
assert (
"Warning! Failed to load extra entrypoint 'foo$bar': "
"AttributeError: 'NoneType' object has no attribute 'group'"
) in stderr


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_callable():
with temporary_settings(
{PREFECT_EXTRA_ENTRYPOINTS: "test_module_name:returns_test"}
):
result = load_extra_entrypoints()

assert set(result.keys()) == {"test_module_name:returns_test"}
assert result["test_module_name:returns_test"] == "test"


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_multiple_entrypoints():
with temporary_settings(
{
PREFECT_EXTRA_ENTRYPOINTS: "test_module_name:returns_foo,test_module_name:returns_bar"
}
):
result = load_extra_entrypoints()

assert result == {
"test_module_name:returns_foo": "foo",
"test_module_name:returns_bar": "bar",
}


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_callable_given_no_arguments():
with temporary_settings(
{PREFECT_EXTRA_ENTRYPOINTS: "test_module_name:mock_function"}
):
result = load_extra_entrypoints()

assert set(result.keys()) == {"test_module_name:mock_function"}
result["test_module_name:mock_function"].assert_called_once_with()


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_callable_that_raises(capsys):
with temporary_settings(
{PREFECT_EXTRA_ENTRYPOINTS: "test_module_name:raises_value_error"}
):
result = load_extra_entrypoints()

assert set(result.keys()) == {"test_module_name:raises_value_error"}
assert exceptions_equal(
result["test_module_name:raises_value_error"], ValueError("test")
)

_, stderr = capsys.readouterr()
assert (
"Warning! Failed to run callable entrypoint "
"'test_module_name:raises_value_error': ValueError: test"
) in stderr


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_callable_that_raises_base_exception():
with temporary_settings(
{PREFECT_EXTRA_ENTRYPOINTS: "test_module_name:raises_base_exception"}
):
with pytest.raises(BaseException, match="test"):
load_extra_entrypoints()


@pytest.mark.usefixtures("raising_module_fixture")
def test_load_extra_entrypoints_error_on_import(capsys):
with temporary_settings({PREFECT_EXTRA_ENTRYPOINTS: "raising_module_name"}):
result = load_extra_entrypoints()

assert set(result.keys()) == {"raising_module_name"}
assert exceptions_equal(result["raising_module_name"], RuntimeError("test"))

_, stderr = capsys.readouterr()
assert (
"Warning! Failed to load extra entrypoint 'raising_module_name': "
"RuntimeError: test"
) in stderr


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_missing_module(capsys):
with temporary_settings({PREFECT_EXTRA_ENTRYPOINTS: "nonexistant_module"}):
result = load_extra_entrypoints()

assert set(result.keys()) == {"nonexistant_module"}
assert exceptions_equal(
result["nonexistant_module"],
ModuleNotFoundError("No module named 'nonexistant_module'"),
)

_, stderr = capsys.readouterr()
assert (
"Warning! Failed to load extra entrypoint 'nonexistant_module': "
"ModuleNotFoundError"
) in stderr


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_missing_submodule(capsys):
with temporary_settings(
{PREFECT_EXTRA_ENTRYPOINTS: "test_module_name.missing_module"}
):
result = load_extra_entrypoints()

assert set(result.keys()) == {"test_module_name.missing_module"}
assert exceptions_equal(
result["test_module_name.missing_module"],
ModuleNotFoundError(
"No module named 'test_module_name.missing_module'; "
"'test_module_name' is not a package"
),
)

_, stderr = capsys.readouterr()
assert (
"Warning! Failed to load extra entrypoint 'test_module_name.missing_module': "
"ModuleNotFoundError"
) in stderr


@pytest.mark.usefixtures("module_fixture")
def test_load_extra_entrypoints_missing_attribute(capsys):
with temporary_settings(
{PREFECT_EXTRA_ENTRYPOINTS: "test_module_name:missing_attr"}
):
result = load_extra_entrypoints()

assert set(result.keys()) == {"test_module_name:missing_attr"}
assert exceptions_equal(
result["test_module_name:missing_attr"],
AttributeError("module 'test_module_name' has no attribute 'missing_attr'"),
)

_, stderr = capsys.readouterr()
assert (
"Warning! Failed to load extra entrypoint 'test_module_name:missing_attr': "
"AttributeError"
) in stderr

0 comments on commit 309b1a9

Please sign in to comment.