diff --git a/nvtabular/workflow.py b/nvtabular/workflow.py index 8e501171330..adb7c9a4eea 100644 --- a/nvtabular/workflow.py +++ b/nvtabular/workflow.py @@ -24,8 +24,6 @@ from nvtabular.ops import StatOperator from nvtabular.worker import clean_worker_cache -# import yaml - LOG = logging.getLogger("nvtabular") @@ -158,7 +156,8 @@ def _transform_ddf(ddf, column_groups): columns = list(flatten(cg.flattened_columns for cg in column_groups)) return ddf.map_partitions( - lambda gdf: _transform_partition(gdf, column_groups), + _transform_partition, + column_groups, meta=cudf.DataFrame({k: [] for k in columns}), )