Skip to content

Commit

Permalink
Unittest ops + bugfix in Bucketize (#496)
Browse files Browse the repository at this point in the history
* test_minmix

* updates test

* unittest ops
  • Loading branch information
bschifferer authored Dec 16, 2020
1 parent 5467d86 commit f216edf
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 583 deletions.
3 changes: 1 addition & 2 deletions nvtabular/ops/bucketize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,5 @@ def transform(self, columns, gdf: cudf.DataFrame):
val = 0
for boundary in b:
val += (gdf[col] >= boundary).astype("int")
new_col = f"{col}_{self._id}"
new_gdf[new_col] = val
new_gdf[col] = val
return new_gdf
17 changes: 8 additions & 9 deletions nvtabular/ops/hashed_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@ def __init__(self, num_buckets):
self.num_buckets = num_buckets

@annotate("HashedCross_op", color="darkgreen", domain="nvt_python")
def op_logic(self, columns, gdf: cudf.DataFrame):
def transform(self, columns, gdf: cudf.DataFrame):
new_gdf = cudf.DataFrame()
for cross in columns:
val = 0
for column in cross:
val ^= gdf[column].hash_values() # or however we want to do this aggregation
# TODO: support different size buckets per cross
val = val % self.bucket_size
new_gdf["_X_".join(cross)] = val
val = 0
for column in columns:
val ^= gdf[column].hash_values() # or however we want to do this aggregation
# TODO: support different size buckets per cross
val = val % self.num_buckets
new_gdf["_X_".join(columns)] = val
return new_gdf

def output_column_names(self, columns):
return ["_X_".join(cross) for cross in columns]
return ["_X_".join(columns)]
8 changes: 4 additions & 4 deletions nvtabular/ops/join_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
cache="host",
**kwargs,
):
super().__init__(replace=False)
super(JoinExternal).__init__()
self.on = on
self.df_ext = df_ext
self.on_ext = on_ext or self.on
Expand Down Expand Up @@ -155,9 +155,9 @@ def transform(
return new_gdf

def output_column_names(self, columns):
if self.ext_columns:
return columns + self.ext_columns
return columns + self._ext.columns
if self.columns_ext:
return list(set(columns + self.columns_ext))
return list(set(columns + list(self._ext.columns)))


def _detect_format(data):
Expand Down
Loading

0 comments on commit f216edf

Please sign in to comment.