Skip to content

Commit

Permalink
Add @experimental_param and @deprecated_param decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Aug 1, 2023
1 parent efbc485 commit 8857f68
Show file tree
Hide file tree
Showing 32 changed files with 800 additions and 349 deletions.
277 changes: 255 additions & 22 deletions python_modules/dagster/dagster/_annotations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import inspect
from dataclasses import dataclass
from typing import Callable, Optional, TypeVar, Union, overload
from typing import Callable, Mapping, Optional, TypeVar, Union, overload

from typing_extensions import Annotated, Final, TypeAlias

from dagster import _check as check
from dagster._core.decorator_utils import (
Decoratable,
apply_pre_call_decorator,
get_decorator_target,
is_resource_def,
)
from dagster._utils.backcompat import (
Expand Down Expand Up @@ -164,9 +165,9 @@ def not_deprecated_function():
additional_warn_text=additional_warn_text,
stacklevel=_get_warning_stacklevel(__obj),
)
return apply_pre_call_decorator(__obj, warning_fn)
return apply_pre_call_decorator(__obj, warning_fn) # type: ignore # (pyright bug)
else:
return __obj
return __obj # type: ignore # (pyright bug)


def is_deprecated(obj: Annotatable) -> bool:
Expand All @@ -179,6 +180,113 @@ def get_deprecated_info(obj: Annotatable) -> DeprecatedInfo:
return getattr(target, _DEPRECATED_ATTR_NAME)


# ########################
# ##### DEPRECATED PARAM
# ########################

_DEPRECATED_PARAM_ATTR_NAME: Final[str] = "_deprecated_params"


@overload
def deprecated_param(
__obj: T_Annotatable,
*,
param: str,
breaking_version: str,
additional_warn_text: Optional[str] = ...,
emit_runtime_warning: bool = ...,
) -> T_Annotatable:
...


@overload
def deprecated_param(
__obj: None = ...,
*,
param: str,
breaking_version: str,
additional_warn_text: Optional[str] = ...,
emit_runtime_warning: bool = ...,
) -> Callable[[T_Annotatable], T_Annotatable]:
...


def deprecated_param(
__obj: Optional[T_Annotatable] = None,
*,
param: str,
breaking_version: str,
additional_warn_text: Optional[str] = None,
emit_runtime_warning: bool = True,
) -> T_Annotatable:
"""Mark a parameter of a class initializer or function/method as deprecated. This appends some
metadata to the decorated object that causes the specified argument to be rendered with a
"deprecated" tag and associated warning in the docs.
If `emit_runtime_warning` is True, a warning will also be emitted when the function is called
and a non-None value is passed for the parameter. For consistency between docs and runtime
warnings, this decorator is preferred to manual calls to `deprecation_warning`. Note that the
warning will only be emitted if the value is passed as a keyword argument.
Args:
param (str): The name of the parameter to deprecate.
breaking_version (str): The version at which the deprecated function will be removed.
additional_warn_text (str): Additional text to display after the deprecation warning.
Typically this should suggest a newer API.
emit_runtime_warning (bool): Whether to emit a warning when the function is called.
"""
if __obj is None:
return lambda obj: deprecated_param( # type: ignore
obj,
param=param,
breaking_version=breaking_version,
additional_warn_text=additional_warn_text,
emit_runtime_warning=emit_runtime_warning,
)
else:
check.invariant(
_annotatable_has_param(__obj, param),
f"Attempted to mark undefined parameter `{param}` deprecated.",
)
target = _get_annotation_target(__obj)
if not hasattr(target, _DEPRECATED_PARAM_ATTR_NAME):
setattr(target, _DEPRECATED_PARAM_ATTR_NAME, {})
getattr(target, _DEPRECATED_PARAM_ATTR_NAME)[param] = DeprecatedInfo(
breaking_version=breaking_version,
additional_warn_text=additional_warn_text,
)

if emit_runtime_warning:
condition = lambda *_, **kwargs: kwargs.get(param) is not None
warning_fn = lambda: deprecation_warning(
_get_subject(__obj, param=param),
breaking_version=breaking_version,
additional_warn_text=additional_warn_text,
stacklevel=4,
)
return apply_pre_call_decorator(__obj, warning_fn, condition=condition) # type: ignore # (pyright bug)
else:
return __obj # type: ignore # (pyright bug)


def has_deprecated_params(obj: Annotatable) -> bool:
return hasattr(_get_annotation_target(obj), _DEPRECATED_PARAM_ATTR_NAME)


def get_deprecated_params(obj: Annotatable) -> Mapping[str, DeprecatedInfo]:
return getattr(_get_annotation_target(obj), _DEPRECATED_PARAM_ATTR_NAME)


def is_deprecated_param(obj: Annotatable, param_name: str) -> bool:
target = _get_annotation_target(obj)
return param_name in getattr(target, _DEPRECATED_PARAM_ATTR_NAME, {})


def get_deprecated_param_info(obj: Annotatable, param_name: str) -> DeprecatedInfo:
target = _get_annotation_target(obj)
return getattr(target, _DEPRECATED_PARAM_ATTR_NAME)[param_name]


