diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index b3c595cf..89f1e761 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -43,6 +43,8 @@ def read_json( files_per_partition: int = 1, add_filename: bool = False, input_meta: Union[str, dict] = None, + columns: Optional[List[str]] = None, + **kwargs, ): return cls( _read_json_or_parquet( @@ -52,6 +54,8 @@ def read_json( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + columns=columns, + **kwargs, ) ) @@ -62,6 +66,8 @@ def read_parquet( backend="pandas", files_per_partition=1, add_filename=False, + columns: Optional[List[str]] = None, + **kwargs, ): return cls( _read_json_or_parquet( @@ -70,6 +76,8 @@ def read_parquet( backend=backend, files_per_partition=files_per_partition, add_filename=add_filename, + columns=columns, + **kwargs, ) ) @@ -80,6 +88,8 @@ def read_pickle( backend="pandas", files_per_partition=1, add_filename=False, + columns: Optional[List[str]] = None, + **kwargs, ): return cls( read_data( @@ -88,6 +98,8 @@ def read_pickle( backend=backend, files_per_partition=files_per_partition, add_filename=add_filename, + columns=columns, + **kwargs, ) ) @@ -175,6 +187,8 @@ def _read_json_or_parquet( files_per_partition: int, add_filename: bool, input_meta: Union[str, dict] = None, + columns: Optional[List[str]] = None, + **kwargs, ): """ `input_files` may be a list or a string type. @@ -205,6 +219,8 @@ def _read_json_or_parquet( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + columns=columns, + **kwargs, ) # List of directories @@ -222,6 +238,8 @@ def _read_json_or_parquet( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + columns=columns, + **kwargs, ) dfs.append(df) @@ -245,6 +263,8 @@ def _read_json_or_parquet( files_per_partition=files_per_partition, add_filename=add_filename, input_meta=input_meta, + columns=columns, + **kwargs, ) else: diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index c41b5ea9..8f36139b 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -260,6 +260,8 @@ def read_single_partition( filetype="jsonl", add_filename=False, input_meta: Union[str, dict] = None, + columns: Optional[List[str]] = None, + **kwargs, ) -> Union[cudf.DataFrame, pd.DataFrame]: """ This function reads a file with cuDF, sorts the columns of the DataFrame @@ -271,6 +273,7 @@ def read_single_partition( add_filename: Whether to add a "filename" column to the DataFrame. input_meta: A dictionary or a string formatted as a dictionary, which outlines the field names and their respective data types within the JSONL input file. + columns: If not None, only these columns will be read from the file. Returns: A cudf DataFrame or a pandas DataFrame. @@ -296,12 +299,14 @@ def read_single_partition( read_kwargs["dtype"] = ( ast.literal_eval(input_meta) if type(input_meta) == str else input_meta ) + elif filetype == "parquet": - read_kwargs = {} + read_kwargs = {"columns": columns} if backend == "cudf": read_f = cudf.read_parquet else: read_f = pd.read_parquet + else: raise RuntimeError("Could not read data, please check file type") @@ -312,7 +317,7 @@ def read_single_partition( # cuDF supports reading multiple files at once read_files_one_at_a_time = False else: - # pandas does not support reading multiple files at once + # Pandas does not support reading multiple files at once read_files_one_at_a_time = True if read_files_one_at_a_time: @@ -322,31 +327,43 @@ def read_single_partition( concat_f = pd.concat df_ls = [] for file in files: - df = read_f(file, **read_kwargs) + df = read_f(file, **read_kwargs, **kwargs) if add_filename: df["filename"] = os.path.basename(file) df_ls.append(df) df = concat_f(df_ls, ignore_index=True) else: - df = read_f(files, **read_kwargs) + df = read_f(files, **read_kwargs, **kwargs) + + if filetype in ["jsonl", "json"] and columns is not None: + if add_filename and "filename" not in columns: + columns.append("filename") + df = df[columns] + df = df[sorted(df.columns)] return df -def read_pandas_pickle(file, add_filename=False) -> pd.DataFrame: +def read_pandas_pickle( + file, add_filename=False, columns=None, **kwargs +) -> pd.DataFrame: """ - This function reads a pickle file with pandas and adds a "filename" column. + This function reads a pickle file with Pandas. Args: file: The path to the pickle file to read. add_filename: Whether to add a "filename" column to the DataFrame. Returns: - A pandas DataFrame. + A Pandas DataFrame. """ if add_filename: warnings.warn("add_filename is not supported for pickle files") - return pd.read_pickle(file) + + if columns is not None: + return pd.read_pickle(file, **kwargs)[columns] + else: + return pd.read_pickle(file, **kwargs) def read_data( @@ -356,6 +373,8 @@ def read_data( files_per_partition: int = 1, add_filename: bool = False, input_meta: Union[str, dict] = None, + columns: Optional[List[str]] = None, + **kwargs, ) -> Union[dd.DataFrame, dask_cudf.DataFrame]: """ This function can read multiple data formats and returns a Dask-cuDF DataFrame. @@ -368,6 +387,7 @@ def read_data( add_filename: Whether to add a "filename" column to the DataFrame. input_meta: A dictionary or a string formatted as a dictionary, which outlines the field names and their respective data types within the JSONL input file. + columns: If not None, only these columns will be read from the file. Returns: A Dask-cuDF or a Dask-pandas DataFrame. @@ -378,7 +398,9 @@ def read_data( test_obj = cudf.Series if file_type == "pickle": - df = read_pandas_pickle(input_files[0], add_filename=add_filename) + df = read_pandas_pickle( + input_files[0], add_filename=add_filename, columns=columns, **kwargs + ) df = dd.from_pandas(df, npartitions=16) if backend == "cudf": df = df.to_backend("cudf") @@ -401,6 +423,8 @@ def read_data( add_filename=add_filename, input_meta=input_meta, enforce_metadata=False, + columns=columns, + **kwargs, ) else: raise RuntimeError("Could not read data, please check file type")