Skip to content

Commit

Permalink
Safe import tests, improve install instruction, update gha workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushdg committed Apr 22, 2024
1 parent 12a8386 commit 9255b04
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ jobs:
# Explicitly install cython: https://github.com/VKCOM/YouTokenToMe/issues/94
run: |
pip install wheel cython
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com .
pip install --no-cache-dir .
pip install pytest
- name: Run tests
# TODO: Remove env variable when gpu dependencies are optional
run: |
RAPIDS_NO_INITIALIZE=1 python -m pytest -v --cpu
python -m pytest -v --cpu
4 changes: 2 additions & 2 deletions nemo_curator/utils/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo_curator[cuda-12x]`
or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda-12x]"` if installing from source"""
GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo_curator[cuda12x]`
or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"` if installing from source"""


def is_cudf_type(obj):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from itertools import combinations
from typing import Iterable

import cudf
import dask_cudf
import numpy as np
import pytest
from dask.dataframe.utils import assert_eq

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules import LSH, MinHash
from nemo_curator.utils.import_utils import gpu_only_import

cudf = gpu_only_import("cudf")
dask_cudf = gpu_only_import("dask_cudf")


@pytest.fixture
Expand Down

0 comments on commit 9255b04

Please sign in to comment.