Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic error handling for yaml. #27145

Merged
merged 2 commits into from
Jun 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,9 @@ def expand(self, pcoll):
*self._args,
**self._kwargs).with_outputs(
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
result[self._main_tag].element_type = self._fn.infer_output_type(
pcoll.element_type)

if self._threshold < 1.0:

Expand Down Expand Up @@ -2244,6 +2247,81 @@ def process(self, *args, **kwargs):
traceback.format_exception(*sys.exc_info()))))


# Idea adapted from https://github.com/tosun-si/asgarde.
# TODO(robertwb): Consider how this could fit into the public API.
# TODO(robertwb): Generalize to all PValue types.
class _PValueWithErrors(object):
"""This wraps a PCollection such that transforms can be chained in a linear
manner while still accumulating any errors."""
def __init__(self, pcoll, exception_handling_args, upstream_errors=()):
self._pcoll = pcoll
self._exception_handling_args = exception_handling_args
self._upstream_errors = upstream_errors

def main_output_tag(self):
return self._exception_handling_args.get('main_tag', 'good')

def error_output_tag(self):
return self._exception_handling_args.get('dead_letter_tag', 'bad')

def __or__(self, transform):
return self.apply(transform)

def apply(self, transform):
result = self._pcoll | transform.with_exception_handling(
**self._exception_handling_args)
if result[self.main_output_tag()].element_type == typehints.Any:
result[self.main_output_tag()].element_type = transform.infer_output_type(
self._pcoll.element_type)
# TODO(BEAM-18957): Add support for tagged type hints.
result[self.error_output_tag()].element_type = typehints.Any
return _PValueWithErrors(
result[self.main_output_tag()],
self._exception_handling_args,
self._upstream_errors + (result[self.error_output_tag()], ))

def accumulated_errors(self):
if len(self._upstream_errors) == 1:
return self._upstream_errors[0]
else:
return self._upstream_errors | Flatten()
Comment on lines +2283 to +2287
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to separately handle a case when self._upstream_errors contains one element? Shouldn't Flatten() handle that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but that might have issues with streaming pipeline upgrades. Keeping as is for now as not all runners handle one-element flattens optimally.


def as_result(self, error_post_processing=None):
return {
self.main_output_tag(): self._pcoll,
self.error_output_tag(): self.accumulated_errors()
if error_post_processing is None else self.accumulated_errors()
| error_post_processing,
}


class _MaybePValueWithErrors(object):
"""This is like _PValueWithErrors, but only wraps values if
exception_handling_args is non-trivial. It is useful for handling
error-catching and non-error-catching code in a uniform manner.
"""
def __init__(self, pvalue, exception_handling_args=None):
if isinstance(pvalue, _PValueWithErrors):
assert exception_handling_args is None
self._pvalue = pvalue
elif exception_handling_args is None:
self._pvalue = pvalue
else:
self._pvalue = _PValueWithErrors(pvalue, exception_handling_args)

def __or__(self, transform):
return self.apply(transform)

def apply(self, transform):
return _MaybePValueWithErrors(self._pvalue | transform)

def as_result(self, error_post_processing=None):
if isinstance(self._pvalue, _PValueWithErrors):
return self._pvalue.as_result(error_post_processing)
else:
return self._pvalue


class _SubprocessDoFn(DoFn):
"""Process method run in a subprocess, turning hard crashes into exceptions.
"""
Expand Down Expand Up @@ -3232,14 +3310,21 @@ def __init__(
_expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args)
] + [(name, _expr_to_callable(expr, name))
for (name, expr) in kwargs.items()]
self._exception_handling_args = None

def with_exception_handling(self, **kwargs):
self._exception_handling_args = kwargs
return self

def default_label(self):
return 'ToRows(%s)' % ', '.join(name for name, _ in self._fields)

def expand(self, pcoll):
return pcoll | Map(
lambda x: pvalue.Row(**{name: expr(x)
for name, expr in self._fields}))
return (
_MaybePValueWithErrors(pcoll, self._exception_handling_args) | Map(
lambda x: pvalue.Row(
**{name: expr(x)
for name, expr in self._fields}))).as_result()

