Skip to content

Commit

Permalink
Set base path to hub url for canonical datasets (huggingface#3709)
Browse files Browse the repository at this point in the history
* set base_path to hub URL for canonical datasets

* add test

* minor
  • Loading branch information
lhoestq committed Feb 16, 2022
1 parent 643e3e0 commit 8116080
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def get_module(self) -> DatasetModule:
)
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {"hash": hash, "base_path": hf_github_url(self.name, "", revision=revision)}
builder_kwargs = {"hash": hash, "base_path": hf_hub_url(self.name, "", revision=self.revision)}
return DatasetModule(module_path, hash, builder_kwargs)


Expand Down
7 changes: 6 additions & 1 deletion tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import requests

import datasets
from datasets import SCRIPTS_VERSION, load_dataset, load_from_disk
from datasets import SCRIPTS_VERSION, config, load_dataset, load_from_disk
from datasets.arrow_dataset import Dataset
from datasets.builder import DatasetBuilder
from datasets.data_files import DataFilesDict
Expand Down Expand Up @@ -189,6 +189,7 @@ def test_CanonicalDatasetModuleFactory(self):
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)

def test_CanonicalMetricModuleFactory_with_internal_import(self):
# "squad_v2" requires additional imports (internal)
Expand Down Expand Up @@ -221,11 +222,13 @@ def test_LocalDatasetModuleFactoryWithScript(self):
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
assert os.path.isdir(module_factory_result.builder_kwargs["base_path"])

def test_LocalDatasetModuleFactoryWithoutScript(self):
factory = LocalDatasetModuleFactoryWithoutScript(self._data_dir)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
assert os.path.isdir(module_factory_result.builder_kwargs["base_path"])

def test_PackagedDatasetModuleFactory(self):
factory = PackagedDatasetModuleFactory(
Expand All @@ -240,6 +243,7 @@ def test_CommunityDatasetModuleFactoryWithoutScript(self):
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)

def test_CommunityDatasetModuleFactoryWithScript(self):
factory = CommunityDatasetModuleFactoryWithScript(
Expand All @@ -249,6 +253,7 @@ def test_CommunityDatasetModuleFactoryWithScript(self):
)
module_factory_result = factory.get_module()
assert importlib.import_module(module_factory_result.module_path) is not None
assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT)

def test_CachedDatasetModuleFactory(self):
path = os.path.join(self._dataset_loading_script_dir, f"{DATASET_LOADING_SCRIPT_NAME}.py")
Expand Down

0 comments on commit 8116080

Please sign in to comment.