# ########################
# ##### EXPERIMENTAL
# ########################
Expand Down Expand Up @@ -265,9 +373,9 @@ class MyExperimentalClass:
subject or _get_subject(__obj),
stacklevel=_get_warning_stacklevel(__obj),
)
return apply_pre_call_decorator(__obj, warning_fn)
return apply_pre_call_decorator(__obj, warning_fn) # type: ignore # (pyright bug)
else:
return __obj
return __obj # type: ignore # (pyright bug)


def is_experimental(obj: Annotatable) -> bool:
Expand All @@ -280,6 +388,107 @@ def get_experimental_info(obj: Annotatable) -> ExperimentalInfo:
return getattr(target, _EXPERIMENTAL_ATTR_NAME)


# ########################
# ##### EXPERIMENTAL PARAM
# ########################

_EXPERIMENTAL_PARAM_ATTR_NAME: Final[str] = "_experimental_params"


@overload
def experimental_param(
__obj: T_Annotatable,
*,
param: str,
additional_warn_text: Optional[str] = ...,
emit_runtime_warning: bool = ...,
) -> T_Annotatable:
...


@overload
def experimental_param(
__obj: None = ...,
*,
param: str,
additional_warn_text: Optional[str] = ...,
emit_runtime_warning: bool = ...,
) -> Callable[[T_Annotatable], T_Annotatable]:
...


def experimental_param(
__obj: Optional[T_Annotatable] = None,
*,
param: str,
additional_warn_text: Optional[str] = None,
emit_runtime_warning: bool = True,
) -> Union[T_Annotatable, Callable[[T_Annotatable], T_Annotatable]]:
"""Mark a parameter of a class initializer or function/method as experimental. This appends some
metadata to the decorated object that causes the specified argument to be rendered with an
"experimental" tag and associated warning in the docs.
If `emit_runtime_warning` is True, a warning will also be emitted when the function is called
and a non-None value is passed for the parameter. For consistency between docs and runtime
warnings, this decorator is preferred to manual calls to `experimental_warning`. Note that the
warning will only be emitted if the value is passed as a keyword argument.
Args:
param (str): The name of the parameter to mark experimental.
additional_warn_text (str): Additional text to display after the deprecation warning.
Typically this should suggest a newer API.
emit_runtime_warning (bool): Whether to emit a warning when the function is called.
"""
if __obj is None:
return lambda obj: experimental_param(
obj,
param=param,
additional_warn_text=additional_warn_text,
emit_runtime_warning=emit_runtime_warning,
)
else:
check.invariant(
_annotatable_has_param(__obj, param),
f"Attempted to mark undefined parameter `{param}` experimental.",
)
target = _get_annotation_target(__obj)

if not hasattr(target, _EXPERIMENTAL_PARAM_ATTR_NAME):
setattr(target, _EXPERIMENTAL_PARAM_ATTR_NAME, {})
getattr(target, _EXPERIMENTAL_PARAM_ATTR_NAME)[param] = ExperimentalInfo(
additional_warn_text=additional_warn_text,
)

if emit_runtime_warning:
condition = lambda *_, **kwargs: kwargs.get(param) is not None
warning_fn = lambda: experimental_warning(
_get_subject(__obj, param=param),
additional_warn_text=additional_warn_text,
stacklevel=4,
)
return apply_pre_call_decorator(__obj, warning_fn, condition=condition) # type: ignore # (pyright bug)
else:
return __obj # type: ignore # (pyright bug)


def has_experimental_params(obj: Annotatable) -> bool:
return hasattr(_get_annotation_target(obj), _EXPERIMENTAL_PARAM_ATTR_NAME)


def get_experimental_params(obj: Annotatable) -> Mapping[str, ExperimentalInfo]:
return getattr(_get_annotation_target(obj), _EXPERIMENTAL_PARAM_ATTR_NAME)


def is_experimental_param(obj: Annotatable, param_name: str) -> bool:
target = _get_annotation_target(obj)
return param_name in getattr(target, _EXPERIMENTAL_PARAM_ATTR_NAME, {})


def get_experimental_param_info(obj: Annotatable, param_name: str) -> ExperimentalInfo:
target = _get_annotation_target(obj)
return getattr(target, _EXPERIMENTAL_PARAM_ATTR_NAME)[param_name]


