Skip to content

Commit

Permalink
Merge branch 'main' into r0.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ryantwolf committed Sep 20, 2024
2 parents 5bcfd6e + 4197d02 commit 30b0b3d
Show file tree
Hide file tree
Showing 6 changed files with 1,075 additions and 22 deletions.
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ Before installing NeMo Curator, ensure that the following requirements are met:
- Volta™ or higher ([compute capability 7.0+](https://developer.nvidia.com/cuda-gpus))
- CUDA 12 (or above)

You can install NeMo-Curator from PyPi, from source or get it through the NeMo Framework container.
You can install NeMo-Curator
1. from PyPi
2. from source
3. get it through the [NeMo Framework container](https://github.com/NVIDIA/NeMo?tab=readme-ov-file#docker-containers).



#### From PyPi

Expand Down Expand Up @@ -126,6 +131,21 @@ pip install --extra-index-url https://pypi.nvidia.com nemo-curator[cuda12x]
pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"
```

#### Using Nightly Dependencies for Rapids

You can also install NeMo Curator using the Rapids nightly, to do so you can set the environment variable `RAPIDS_NIGHTLY=1`.


```bash
# installing from pypi
RAPIDS_NIGHTLY=1 pip install --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple "nemo-curator[cuda12x]"
# installing from source
RAPIDS_NIGHTLY=1 pip install --extra-index-url=https://pypi.anaconda.org/rapidsai-wheels-nightly/simple ".[cuda12x]"
```

When the environment variable set to 0 or not set (default behavior) it'll use the stable version of Rapids.
#### From the NeMo Framework Container
The latest release of NeMo Curator comes preinstalled in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags). If you want the latest commit inside the container, you can reinstall NeMo Curator using:
Expand Down
42 changes: 31 additions & 11 deletions nemo_curator/modules/semantic_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from nemo_curator.utils.distributed_utils import write_to_disk
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
from nemo_curator.utils.semdedup_utils import (
_assign_and_sort_clusters,
assign_and_sort_clusters,
extract_dedup_data,
get_semantic_matches_per_cluster,
)
Expand Down Expand Up @@ -123,6 +123,7 @@ def __init__(
embedding_batch_size: int,
embedding_output_dir: str,
input_column: str = "text",
embedding_column: str = "embeddings",
write_embeddings_to_disk: bool = True,
write_to_filename: bool = False,
logger: Union[logging.Logger, str] = "./",
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
self.logger = self._setup_logger(logger)
self.embedding_output_dir = embedding_output_dir
self.input_column = input_column
self.embedding_column = embedding_column
self.model = EmbeddingCrossFitModel(self.embeddings_config)
self.write_embeddings_to_disk = write_embeddings_to_disk
self.write_to_filename = write_to_filename
Expand Down Expand Up @@ -190,7 +192,7 @@ def create_embeddings(
self.model,
sorted_data_loader=True,
batch_size=self.batch_size,
pred_output_col="embeddings",
pred_output_col=self.embedding_column,
),
keep_cols=ddf.columns.tolist(),
)
Expand All @@ -215,12 +217,14 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:


### Clustering Module
def get_embedding_ar(df: "cudf.DataFrame") -> cp.ndarray:
return df["embeddings"].list.leaves.values.reshape(len(df), -1)
def get_embedding_ar(df: "cudf.DataFrame", embedding_col: str) -> cp.ndarray:
return df[embedding_col].list.leaves.values.reshape(len(df), -1)


def add_dist_to_cents(df: "cudf.DataFrame", centroids: cp.ndarray) -> "cudf.DataFrame":
embed_array = get_embedding_ar(df)
def add_dist_to_cents(
df: "cudf.DataFrame", embedding_col: str, centroids: cp.ndarray
) -> "cudf.DataFrame":
embed_array = get_embedding_ar(df, embedding_col)
centroids_ar = centroids[df["nearest_cent"].values]
dist_to_cents = cp.sqrt(np.sum((embed_array - centroids_ar) ** 2, axis=1))
df["dist_to_cent"] = dist_to_cents
Expand All @@ -234,6 +238,7 @@ def __init__(
max_iter: int,
n_clusters: int,
clustering_output_dir: str,
embedding_col: str = "embeddings",
sim_metric: str = "cosine",
which_to_keep: str = "hard",
sort_clusters: bool = True,
Expand All @@ -249,6 +254,7 @@ def __init__(
max_iter (int): Maximum number of iterations for the clustering algorithm.
n_clusters (int): The number of clusters to form.
clustering_output_dir (str): Directory path where clustering results will be saved.
embedding_col (str): Column name where the embeddings are stored.
sim_metric (str): Similarity metric to use for clustering, default is "cosine".
which_to_keep (str): Strategy to decide which duplicates to keep; default is "hard".
sort_clusters (bool): Whether to sort clusters, default is True.
Expand All @@ -262,6 +268,7 @@ def __init__(
self.max_iter = max_iter
self.n_clusters = n_clusters
self.clustering_output_dir = clustering_output_dir
self.embedding_col = embedding_col
self.sim_metric = sim_metric
self.keep_hard = which_to_keep == "hard"
self.kmeans_with_cos_dist = kmeans_with_cos_dist
Expand Down Expand Up @@ -291,15 +298,20 @@ def _setup_logger(self, logger):
def __call__(self, embeddings_dataset: DocumentDataset):
embeddings_df = embeddings_dataset.df

assert "embeddings" in embeddings_df.columns
embeddings_df = embeddings_df[[self.id_col, "embeddings"]]
if self.embedding_col not in embeddings_df.columns:
raise ValueError(
f"Expected embedding column '{self.embedding_col}'"
f" to be in dataset. Only found columns {embeddings_df.columns}"
)

embeddings_df = embeddings_df[[self.id_col, self.embedding_col]]

embeddings_df = embeddings_df.to_backend("pandas").persist()
embeddings_df = embeddings_df.repartition(partition_size=self.partition_size)
embeddings_df = embeddings_df.to_backend("cudf")

cupy_darr = embeddings_df.map_partitions(
get_embedding_ar, meta=cp.ndarray([1, 1])
get_embedding_ar, self.embedding_col, meta=cp.ndarray([1, 1])
)
cupy_darr.compute_chunk_sizes()

Expand All @@ -317,7 +329,10 @@ def __call__(self, embeddings_dataset: DocumentDataset):
meta_df = embeddings_df._meta.copy()
meta_df["dist_to_cent"] = cp.zeros(1)
embeddings_df = embeddings_df.map_partitions(
add_dist_to_cents, centroids=kmeans.cluster_centers_, meta=meta_df
add_dist_to_cents,
embedding_col=self.embedding_col,
centroids=kmeans.cluster_centers_,
meta=meta_df,
)
centroids = kmeans.cluster_centers_
embeddings_df = embeddings_df.reset_index(drop=True)
Expand Down Expand Up @@ -348,13 +363,14 @@ def __call__(self, embeddings_dataset: DocumentDataset):
del embeddings_df

if self.sort_clusters:
_assign_and_sort_clusters(
assign_and_sort_clusters(
id_col=self.id_col,
kmeans_centroids_file=kmeans_centroids_file,
nearest_cent_dir=clustering_output_dir,
output_sorted_clusters_dir=os.path.join(
self.clustering_output_dir, "sorted"
),
embedding_col=self.embedding_col,
sim_metric=self.sim_metric,
keep_hard=self.keep_hard,
kmeans_with_cos_dist=self.kmeans_with_cos_dist,
Expand All @@ -380,6 +396,7 @@ def __init__(
id_col_type: str,
which_to_keep: str,
output_dir: str,
embedding_col: str = "embeddings",
logger: Union[logging.Logger, str] = "./",
) -> None:
"""
Expand All @@ -393,6 +410,7 @@ def __init__(
id_col_type (str): Data type of the ID column.
which_to_keep (str): Strategy for which duplicate to keep.
output_dir (str): Directory to save output files.
embedding_col (str): Column where the embeddings are stored.
logger (Union[logging.Logger, str]): Logger instance or path to the log file directory.
"""
self.n_clusters = n_clusters
Expand All @@ -406,6 +424,7 @@ def __init__(
output_dir, "semdedup_pruning_tables"
)
self.computed_semantic_match_dfs = False
self.embedding_col = embedding_col
self.logger = self._setup_logger(logger)

def _setup_logger(self, logger: Union[logging.Logger, str]) -> logging.Logger:
Expand Down Expand Up @@ -461,6 +480,7 @@ def compute_semantic_match_dfs(
id_col_type=self.id_col_type,
eps_list=eps_list,
output_dir=self.semdedup_pruning_tables_dir,
embedding_col=self.embedding_col,
which_to_keep=self.which_to_keep,
)
)
Expand Down
26 changes: 18 additions & 8 deletions nemo_curator/utils/semdedup_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir


def _assign_and_sort_clusters(
def assign_and_sort_clusters(
id_col: str,
kmeans_centroids_file: str,
nearest_cent_dir: str,
output_sorted_clusters_dir: str,
cluster_ids=List[int],
cluster_ids: List[int],
embedding_col: str,
sim_metric: str = "cosine",
keep_hard: bool = True,
kmeans_with_cos_dist: bool = True,
Expand Down Expand Up @@ -78,6 +79,7 @@ def _assign_and_sort_clusters(
nearest_cent_dir=nearest_cent_dir,
output_sorted_clusters_dir=output_sorted_clusters_dir,
centroids=kmeans_centroids,
embedding_col=embedding_col,
sim_metric=sim_metric,
keep_hard=keep_hard,
kmeans_with_cos_dist=kmeans_with_cos_dist,
Expand All @@ -98,6 +100,7 @@ def rank_within_cluster(
nearest_cent_dir: str,
output_sorted_clusters_dir: str,
centroids: np.ndarray,
embedding_col: str,
sim_metric: str = "cosine",
keep_hard: bool = True,
kmeans_with_cos_dist: bool = False,
Expand Down Expand Up @@ -131,10 +134,10 @@ def rank_within_cluster(
continue

cluster_df = cudf.read_parquet(
cluster_c_path, columns=[id_col, "dist_to_cent", "embeddings"]
cluster_c_path, columns=[id_col, "dist_to_cent", embedding_col]
)
embeds = torch.as_tensor(
cluster_df["embeddings"].list.leaves.values.reshape(
cluster_df[embedding_col].list.leaves.values.reshape(
cluster_df.shape[0], -1
),
device="cuda",
Expand Down Expand Up @@ -188,11 +191,15 @@ def _semdedup(


def get_cluster_reps(
cluster_id: int, emb_by_clust_dir: str, id_col: str, sorted_ids: np.ndarray
cluster_id: int,
emb_by_clust_dir: str,
id_col: str,
embedding_col: str,
sorted_ids: np.ndarray,
) -> torch.Tensor:
cluster_i_path = os.path.join(emb_by_clust_dir, f"nearest_cent={cluster_id}")
cluster_reps = cudf.read_parquet(
cluster_i_path, columns=["embeddings", id_col]
cluster_i_path, columns=[embedding_col, id_col]
).sort_values(by=id_col)
num = cluster_reps.shape[0]

Expand All @@ -203,7 +210,7 @@ def get_cluster_reps(
cluster_reps = cluster_reps.sort_values(by="inverse_sort_id")

cluster_reps = torch.as_tensor(
cluster_reps["embeddings"].list.leaves.values.reshape(len(cluster_reps), -1),
cluster_reps[embedding_col].list.leaves.values.reshape(len(cluster_reps), -1),
device="cuda",
)
return cluster_reps
Expand All @@ -217,6 +224,7 @@ def get_semantic_matches_per_cluster(
id_col_type: str,
eps_list: List[float],
output_dir: str,
embedding_col: str,
which_to_keep: str,
) -> None:

Expand Down Expand Up @@ -251,7 +259,9 @@ def get_semantic_matches_per_cluster(

text_ids = cluster_i[:, 0].astype(id_col_type)

cluster_reps = get_cluster_reps(cluster_id, emb_by_clust_dir, id_col, text_ids)
cluster_reps = get_cluster_reps(
cluster_id, emb_by_clust_dir, id_col, embedding_col, text_ids
)
M, M1 = _semdedup(cluster_reps, "cuda")
assert cluster_reps.shape[0] == len(text_ids)

Expand Down
6 changes: 6 additions & 0 deletions requirements/requirements_rapids_nightly.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
cudf-cu12>=24.10.0a0,<=24.10
cugraph-cu12>=24.10.0a0,<=24.10
cuml-cu12>=24.10.0a0,<=24.10
dask-cuda>=24.10.0a0,<=24.10
dask-cudf-cu12>=24.10.0a0,<=24.10
spacy[cuda12x]>=3.6.0, <4.0.0
17 changes: 15 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
long_description = (here / "README.md").read_text(encoding="utf-8")


def strtobool(value: str) -> bool:
value = value.lower()
if value in ("y", "yes", "1", "true"):
return True
return False


def req_file(filename, folder="requirements"):
with open(os.path.join(folder, filename), encoding="utf-8") as f:
content = f.readlines()
Expand All @@ -29,8 +36,14 @@ def req_file(filename, folder="requirements"):

install_requires = req_file("requirements.txt")

extras_require = {
"cuda12x": req_file("requirements_cuda12x.txt"),
cuda12_requirements_filename = (
"requirements_rapids_nightly.txt"
if strtobool(os.getenv("RAPIDS_NIGHTLY", "false"))
else "requirements_cuda12x.txt"
)

extras_require: dict = {
"cuda12x": req_file(cuda12_requirements_filename),
"image": req_file("requirements_image.txt"),
}

Expand Down
Loading

0 comments on commit 30b0b3d

Please sign in to comment.