From 9255b0421d6461c60d8c98f29fe3ded2b5177396 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Mon, 22 Apr 2024 12:23:43 -0700 Subject: [PATCH] Safe import tests, improve install instruction, update gha workflow --- .github/workflows/test.yml | 5 ++--- nemo_curator/utils/gpu_utils.py | 4 ++-- tests/test_fuzzy_dedup.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d179a2a57..baa968f47 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,9 +40,8 @@ jobs: # Explicitly install cython: https://github.com/VKCOM/YouTokenToMe/issues/94 run: | pip install wheel cython - pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com . + pip install --no-cache-dir . pip install pytest - name: Run tests - # TODO: Remove env variable when gpu dependencies are optional run: | - RAPIDS_NO_INITIALIZE=1 python -m pytest -v --cpu + python -m pytest -v --cpu diff --git a/nemo_curator/utils/gpu_utils.py b/nemo_curator/utils/gpu_utils.py index 7d11e2d7a..86ba888fc 100644 --- a/nemo_curator/utils/gpu_utils.py +++ b/nemo_curator/utils/gpu_utils.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo_curator[cuda-12x]` -or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda-12x]"` if installing from source""" +GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo_curator[cuda12x]` +or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"` if installing from source""" def is_cudf_type(obj): diff --git a/tests/test_fuzzy_dedup.py b/tests/test_fuzzy_dedup.py index 3c6a32754..a1acb901f 100644 --- a/tests/test_fuzzy_dedup.py +++ b/tests/test_fuzzy_dedup.py @@ -16,14 +16,16 @@ from itertools import combinations from typing import Iterable -import cudf -import dask_cudf import numpy as np import pytest from dask.dataframe.utils import assert_eq from nemo_curator.datasets import DocumentDataset from nemo_curator.modules import LSH, MinHash +from nemo_curator.utils.import_utils import gpu_only_import + +cudf = gpu_only_import("cudf") +dask_cudf = gpu_only_import("dask_cudf") @pytest.fixture