-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integration Tests for Retrieval models (#537)
* Started the integration tests for retrieval models using fiddle * Finished first version of retrieval integration tests * Fixing precommit checks * Added CI script for integration tests * Fixed CI integration commands * Enabling only retrieval integration tests * Fixing wandb import order * Refactored retrieval tests to run both as integration and unit tests * Fixed wandb import * Changed retrieval unit test to train for a full epoch to check the runtime using the P100 machine available for PR testing * Fixing wandb import * Changing memory management for running the retrieval integration tests * Limiting training steps of unit test on retrieval to 500 steps to check for runtime * Changing memory management of unit tests to test retrieval on P100 * Increasing test to train for a full epoch with LastFM * Sets W&B execution as successful only when the run is completely peformed * Evaluating MF retrieval test with smaller batch size and for fewer steps * Reducing the eval steps and eval batch size of two-tower retrieval test * Changed unit tests to use synthetic data rather than the real LastFM data, which will be used only by the retrieval integration tests * Fixed retrieval unit tests after rebasing * Fixed wandb import * Adjusting retrieval unit tests to run faster * Fixing tests * Fix retrieval test
- Loading branch information
1 parent
ce678a6
commit 5955ee8
Showing
13 changed files
with
1,105 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,4 +17,4 @@ | |
#!/bin/bash | ||
set -e | ||
|
||
pytest -rxs tests/unit | ||
TF_GPU_ALLOCATOR=cuda_malloc_async python -m pytest -rxs tests/unit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import fiddle as fdl | ||
|
||
from merlin.io import Dataset | ||
from tests.common.tf.retrieval.retrieval_utils import ( | ||
RetrievalTrainEvalRunner, | ||
WandbLogger, | ||
filter_schema, | ||
get_callbacks, | ||
get_dual_encoder_model, | ||
get_item_frequencies, | ||
get_loss, | ||
get_metrics, | ||
get_optimizer, | ||
get_samplers, | ||
get_youtube_dnn_model, | ||
) | ||
|
||
|
||
def config_retrieval_train_eval_runner( | ||
train_ds: Dataset, eval_ds: Dataset, model_type: str, log_to_wandb: bool, wandb_project: str | ||
): | ||
def make_model(schema, train_ds, model_type): | ||
if model_type == "youtubednn": | ||
model = fdl.Config(get_youtube_dnn_model, schema) | ||
else: | ||
samplers = fdl.Config(get_samplers, schema) | ||
items_frequencies = fdl.Config(get_item_frequencies, schema, train_ds) | ||
model = fdl.Config( | ||
get_dual_encoder_model, schema, samplers, items_frequencies, model_type | ||
) | ||
return model | ||
|
||
wandb_logger_cfg = fdl.Config(WandbLogger, enabled=log_to_wandb, wandb_project=wandb_project) | ||
|
||
schema_cfg = fdl.Config(filter_schema, schema=train_ds.schema) | ||
|
||
model_cfg = make_model(schema_cfg, train_ds, model_type=model_type) | ||
optimizer = fdl.Config(get_optimizer) | ||
metrics = fdl.Config(get_metrics) | ||
loss = fdl.Config(get_loss) | ||
callbacks = fdl.Config(get_callbacks, wandb_logger=wandb_logger_cfg) | ||
|
||
runner_cfg = fdl.Config( | ||
RetrievalTrainEvalRunner, | ||
wandb_logger=wandb_logger_cfg, | ||
model_type=model_type, | ||
schema=schema_cfg, | ||
train_ds=train_ds, | ||
eval_ds=eval_ds, | ||
model=model_cfg, | ||
optimizer=optimizer, | ||
metrics=metrics, | ||
loss=loss, | ||
callbacks=callbacks, | ||
) | ||
|
||
return runner_cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# | ||
# Copyright (c) 2021, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from typing import Callable, Optional | ||
|
||
import fiddle as fdl | ||
|
||
from merlin.io import Dataset | ||
from tests.common.tf.retrieval.retrieval_config import config_retrieval_train_eval_runner | ||
from tests.common.tf.tests_utils import extract_hparams_from_config | ||
|
||
|
||
def set_lastfm_two_tower_hparams_config(runner_cfg: fdl.Config): | ||
runner_cfg.model.two_tower_activation = "selu" | ||
runner_cfg.model.two_tower_mlp_layers = "64" | ||
runner_cfg.model.two_tower_dropout = 0.2 | ||
runner_cfg.model.two_tower_embedding_sizes_multiplier = 4.0 | ||
runner_cfg.model.logits_temperature = 0.6 | ||
runner_cfg.model.l2_reg = 2e-06 | ||
runner_cfg.model.embeddings_l2_reg = 1e-05 | ||
runner_cfg.model.logq_correction_factor = 0.7 | ||
|
||
runner_cfg.model.samplers.neg_sampling = "inbatch" | ||
|
||
runner_cfg.loss.loss = "categorical_crossentropy" | ||
runner_cfg.loss.xe_label_smoothing = 0.3 | ||
|
||
runner_cfg.optimizer.lr = 0.02 | ||
runner_cfg.optimizer.lr_decay_rate = 0.97 | ||
runner_cfg.optimizer.lr_decay_steps = 50 | ||
runner_cfg.optimizer.optimizer = "adam" | ||
|
||
runner_cfg.metrics.topk_metrics_cutoffs = "10,50,100" | ||
|
||
runner_cfg.train_batch_size = 4096 | ||
runner_cfg.eval_batch_size = 512 | ||
|
||
runner_cfg.train_epochs = 20 | ||
runner_cfg.train_steps_per_epoch = None | ||
runner_cfg.eval_steps = 5000 | ||
|
||
runner_cfg.callbacks.train_batch_size = runner_cfg.train_batch_size | ||
|
||
|
||
def set_lastfm_mf_hparams_config(runner_cfg: fdl.Config): | ||
runner_cfg.model.mf_dim = 64 | ||
runner_cfg.model.logits_temperature = 1.4 | ||
runner_cfg.model.embeddings_l2_reg = 3e-07 | ||
runner_cfg.model.logq_correction_factor = 0.9 | ||
|
||
runner_cfg.model.samplers.neg_sampling = "inbatch" | ||
|
||
runner_cfg.loss.loss = "categorical_crossentropy" | ||
runner_cfg.loss.xe_label_smoothing = 0.0 | ||
|
||
runner_cfg.optimizer.lr = 0.005 | ||
runner_cfg.optimizer.lr_decay_rate = 0.98 | ||
runner_cfg.optimizer.lr_decay_steps = 50 | ||
runner_cfg.optimizer.optimizer = "adam" | ||
|
||
runner_cfg.metrics.topk_metrics_cutoffs = "10,50,100" | ||
|
||
runner_cfg.train_batch_size = 4096 | ||
runner_cfg.eval_batch_size = 512 | ||
|
||
runner_cfg.train_epochs = 20 | ||
runner_cfg.train_steps_per_epoch = None | ||
runner_cfg.eval_steps = 5000 | ||
|
||
|
||
def train_eval_two_tower( | ||
train_ds: Dataset, | ||
eval_ds: Dataset, | ||
train_epochs: int = 1, | ||
train_steps_per_epoch: Optional[int] = None, | ||
eval_steps: Optional[int] = 2000, | ||
train_batch_size: int = 512, | ||
eval_batch_size: int = 512, | ||
topk_metrics_cutoffs: str = "10,50,100", | ||
log_to_wandb: bool = False, | ||
wandb_project: str = None, | ||
config_callback: Callable = None, | ||
): | ||
runner_cfg = config_retrieval_train_eval_runner( | ||
train_ds, | ||
eval_ds, | ||
model_type="two_tower", | ||
log_to_wandb=log_to_wandb, | ||
wandb_project=wandb_project, | ||
) | ||
|
||
if config_callback: | ||
config_callback(runner_cfg) | ||
|
||
runner_cfg.train_epochs = train_epochs | ||
runner_cfg.train_steps_per_epoch = train_steps_per_epoch | ||
runner_cfg.eval_steps = eval_steps | ||
runner_cfg.train_batch_size = train_batch_size | ||
runner_cfg.eval_batch_size = eval_batch_size | ||
runner_cfg.metrics.topk_metrics_cutoffs = topk_metrics_cutoffs | ||
|
||
hparams = extract_hparams_from_config(runner_cfg) | ||
|
||
runner = fdl.build(runner_cfg) | ||
metrics = runner.run(hparams) | ||
return metrics | ||
|
||
|
||
def train_eval_mf( | ||
train_ds: Dataset, | ||
eval_ds: Dataset, | ||
train_epochs: int = 1, | ||
train_steps_per_epoch: Optional[int] = None, | ||
eval_steps: Optional[int] = 2000, | ||
train_batch_size: int = 512, | ||
eval_batch_size: int = 512, | ||
topk_metrics_cutoffs: str = "10,50,100", | ||
log_to_wandb: bool = False, | ||
wandb_project: str = None, | ||
config_callback: Callable = None, | ||
): | ||
runner_cfg = config_retrieval_train_eval_runner( | ||
train_ds, | ||
eval_ds, | ||
model_type="mf", | ||
log_to_wandb=log_to_wandb, | ||
wandb_project=wandb_project, | ||
) | ||
|
||
if config_callback: | ||
config_callback(runner_cfg) | ||
|
||
runner_cfg.train_epochs = train_epochs | ||
runner_cfg.train_steps_per_epoch = train_steps_per_epoch | ||
runner_cfg.eval_steps = eval_steps | ||
runner_cfg.train_batch_size = train_batch_size | ||
runner_cfg.eval_batch_size = eval_batch_size | ||
runner_cfg.metrics.topk_metrics_cutoffs = topk_metrics_cutoffs | ||
|
||
runner_cfg.callbacks.train_batch_size = runner_cfg.train_batch_size | ||
|
||
hparams = extract_hparams_from_config(runner_cfg) | ||
|
||
runner = fdl.build(runner_cfg) | ||
metrics = runner.run(hparams) | ||
return metrics | ||
|
||
|
||
def train_eval_two_tower_for_lastfm( | ||
train_ds: Dataset, | ||
eval_ds: Dataset, | ||
train_epochs: int = 1, | ||
train_steps_per_epoch: Optional[int] = None, | ||
eval_steps: Optional[int] = 2000, | ||
train_batch_size: int = 512, | ||
eval_batch_size: int = 512, | ||
topk_metrics_cutoffs: str = "10,50,100", | ||
log_to_wandb: bool = False, | ||
wandb_project: str = None, | ||
): | ||
return train_eval_two_tower( | ||
train_ds, | ||
eval_ds, | ||
train_epochs, | ||
train_steps_per_epoch, | ||
eval_steps, | ||
train_batch_size, | ||
eval_batch_size, | ||
topk_metrics_cutoffs, | ||
log_to_wandb, | ||
wandb_project, | ||
config_callback=set_lastfm_two_tower_hparams_config, | ||
) | ||
|
||
|
||
def train_eval_mf_for_lastfm( | ||
train_ds: Dataset, | ||
eval_ds: Dataset, | ||
train_epochs: int = 1, | ||
train_steps_per_epoch: Optional[int] = None, | ||
eval_steps: Optional[int] = 2000, | ||
train_batch_size: int = 512, | ||
eval_batch_size: int = 512, | ||
topk_metrics_cutoffs: str = "10,50,100", | ||
log_to_wandb: bool = False, | ||
wandb_project: str = None, | ||
): | ||
return train_eval_mf( | ||
train_ds, | ||
eval_ds, | ||
train_epochs, | ||
train_steps_per_epoch, | ||
eval_steps, | ||
train_batch_size, | ||
eval_batch_size, | ||
topk_metrics_cutoffs, | ||
log_to_wandb, | ||
wandb_project, | ||
config_callback=set_lastfm_mf_hparams_config, | ||
) |
Oops, something went wrong.