diff --git a/sdks/python/apache_beam/yaml/readme_test.py b/sdks/python/apache_beam/yaml/readme_test.py index efce92490a359..555d1d0b583f2 100644 --- a/sdks/python/apache_beam/yaml/readme_test.py +++ b/sdks/python/apache_beam/yaml/readme_test.py @@ -55,6 +55,8 @@ def expand(self, inputs): def guess_name_and_type(expr): expr = expr.strip().replace('`', '') + if expr.endswith('*'): + return 'unknown', str parts = expr.split() if len(parts) >= 2 and parts[-2].lower() == 'as': name = parts[-1] @@ -87,7 +89,7 @@ def guess_name_and_type(expr): return name, typ if m.group(1) == '*': - return inputs['PCOLLECTION'] | beam.Filter(lambda _: True) + return next(iter(inputs.values())) | beam.Filter(lambda _: True) else: output_schema = [ guess_name_and_type(expr) for expr in m.group(1).split(',') @@ -269,6 +271,22 @@ def test(self): def parse_test_methods(markdown_lines): # pylint: disable=too-many-nested-blocks + + def extract_inputs(input_spec): + if not input_spec: + return set() + elif isinstance(input_spec, str): + return set([input_spec.split('.')[0]]) + elif isinstance(input_spec, list): + return set.union(*[extract_inputs(v) for v in input_spec]) + elif isinstance(input_spec, dict): + return set.union(*[extract_inputs(v) for v in input_spec.values()]) + else: + raise ValueError("Misformed inputs: " + input_spec) + + def extract_name(input_spec): + return input_spec.get('name', input_spec.get('type')) + code_lines = None for ix, line in enumerate(markdown_lines): line = line.rstrip() @@ -280,17 +298,23 @@ def parse_test_methods(markdown_lines): else: if code_lines: if code_lines[0].startswith('- type:'): - is_chain = not any('input:' in line for line in code_lines) + specs = yaml.load('\n'.join(code_lines), Loader=SafeLoader) + is_chain = not any('input' in spec for spec in specs) + if is_chain: + undefined_inputs = set(['input']) + else: + undefined_inputs = set.union( + *[extract_inputs(spec.get('input')) for spec in specs]) - set( + extract_name(spec) for spec in specs) # Treat this as a fragment of a larger pipeline. # pylint: disable=not-an-iterable code_lines = [ 'pipeline:', ' type: chain' if is_chain else '', ' transforms:', - ' - type: ReadFromCsv', - ' name: input', - ' config:', - ' path: whatever', + ] + [ + ' - {type: ReadFromCsv, name: "%s", config: {path: x}}' % + undefined_input for undefined_input in undefined_inputs ] + [' ' + line for line in code_lines] if code_lines[0] == 'pipeline:': yaml_pipeline = '\n'.join(code_lines) @@ -329,6 +353,9 @@ def createTestSuite(name, path): InlinePythonTest = createTestSuite( 'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md')) +JoinTest = createTestSuite( + 'JoinTest', os.path.join(YAML_DOCS_DIR, 'yaml-join.md')) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--render_dir', default=None) diff --git a/sdks/python/apache_beam/yaml/yaml_join.py b/sdks/python/apache_beam/yaml/yaml_join.py index 04a24642c2317..5124ef56b49c1 100644 --- a/sdks/python/apache_beam/yaml/yaml_join.py +++ b/sdks/python/apache_beam/yaml/yaml_join.py @@ -173,8 +173,9 @@ def _is_connected(edge_list, expected_node_count): def _SqlJoinTransform( pcolls, sql_transform_constructor, - type: Union[str, Dict[str, List]], + *, equalities: Union[str, List[Dict[str, str]]], + type: Union[str, Dict[str, List]] = 'inner', fields: Optional[Dict[str, Any]] = None): """Joins two or more inputs using a specified condition. diff --git a/website/www/site/content/en/documentation/sdks/yaml-join.md b/website/www/site/content/en/documentation/sdks/yaml-join.md index d207926ff995b..6645d7a945a34 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-join.md +++ b/website/www/site/content/en/documentation/sdks/yaml-join.md @@ -48,7 +48,7 @@ inputs, one can use the following shorthand syntax: input2: Second Input input3: Third Input config: - equalities: col + equalities: col1 ``` ## Join Types