def infer_output_type(self, input_type):
return row_type.RowTypeConstraint.from_fields([
Expand Down Expand Up @@ -3430,6 +3515,9 @@ def process(
new_windows = self.windowing.windowfn.assign(context)
yield WindowedValue(element, context.timestamp, new_windows)

def infer_output_type(self, input_type):
return input_type

def __init__(
self,
windowfn, # type: typing.Union[Windowing, WindowFn]
Expand Down
48 changes: 41 additions & 7 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class _Explode(beam.PTransform):
def __init__(self, fields, cross_product):
self._fields = fields
self._cross_product = cross_product
self._exception_handling_args = None

def expand(self, pcoll):
all_fields = [
Expand All @@ -86,28 +87,53 @@ def explode_zip(base, fields):
copy[field] = values[ix]
yield beam.Row(**copy)

return pcoll | beam.FlatMap(
lambda row: (
explode_cross_product if self._cross_product else explode_zip)(
{name: getattr(row, name) for name in all_fields}, # yapf break
to_explode))
return (
beam.core._MaybePValueWithErrors(
pcoll, self._exception_handling_args)
| beam.FlatMap(
lambda row: (
explode_cross_product if self._cross_product else explode_zip)(
{name: getattr(row, name) for name in all_fields}, # yapf
to_explode))
).as_result()

def infer_output_type(self, input_type):
return row_type.RowTypeConstraint.from_fields([(
name,
trivial_inference.element_type(typ) if name in self._fields else
typ) for (name, typ) in named_fields_from_element_type(input_type)])

def with_exception_handling(self, **kwargs):
# It's possible there's an error in iteration...
self._exception_handling_args = kwargs
return self


# TODO(yaml): Should Filter and Explode be distinct operations from Project?
# We'll want these per-language.
@beam.ptransform.ptransform_fn
def _PythonProjectionTransform(
pcoll, *, fields, keep=None, explode=(), cross_product=True):
pcoll,
*,
fields,
keep=None,
explode=(),
cross_product=True,
error_handling=None):
original_fields = [
name for (name, _) in named_fields_from_element_type(pcoll.element_type)
]

if error_handling is None:
error_handling_args = None
else:
error_handling_args = {
'dead_letter_tag' if k == 'output' else k: v
for (k, v) in error_handling.items()
}

pcoll = beam.core._MaybePValueWithErrors(pcoll, error_handling_args)

if keep:
if isinstance(keep, str) and keep in original_fields:
keep_fn = lambda row: getattr(row, keep)
Expand All @@ -131,7 +157,11 @@ def _PythonProjectionTransform(
else:
result = projected

return result
return result.as_result(
beam.MapTuple(
lambda element,
exc_info: beam.Row(
element=element, msg=str(exc_info[1]), stack=str(exc_info[2]))))


@beam.ptransform.ptransform_fn
Expand All @@ -146,6 +176,7 @@ def MapToFields(
append=False,
drop=(),
language=None,
error_handling=None,
**language_keywords):

if isinstance(explode, str):
Expand Down Expand Up @@ -192,6 +223,8 @@ def MapToFields(
language = "python"

if language in ("sql", "calcite"):
if error_handling:
raise ValueError('Error handling unsupported for sql.')
selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
if keep:
Expand All @@ -215,6 +248,7 @@ def MapToFields(
'keep': keep,
'explode': explode,
'cross_product': cross_product,
'error_handling': error_handling,
},
**language_keywords
})
Expand Down
37 changes: 36 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ def get_pcollection(self, name):
if len(outputs) == 1:
return only_element(outputs.values())
else:
error_output = self._transforms_by_uuid[self.get_transform_id(
name)].get('error_handling', {}).get('output')
if error_output and error_output in outputs and len(outputs) == 2:
return next(
output for tag, output in outputs.items() if tag != error_output)
raise ValueError(
f'Ambiguous output at line {SafeLineLoader.get_line(name)}: '
f'{name} has outputs {list(outputs.keys())}')
Expand Down Expand Up @@ -620,6 +625,34 @@ def all_inputs(t):
return dict(spec, transforms=new_transforms)


def ensure_transforms_have_types(spec):
if 'type' not in spec:
raise ValueError(f'Missing type specification in {identify_object(spec)}')
return spec


def ensure_errors_consumed(spec):
if spec['type'] == 'composite':
scope = LightweightScope(spec['transforms'])
to_handle = {}
consumed = set(
scope.get_transform_id_and_output_name(output)
for output in spec['output'].values())
for t in spec['transforms']:
if 'error_handling' in t:
if 'output' not in t['error_handling']:
raise ValueError(
f'Missing output in error_handling of {identify_object(t)}')
to_handle[t['__uuid__'], t['error_handling']['output']] = t
for _, input in t['input'].items():
if input not in spec['input']:
consumed.add(scope.get_transform_id_and_output_name(input))
for error_pcoll, t in to_handle.items():
if error_pcoll not in consumed:
raise ValueError(f'Unconsumed error output for {identify_object(t)}.')
return spec


def preprocess(spec, verbose=False):
if verbose:
pprint.pprint(spec)
Expand All @@ -631,10 +664,12 @@ def apply(phase, spec):
spec, transforms=[apply(phase, t) for t in spec['transforms']])
return spec

for phase in [preprocess_source_sink,
for phase in [ensure_transforms_have_types,
preprocess_source_sink,
preprocess_chain,
normalize_inputs_outputs,
preprocess_flattened_inputs,
ensure_errors_consumed,
preprocess_windowing]:
spec = apply(phase, spec)
if verbose:
Expand Down
Loading
Loading