Skip to content

Commit

Permalink
[core][experimental] Catch errors for DAG API (#46264)
Browse files Browse the repository at this point in the history
Catch unhandled errors in DAG API:
- when Ray actor tasks with multiple returns are used, throw error
- throw error when dag.execute inputs don't match the expected number of
args or kwargs
- supporting binding kwargs to normal Python values
- throw error when user tries to bind DAG nodes to kwargs (not supported
yet)
- When raising RayTaskError, raise it as instance of the original
exception

Closes #46222.

---------

Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
  • Loading branch information
stephanie-wang committed Jun 27, 2024
1 parent 1e2599f commit 9a2f2c2
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 42 deletions.
8 changes: 8 additions & 0 deletions python/ray/dag/class_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,11 @@ def _get_actor_handle(self) -> Optional["ray.actor.ActorHandle"]:
if not isinstance(self._parent_class_node, ray.actor.ActorHandle):
return None
return self._parent_class_node

@property
def num_returns(self) -> int:
num_returns = self._bound_options.get("num_returns", None)
if num_returns is None:
method = self._get_remote_method(self._method_name)
num_returns = method.__getstate__()["num_returns"]
return num_returns
129 changes: 114 additions & 15 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _exec_task(self, task: "ExecutableTask", idx: int) -> bool:
resolved_inputs.append(task_input.resolve(res))

try:
output_val = method(*resolved_inputs)
output_val = method(*resolved_inputs, **task.resolved_kwargs)
except Exception as exc:
output_val = _wrap_exception(exc)

Expand Down Expand Up @@ -216,6 +216,10 @@ def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
def args(self) -> Tuple[Any]:
return self.dag_node.get_args()

@property
def kwargs(self) -> Dict[str, Any]:
return self.dag_node.get_kwargs()

@property
def num_readers(self) -> int:
return len(self.downstream_node_idxs)
Expand Down Expand Up @@ -259,7 +263,7 @@ def extract_arg(args_tuple):
return extract_arg

if input_attr_node:
key = input_attr_node.get_other_args_to_resolve()["key"]
key = input_attr_node.key
else:
key = 0
self._adapt_method = extractor(key)
Expand Down Expand Up @@ -316,6 +320,7 @@ def __init__(
self,
task: "CompiledTask",
resolved_args: List[Any],
resolved_kwargs: Dict[str, Any],
):
"""
Args:
Expand All @@ -324,6 +329,9 @@ def __init__(
not Channels will get passed through to the actor method.
If the argument is a channel, it will be replaced by the
value read from the channel before the method executes.
resolved_kwargs: The keyword arguments to the method. Currently, we
do not support binding kwargs to other DAG nodes, so the values
of the dictionary cannot be Channels.
"""
self.method_name = task.dag_node.get_method_name()
self.bind_index = task.dag_node._get_bind_index()
Expand All @@ -333,6 +341,7 @@ def __init__(

self.input_channels: List[ChannelInterface] = []
self.task_inputs: List[_ExecutableTaskInput] = []
self.resolved_kwargs: Dict[str, Any] = resolved_kwargs

# Reverse map for input_channels: maps an input channel to
# its index in input_channels.
Expand Down Expand Up @@ -360,6 +369,11 @@ def __init__(
task_input = _ExecutableTaskInput(arg, None)
self.task_inputs.append(task_input)

# Currently DAGs do not support binding kwargs to other DAG nodes.
for val in self.resolved_kwargs.values():
assert not isinstance(val, ChannelInterface)
assert not isinstance(val, DAGInputAdapter)


@DeveloperAPI
class CompiledDAG:
Expand Down Expand Up @@ -461,6 +475,10 @@ def __init__(
self.input_task_idx: Optional[int] = None
self.output_task_idx: Optional[int] = None
self._has_single_output: bool = False
# Number of expected positional args and kwargs that may be passed to
# dag.execute.
self._input_num_positional_args: Optional[int] = None
self._input_kwargs: Tuple[str, ...] = None
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int)

# Cached attributes that are set during compilation.
Expand Down Expand Up @@ -560,6 +578,14 @@ def _preprocess(self) -> None:
"Compiled DAGs currently require exactly one InputNode"
)

# Whether the DAG binds directly to the InputNode(), versus binding to
# a positional arg or kwarg of the input. For example, a.foo.bind(inp)
# instead of a.foo.bind(inp[0]) or a.foo.bind(inp.key).
direct_input: Optional[bool] = None
# Collect the set of InputNode keys bound to DAG node args.
input_positional_args: Set[int] = set()
input_kwargs: Set[str] = set()

# For each task node, set its upstream and downstream task nodes.
# Also collect the set of tasks that produce torch.tensors.
for node_idx, task in self.idx_to_task.items():
Expand Down Expand Up @@ -587,6 +613,13 @@ def _preprocess(self) -> None:
"Compiled DAGs can only bind methods to an actor "
"that is already created with Actor.remote()"
)

if dag_node.num_returns != 1:
raise ValueError(
"Compiled DAGs only supports actor methods with "
"num_returns=1"
)

self.actor_task_count[actor_handle._actor_id] += 1

if dag_node.type_hint.requires_nccl():
Expand All @@ -604,6 +637,13 @@ def _preprocess(self) -> None:
# with the default type hint for this DAG.
task.dag_node.with_type_hint(self._default_type_hint)

for kwarg, val in task.kwargs.items():
if isinstance(val, DAGNode):
raise ValueError(
"Compiled DAG currently does not support binding to "
"other DAG nodes as kwargs"
)

for arg_idx, arg in enumerate(task.args):
if not isinstance(arg, DAGNode):
continue
Expand All @@ -614,11 +654,40 @@ def _preprocess(self) -> None:
if isinstance(task.dag_node, ClassMethodNode):
downstream_actor_handle = task.dag_node._get_actor_handle()

# If the upstream node is an InputAttributeNode, treat the
# DAG's input node as the actual upstream node
if isinstance(upstream_node.dag_node, InputAttributeNode):
# Record all of the keys used to index the InputNode.
# During execution, we will check that the user provides
# the same args and kwargs.
if isinstance(upstream_node.dag_node.key, int):
input_positional_args.add(upstream_node.dag_node.key)
elif isinstance(upstream_node.dag_node.key, str):
input_kwargs.add(upstream_node.dag_node.key)
else:
raise ValueError(
"InputNode() can only be indexed using int "
"for positional args or str for kwargs."
)

if direct_input is not None and direct_input:
raise ValueError(
"All tasks must either use InputNode() "
"directly, or they must index to specific args or "
"kwargs."
)
direct_input = False

# If the upstream node is an InputAttributeNode, treat the
# DAG's input node as the actual upstream node
upstream_node = self.idx_to_task[self.input_task_idx]

elif isinstance(upstream_node.dag_node, InputNode):
if direct_input is not None and not direct_input:
raise ValueError(
"All tasks must either use InputNode() directly, "
"or they must index to specific args or kwargs."
)
direct_input = True

upstream_node.downstream_node_idxs[node_idx] = downstream_actor_handle
task.arg_type_hints.append(upstream_node.dag_node.type_hint)

Expand Down Expand Up @@ -659,6 +728,14 @@ def _preprocess(self) -> None:
if nccl_actors and self._nccl_group_id is None:
self._nccl_group_id = _init_nccl_group(nccl_actors)

if direct_input:
self._input_num_positional_args = 1
elif not input_positional_args:
self._input_num_positional_args = 0
else:
self._input_num_positional_args = max(input_positional_args) + 1
self._input_kwargs = tuple(input_kwargs)

def _get_or_compile(
self,
) -> None:
Expand Down Expand Up @@ -769,17 +846,6 @@ def _get_or_compile(
for idx in task.downstream_node_idxs:
frontier.append(idx)

from ray.dag.constants import RAY_ADAG_ENABLE_DETECT_DEADLOCK

if RAY_ADAG_ENABLE_DETECT_DEADLOCK and not self._detect_deadlock():
raise ValueError(
"This DAG cannot be compiled because it will deadlock on NCCL "
"calls. If you believe this is a false positive, please disable "
"the graph verification by setting the environment variable "
"RAY_ADAG_ENABLE_DETECT_DEADLOCK to 0 and file an issue at "
"https://github.com/ray-project/ray/issues/new/."
)

# Validate input channels for tasks that have not been visited
for node_idx, task in self.idx_to_task.items():
if (
Expand All @@ -799,6 +865,17 @@ def _get_or_compile(
"or at least one other DAGNode as an input"
)

from ray.dag.constants import RAY_ADAG_ENABLE_DETECT_DEADLOCK

if RAY_ADAG_ENABLE_DETECT_DEADLOCK and not self._detect_deadlock():
raise ValueError(
"This DAG cannot be compiled because it will deadlock on NCCL "
"calls. If you believe this is a false positive, please disable "
"the graph verification by setting the environment variable "
"RAY_ADAG_ENABLE_DETECT_DEADLOCK to 0 and file an issue at "
"https://github.com/ray-project/ray/issues/new/."
)

input_task = self.idx_to_task[self.input_task_idx]
# Register custom serializers for inputs provided to dag.execute().
input_task.dag_node.type_hint.register_custom_serializer()
Expand Down Expand Up @@ -838,6 +915,7 @@ def _get_or_compile(
executable_task = ExecutableTask(
task,
resolved_args,
task.kwargs,
)
executable_tasks.append(executable_task)
if worker_fn is None:
Expand Down Expand Up @@ -1187,13 +1265,33 @@ def execute(

self._get_or_compile()

self._check_inputs(args, kwargs)
inp = (args, kwargs)
self._dag_submitter.write(inp)

ref = CompiledDAGRef(self, self._execution_index)
self._execution_index += 1
return ref

def _check_inputs(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
"""
Helper method to check that the DAG args provided by the user during
execution are valid according to the defined DAG.
"""
if len(args) != self._input_num_positional_args:
raise ValueError(
"dag.execute() or dag.execute_async() must be "
f"called with {self._input_num_positional_args} positional args, got "
f"{len(args)}"
)

for kwarg in self._input_kwargs:
if kwarg not in kwargs:
raise ValueError(
"dag.execute() or dag.execute_async() "
f"must be called with kwarg `{kwarg}`"
)

async def execute_async(
self,
*args,
Expand All @@ -1214,6 +1312,7 @@ async def execute_async(
raise ValueError("Use execute if enable_asyncio=False")

self._get_or_compile()
self._check_inputs(args, kwargs)
async with self._dag_submission_lock:
inp = (args, kwargs)
await self._dag_submitter.write(inp)
Expand Down
6 changes: 5 additions & 1 deletion python/ray/dag/input_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class InputAttributeNode(DAGNode):
def __init__(
self,
dag_input_node: InputNode,
key: str,
key: Union[int, str],
accessor_method: str,
input_type: str = None,
):
Expand Down Expand Up @@ -291,6 +291,10 @@ def get_result_type(self) -> str:
if "result_type_string" in self._bound_other_args_to_resolve:
return self._bound_other_args_to_resolve["result_type_string"]

@property
def key(self) -> Union[int, str]:
return self._key


@DeveloperAPI
class DAGInputData:
Expand Down
Loading

0 comments on commit 9a2f2c2

Please sign in to comment.