Skip to content

Commit

Permalink
Integration Tests for Retrieval models (#537)
Browse files Browse the repository at this point in the history
* Started the integration tests for retrieval models using fiddle

* Finished first version of retrieval integration tests

* Fixing precommit checks

* Added CI script for integration tests

* Fixed CI integration commands

* Enabling only retrieval integration tests

* Fixing wandb import order

* Refactored retrieval tests to run both as integration and unit tests

* Fixed wandb import

* Changed retrieval unit test to train for a full epoch to check the runtime using the P100 machine available for PR testing

* Fixing wandb import

* Changing memory management for running the retrieval integration tests

* Limiting training steps of unit test on retrieval to 500 steps to check for runtime

* Changing memory management of unit tests to test retrieval on P100

* Increasing test to train for a full epoch with LastFM

* Sets W&B execution as successful only when the run is completely peformed

* Evaluating MF retrieval test with smaller batch size and for fewer steps

* Reducing the eval steps and eval batch size of two-tower retrieval test

* Changed unit tests to use synthetic data rather than the real LastFM data, which will be used only by the retrieval integration tests

* Fixed retrieval unit tests after rebasing

* Fixed wandb import

* Adjusting retrieval unit tests to run faster

* Fixing tests

* Fix retrieval test
  • Loading branch information
gabrielspmoreira authored and mengyao00 committed Jul 15, 2022
1 parent ce678a6 commit 5955ee8
Show file tree
Hide file tree
Showing 13 changed files with 1,105 additions and 2 deletions.
18 changes: 17 additions & 1 deletion ci/test_integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,20 @@
#!/bin/bash
set -e

pytest -rxs tests/integration
# Call this script with:
# 1. Name of container as first parameter
# [merlin-hugectr, merlin-tensorflow, merlin-pytorch]
#
# 2. Devices to use:
# [0; 0,1; 0,1,..,n-1]

# Get last Models version
cd /models/
#git pull origin main

container=$1
devices=$2
if [ "$container" == "merlin-tensorflow" ]; then
CUDA_VISIBLE_DEVICES="$devices" TF_GPU_ALLOCATOR=cuda_malloc_async python -m pytest -rxs tests/integration/tf/retrieval
# TODO: When the example notebooks integration tests are fixed, change to python -m pytest -rxs tests/integration/tf/
fi
2 changes: 1 addition & 1 deletion ci/test_unit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
#!/bin/bash
set -e

pytest -rxs tests/unit
TF_GPU_ALLOCATOR=cuda_malloc_async python -m pytest -rxs tests/unit
2 changes: 2 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ myst-nb==0.13.2
linkify-it-py==1.0.3
sphinx-multiversion@git+https://github.com/mikemckiernan/sphinx-multiversion.git
sphinxcontrib-copydirs@git+https://github.com/mikemckiernan/sphinxcontrib-copydirs.git
fiddle
wandb
15 changes: 15 additions & 0 deletions tests/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
15 changes: 15 additions & 0 deletions tests/common/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
15 changes: 15 additions & 0 deletions tests/common/tf/retrieval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
72 changes: 72 additions & 0 deletions tests/common/tf/retrieval/retrieval_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fiddle as fdl

from merlin.io import Dataset
from tests.common.tf.retrieval.retrieval_utils import (
RetrievalTrainEvalRunner,
WandbLogger,
filter_schema,
get_callbacks,
get_dual_encoder_model,
get_item_frequencies,
get_loss,
get_metrics,
get_optimizer,
get_samplers,
get_youtube_dnn_model,
)


def config_retrieval_train_eval_runner(
train_ds: Dataset, eval_ds: Dataset, model_type: str, log_to_wandb: bool, wandb_project: str
):
def make_model(schema, train_ds, model_type):
if model_type == "youtubednn":
model = fdl.Config(get_youtube_dnn_model, schema)
else:
samplers = fdl.Config(get_samplers, schema)
items_frequencies = fdl.Config(get_item_frequencies, schema, train_ds)
model = fdl.Config(
get_dual_encoder_model, schema, samplers, items_frequencies, model_type
)
return model

wandb_logger_cfg = fdl.Config(WandbLogger, enabled=log_to_wandb, wandb_project=wandb_project)

schema_cfg = fdl.Config(filter_schema, schema=train_ds.schema)

model_cfg = make_model(schema_cfg, train_ds, model_type=model_type)
optimizer = fdl.Config(get_optimizer)
metrics = fdl.Config(get_metrics)
loss = fdl.Config(get_loss)
callbacks = fdl.Config(get_callbacks, wandb_logger=wandb_logger_cfg)

runner_cfg = fdl.Config(
RetrievalTrainEvalRunner,
wandb_logger=wandb_logger_cfg,
model_type=model_type,
schema=schema_cfg,
train_ds=train_ds,
eval_ds=eval_ds,
model=model_cfg,
optimizer=optimizer,
metrics=metrics,
loss=loss,
callbacks=callbacks,
)

return runner_cfg
213 changes: 213 additions & 0 deletions tests/common/tf/retrieval/retrieval_tests_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Callable, Optional

import fiddle as fdl

from merlin.io import Dataset
from tests.common.tf.retrieval.retrieval_config import config_retrieval_train_eval_runner
from tests.common.tf.tests_utils import extract_hparams_from_config


def set_lastfm_two_tower_hparams_config(runner_cfg: fdl.Config):
runner_cfg.model.two_tower_activation = "selu"
runner_cfg.model.two_tower_mlp_layers = "64"
runner_cfg.model.two_tower_dropout = 0.2
runner_cfg.model.two_tower_embedding_sizes_multiplier = 4.0
runner_cfg.model.logits_temperature = 0.6
runner_cfg.model.l2_reg = 2e-06
runner_cfg.model.embeddings_l2_reg = 1e-05
runner_cfg.model.logq_correction_factor = 0.7

runner_cfg.model.samplers.neg_sampling = "inbatch"

runner_cfg.loss.loss = "categorical_crossentropy"
runner_cfg.loss.xe_label_smoothing = 0.3

runner_cfg.optimizer.lr = 0.02
runner_cfg.optimizer.lr_decay_rate = 0.97
runner_cfg.optimizer.lr_decay_steps = 50
runner_cfg.optimizer.optimizer = "adam"

runner_cfg.metrics.topk_metrics_cutoffs = "10,50,100"

runner_cfg.train_batch_size = 4096
runner_cfg.eval_batch_size = 512

runner_cfg.train_epochs = 20
runner_cfg.train_steps_per_epoch = None
runner_cfg.eval_steps = 5000

runner_cfg.callbacks.train_batch_size = runner_cfg.train_batch_size


def set_lastfm_mf_hparams_config(runner_cfg: fdl.Config):
runner_cfg.model.mf_dim = 64
runner_cfg.model.logits_temperature = 1.4
runner_cfg.model.embeddings_l2_reg = 3e-07
runner_cfg.model.logq_correction_factor = 0.9

runner_cfg.model.samplers.neg_sampling = "inbatch"

runner_cfg.loss.loss = "categorical_crossentropy"
runner_cfg.loss.xe_label_smoothing = 0.0

runner_cfg.optimizer.lr = 0.005
runner_cfg.optimizer.lr_decay_rate = 0.98
runner_cfg.optimizer.lr_decay_steps = 50
runner_cfg.optimizer.optimizer = "adam"

runner_cfg.metrics.topk_metrics_cutoffs = "10,50,100"

runner_cfg.train_batch_size = 4096
runner_cfg.eval_batch_size = 512

runner_cfg.train_epochs = 20
runner_cfg.train_steps_per_epoch = None
runner_cfg.eval_steps = 5000


def train_eval_two_tower(
train_ds: Dataset,
eval_ds: Dataset,
train_epochs: int = 1,
train_steps_per_epoch: Optional[int] = None,
eval_steps: Optional[int] = 2000,
train_batch_size: int = 512,
eval_batch_size: int = 512,
topk_metrics_cutoffs: str = "10,50,100",
log_to_wandb: bool = False,
wandb_project: str = None,
config_callback: Callable = None,
):
runner_cfg = config_retrieval_train_eval_runner(
train_ds,
eval_ds,
model_type="two_tower",
log_to_wandb=log_to_wandb,
wandb_project=wandb_project,
)

if config_callback:
config_callback(runner_cfg)

runner_cfg.train_epochs = train_epochs
runner_cfg.train_steps_per_epoch = train_steps_per_epoch
runner_cfg.eval_steps = eval_steps
runner_cfg.train_batch_size = train_batch_size
runner_cfg.eval_batch_size = eval_batch_size
runner_cfg.metrics.topk_metrics_cutoffs = topk_metrics_cutoffs

hparams = extract_hparams_from_config(runner_cfg)

runner = fdl.build(runner_cfg)
metrics = runner.run(hparams)
return metrics


def train_eval_mf(
train_ds: Dataset,
eval_ds: Dataset,
train_epochs: int = 1,
train_steps_per_epoch: Optional[int] = None,
eval_steps: Optional[int] = 2000,
train_batch_size: int = 512,
eval_batch_size: int = 512,
topk_metrics_cutoffs: str = "10,50,100",
log_to_wandb: bool = False,
wandb_project: str = None,
config_callback: Callable = None,
):
runner_cfg = config_retrieval_train_eval_runner(
train_ds,
eval_ds,
model_type="mf",
log_to_wandb=log_to_wandb,
wandb_project=wandb_project,
)

if config_callback:
config_callback(runner_cfg)

runner_cfg.train_epochs = train_epochs
runner_cfg.train_steps_per_epoch = train_steps_per_epoch
runner_cfg.eval_steps = eval_steps
runner_cfg.train_batch_size = train_batch_size
runner_cfg.eval_batch_size = eval_batch_size
runner_cfg.metrics.topk_metrics_cutoffs = topk_metrics_cutoffs

runner_cfg.callbacks.train_batch_size = runner_cfg.train_batch_size

hparams = extract_hparams_from_config(runner_cfg)

runner = fdl.build(runner_cfg)
metrics = runner.run(hparams)
return metrics


def train_eval_two_tower_for_lastfm(
train_ds: Dataset,
eval_ds: Dataset,
train_epochs: int = 1,
train_steps_per_epoch: Optional[int] = None,
eval_steps: Optional[int] = 2000,
train_batch_size: int = 512,
eval_batch_size: int = 512,
topk_metrics_cutoffs: str = "10,50,100",
log_to_wandb: bool = False,
wandb_project: str = None,
):
return train_eval_two_tower(
train_ds,
eval_ds,
train_epochs,
train_steps_per_epoch,
eval_steps,
train_batch_size,
eval_batch_size,
topk_metrics_cutoffs,
log_to_wandb,
wandb_project,
config_callback=set_lastfm_two_tower_hparams_config,
)


def train_eval_mf_for_lastfm(
train_ds: Dataset,
eval_ds: Dataset,
train_epochs: int = 1,
train_steps_per_epoch: Optional[int] = None,
eval_steps: Optional[int] = 2000,
train_batch_size: int = 512,
eval_batch_size: int = 512,
topk_metrics_cutoffs: str = "10,50,100",
log_to_wandb: bool = False,
wandb_project: str = None,
):
return train_eval_mf(
train_ds,
eval_ds,
train_epochs,
train_steps_per_epoch,
eval_steps,
train_batch_size,
eval_batch_size,
topk_metrics_cutoffs,
log_to_wandb,
wandb_project,
config_callback=set_lastfm_mf_hparams_config,
)
Loading

0 comments on commit 5955ee8

Please sign in to comment.