# ########################
# ##### HELPERS
# ########################
Expand All @@ -289,12 +498,24 @@ def copy_annotations(dest: Annotatable, src: Annotatable) -> None:
"""Copy all Dagster annotations from one object to another object."""
dest_target = _get_annotation_target(dest)
src_target = _get_annotation_target(src)
if hasattr(src_target, _DEPRECATED_ATTR_NAME):
setattr(dest_target, _DEPRECATED_ATTR_NAME, getattr(src_target, _DEPRECATED_ATTR_NAME))
if hasattr(src_target, _PUBLIC_ATTR_NAME):
setattr(dest_target, _PUBLIC_ATTR_NAME, getattr(src_target, _PUBLIC_ATTR_NAME))
if hasattr(src_target, _DEPRECATED_ATTR_NAME):
setattr(dest_target, _DEPRECATED_ATTR_NAME, getattr(src_target, _DEPRECATED_ATTR_NAME))
if hasattr(src_target, _DEPRECATED_PARAM_ATTR_NAME):
setattr(
dest_target,
_DEPRECATED_PARAM_ATTR_NAME,
getattr(src_target, _DEPRECATED_PARAM_ATTR_NAME),
)
if hasattr(src_target, _EXPERIMENTAL_ATTR_NAME):
setattr(dest_target, _EXPERIMENTAL_ATTR_NAME, getattr(src_target, _EXPERIMENTAL_ATTR_NAME))
if hasattr(src_target, _EXPERIMENTAL_PARAM_ATTR_NAME):
setattr(
dest_target,
_EXPERIMENTAL_PARAM_ATTR_NAME,
getattr(src_target, _EXPERIMENTAL_PARAM_ATTR_NAME),
)


def _get_annotation_target(obj: Annotatable) -> object:
Expand All @@ -309,25 +530,32 @@ def _get_annotation_target(obj: Annotatable) -> object:
return obj


def _get_subject(obj: Annotatable) -> str:
def _get_subject(obj: Annotatable, param: Optional[str] = None) -> str:
"""Get the string representation of an annotated object that will appear in
annotation-generated warnings about the object.
"""
if isinstance(obj, type):
return f"Class `{obj.__qualname__}`"
elif isinstance(obj, property):
return f"Property `{obj.fget.__qualname__ if obj.fget else obj}`"
# classmethod and staticmethod don't themselves get a `__qualname__` attr until Python 3.10.
elif isinstance(obj, classmethod):
return f"Class method `{_get_annotation_target(obj).__qualname__}`" # type: ignore
elif isinstance(obj, staticmethod):
return f"Static method `{_get_annotation_target(obj).__qualname__}`" # type: ignore
elif inspect.isfunction(obj):
return f"Function `{obj.__qualname__}`"
elif is_resource_def(obj):
return f"Dagster resource `{obj.__qualname__}`" # type: ignore # (bad stubs)
if param:
if isinstance(obj, type):
return f"Parameter `{param}` of initializer `{obj.__qualname__}.__init__`"
else:
fn_subject = _get_subject(obj)
return f"Parameter `{param}` of {fn_subject[:1].lower() + fn_subject[1:]}"
else:
check.failed(f"Unexpected object type: {type(obj)}")
if isinstance(obj, type):
return f"Class `{obj.__qualname__}`"
elif isinstance(obj, property):
return f"Property `{obj.fget.__qualname__ if obj.fget else obj}`"
# classmethod and staticmethod don't themselves get a `__qualname__` attr until Python 3.10.
elif isinstance(obj, classmethod):
return f"Class method `{_get_annotation_target(obj).__qualname__}`" # type: ignore
elif isinstance(obj, staticmethod):
return f"Static method `{_get_annotation_target(obj).__qualname__}`" # type: ignore
elif inspect.isfunction(obj):
return f"Function `{obj.__qualname__}`"
elif is_resource_def(obj):
return f"Dagster resource `{obj.__qualname__}`" # type: ignore # (bad stubs)
else:
check.failed(f"Unexpected object type: {type(obj)}")


def _get_warning_stacklevel(obj: Annotatable):
Expand All @@ -343,3 +571,8 @@ def _get_warning_stacklevel(obj: Annotatable):
return 6
else:
return 4


def _annotatable_has_param(obj: Annotatable, param: str) -> bool:
target_fn = get_decorator_target(obj)
return param in inspect.signature(target_fn).parameters
12 changes: 9 additions & 3 deletions python_modules/dagster/dagster/_core/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,25 @@ def get_decorator_target(obj: Decoratable) -> Callable:
return obj


def apply_pre_call_decorator(obj: T_Decoratable, pre_call_fn: Callable[[], None]) -> T_Decoratable:
def apply_pre_call_decorator(
obj: T_Decoratable,
pre_call_fn: Callable[[], None],
condition: Optional[Callable[..., bool]] = None,
) -> T_Decoratable:
target = get_decorator_target(obj)
new_fn = _wrap_with_pre_call_fn(target, pre_call_fn)
new_fn = _wrap_with_pre_call_fn(target, pre_call_fn, condition)
return _update_decoratable(obj, new_fn)


def _wrap_with_pre_call_fn(
fn: T_Callable,
pre_call_fn: Callable[[], None],
condition: Optional[Callable[..., bool]] = None,
) -> T_Callable:
@functools.wraps(fn)
def wrapped_with_pre_call_fn(*args, **kwargs):
pre_call_fn()
if condition is None or condition(*args, **kwargs):
pre_call_fn()
return fn(*args, **kwargs)

return cast(T_Callable, wrapped_with_pre_call_fn)
Expand Down
Loading

0 comments on commit 8857f68

Please sign in to comment.