Skip to content

Commit

Permalink
Data inspect (#521)
Browse files Browse the repository at this point in the history
* Initial commit in new branch

* Adds unit test

* Updates json output and multihot calculation

* Updates list processing

* Updates test

* Adds cudf issue

* Data inspector ready

* Test works

* Dataset inspect read - Tests passing

* Moves dataset inspector script

* Initial inspect-datagent test

* Data gen and data inspect work together

* Initial Stats computation as an operator

* Improves but still error

* Removes list support to simplify

* Different Series type for computations

* Cleans and use attributes

* Data Stats Operator working

* Tests inspect-datagen working

* Reestructures script and fixes review

* All Working
  • Loading branch information
Alberto Alvarez authored Jan 26, 2021
1 parent 408a893 commit 3a9940d
Show file tree
Hide file tree
Showing 8 changed files with 688 additions and 158 deletions.
1 change: 1 addition & 0 deletions nvtabular/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .bucketize import Bucketize
from .categorify import Categorify, _get_embedding_order, get_embedding_sizes
from .clip import Clip
from .data_stats import DataStats
from .difference_lag import DifferenceLag
from .dropna import Dropna
from .fill import FillMedian, FillMissing
Expand Down
107 changes: 107 additions & 0 deletions nvtabular/ops/data_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dask_cudf
import numpy as np
from nvtx import annotate

from .operator import ColumnNames
from .stat_operator import StatOperator


class DataStats(StatOperator):
def __init__(self):
super().__init__()
self.col_names = []
self.col_types = []
self.col_dtypes = []
self.output = {}

@annotate("DataStats_fit", color="green", domain="nvt_python")
def fit(self, columns: ColumnNames, ddf: dask_cudf.DataFrame):
dask_stats = {}

ddf_dtypes = ddf.head(1)

# For each column, calculate the stats
for col in columns:
dask_stats[col] = {}
self.col_names.append(col)
# Get dtype for all
dtype = ddf_dtypes[col].dtype
self.col_dtypes.append(dtype)

# Identify column type
if np.issubdtype(dtype, np.floating):
col_type = "conts"
else:
col_type = "cats"
self.col_types.append(col_type)

# Get cardinality for cats
if col_type == "cats":
dask_stats[col]["cardinality"] = ddf[col].nunique()

# if string, replace string for their lengths for the rest of the computations
if dtype == "object":
ddf[col] = ddf[col].map_partitions(lambda x: x.str.len(), meta=("x", int))
# Add list support when cudf supports it:
# https://github.com/rapidsai/cudf/issues/7157
# elif col_type == "cat_mh":
# ddf[col] = ddf[col].map_partitions(lambda x: x.list.len())

# Get min,max, and mean
dask_stats[col]["min"] = ddf[col].min()
dask_stats[col]["max"] = ddf[col].max()
dask_stats[col]["mean"] = ddf[col].mean()

# Get std only for conts
if col_type == "conts":
dask_stats[col]["std"] = ddf[col].std()

# Get Percentage of NaNs for all
dask_stats[col]["per_nan"] = 100 * (1 - ddf[col].count() / len(ddf[col]))

return dask_stats

def fit_finalize(self, dask_stats):
for i, col in enumerate(self.col_names):
# Add dtype
dask_stats[col]["dtype"] = str(self.col_dtypes[i])
# Cast types for yaml
if isinstance(dask_stats[col]["mean"], np.floating):
dask_stats[col]["mean"] = dask_stats[col]["mean"].item()
if isinstance(dask_stats[col]["per_nan"], np.floating):
dask_stats[col]["per_nan"] = dask_stats[col]["per_nan"].item()
if self.col_types[i] == "conts":
if isinstance(dask_stats[col]["std"], np.floating):
dask_stats[col]["std"] = dask_stats[col]["std"].item()
else:
if isinstance(dask_stats[col]["cardinality"], np.integer):
dask_stats[col]["cardinality"] = dask_stats[col]["cardinality"].item()
self.output = dask_stats

def save(self):
return self.output

def load(self, data):
self.output = data

def clear(self):
self.output = {}

# transform.__doc__ = Operator.transform.__doc__
fit.__doc__ = StatOperator.fit.__doc__
fit_finalize.__doc__ = StatOperator.fit_finalize.__doc__
1 change: 1 addition & 0 deletions nvtabular/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
24 changes: 13 additions & 11 deletions nvtabular/tools/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def create_labels(self, size, labs_rep):
df = cudf.DataFrame()
for col in labs_rep:
dist = col.distro or self.dist
ser = dist.create_col(size, dtype=col.dtype, min_val=1, max_val=col.cardinality).ceil()
# bring back down to correct representation because of ceil call
ser = ser - 1
ser = dist.create_col(size, dtype=col.dtype, min_val=0, max_val=col.cardinality).ceil()
ser.name = col.name
df = cudf.concat([df, ser], axis=1)
return df
Expand Down Expand Up @@ -165,7 +163,7 @@ def create_df(
):
conts_rep = cols["conts"] if "conts" in cols else None
cats_rep = cols["cats"] if "cats" in cols else None
labs_rep = cols["labs"] if "labs" in cols else None
labs_rep = cols["labels"] if "labels" in cols else None
df = cudf.DataFrame()
if conts_rep:
df = cudf.concat([df, self.create_conts(size, conts_rep)], axis=1)
Expand Down Expand Up @@ -361,16 +359,17 @@ def __init__(
self.min_entry_size = min_entry_size
self.max_entry_size = max_entry_size
self.avg_entry_size = avg_entry_size
self.per_nan = None
self.per_nan = per_nan
self.multi_min = multi_min
self.multi_max = multi_max
self.multi_avg = multi_avg


class LabelCol(Col):
def __init__(self, name, dtype, cardinality, distro=None):
def __init__(self, name, dtype, cardinality, per_nan=None, distro=None):
super().__init__(name, dtype, distro)
self.cardinality = cardinality
self.per_nan = per_nan


def _get_cols_from_schema(schema, distros=None):
Expand All @@ -381,22 +380,23 @@ def _get_cols_from_schema(schema, distros=None):
Schema example
num_rows:
conts:
col_name:
dtype:
min_val:
max_val:
mean:
standard deviation:
% NaNs:
std:
per_nan:
cats:
col_name:
dtype:
cardinality:
min_entry_size:
max_entry_size:
avg_entry_size:
% NaNs:
per_nan:
multi_min:
multi_max:
multi_avg:
Expand All @@ -405,11 +405,13 @@ def _get_cols_from_schema(schema, distros=None):
col_name:
dtype:
cardinality:
% NaNs:
per_nan:
"""
cols = {}
executor = {"conts": ContCol, "cats": CatCol, "labs": LabelCol}
executor = {"conts": ContCol, "cats": CatCol, "labels": LabelCol}
for section, vals in schema.items():
if section == "num_rows":
continue
cols[section] = []
for col_name, val in vals.items():
v_dict = {"name": col_name}
Expand Down
111 changes: 111 additions & 0 deletions nvtabular/tools/dataset_inspector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json

import fsspec
import numpy as np
import yaml

from nvtabular.ops import DataStats
from nvtabular.workflow import Workflow


# Class to help Json to serialize the data
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NpEncoder, self).default(obj)


class DatasetInspector:
"""
Analyzes an existing dataset to extract its statistics.
"""

def __init__(self, client=None):
self.client = client

def inspect(self, dataset, columns_dict, output_file):
"""
Parameters
-----------
path: str, list of str, or <dask.dataframe|cudf|pd>.DataFrame
Dataset path (or list of paths), or a DataFrame. If string,
should specify a specific file or directory path. If this is a
directory path, the directory structure must be flat (nested
directories are not yet supported).
dataset_format: str
Dataset format (i.e parquet or csv)
columns_dict: dict
Dictionary indicating the diferent columns type
output_file: str
Filename to write the output statistics
"""

# Get dataset columns
cats = columns_dict["cats"]
conts = columns_dict["conts"]
labels = columns_dict["labels"]

# Create Dataset, Workflow, and get Stats
features = cats + conts + labels >> DataStats()
workflow = Workflow(features, client=self.client)
workflow.fit(dataset)

# Save stats in a file and read them back
stats_file = "stats_output.yaml"
workflow.save_stats(stats_file)
output = yaml.safe_load(open(stats_file))
output = output[1]["stats"]

# Dictionary to store collected information
data = {}
# Store num_rows
data["num_rows"] = dataset.num_rows
# Store cols
for col_type in ["conts", "cats", "labels"]:
data[col_type] = {}
for col in columns_dict[col_type]:
data[col_type][col] = {}
data[col_type][col]["dtype"] = output[col]["dtype"]

if col_type != "conts":
data[col_type][col]["cardinality"] = output[col]["cardinality"]

if col_type == "cats":
data[col_type][col]["min_entry_size"] = output[col]["min"]
data[col_type][col]["max_entry_size"] = output[col]["max"]
data[col_type][col]["avg_entry_size"] = output[col]["mean"]
elif col_type == "conts":
data[col_type][col]["min_val"] = output[col]["min"]
data[col_type][col]["max_val"] = output[col]["max"]
data[col_type][col]["mean"] = output[col]["mean"]
data[col_type][col]["std"] = output[col]["std"]

data[col_type][col]["per_nan"] = output[col]["per_nan"]

# Write json file
with fsspec.open(output_file, "w") as outfile:
json.dump(data, outfile, cls=NpEncoder)
Loading

0 comments on commit 3a9940d

Please sign in to comment.