Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate Beam API and download from HF GCS bucket #6474

Merged
merged 15 commits into from
Mar 12, 2024
Prev Previous commit
Fix tests
  • Loading branch information
mariosasko committed Mar 11, 2024
commit 9f8e33800dd561a8533177d3cb4c335cbf5b92d7
114 changes: 58 additions & 56 deletions tests/test_hf_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,79 @@
import pytest
from absl.testing import parameterized

from datasets import config
from datasets import config, load_dataset_builder
from datasets.arrow_reader import HF_GCP_BASE_URL
from datasets.builder import DatasetBuilder
from datasets.dataset_dict import IterableDatasetDict
from datasets.iterable_dataset import IterableDataset
from datasets.load import dataset_module_factory, import_main_class
from datasets.utils.file_utils import cached_path


DATASETS_ON_HF_GCP = [
{"dataset": "wikipedia", "config_name": "20220301.de"},
{"dataset": "wikipedia", "config_name": "20220301.en"},
{"dataset": "wikipedia", "config_name": "20220301.fr"},
{"dataset": "wikipedia", "config_name": "20220301.frr"},
{"dataset": "wikipedia", "config_name": "20220301.it"},
{"dataset": "wikipedia", "config_name": "20220301.simple"},
{"dataset": "wiki40b", "config_name": "en"},
{"dataset": "wiki_dpr", "config_name": "psgs_w100.nq.compressed"},
{"dataset": "wiki_dpr", "config_name": "psgs_w100.nq.no_index"},
{"dataset": "wiki_dpr", "config_name": "psgs_w100.multiset.no_index"},
{"dataset": "natural_questions", "config_name": "default"},
{"dataset": "wikipedia", "config_name": "20220301.de", "revision": "4d013bdd32c475c8536aae00a56efc774f061649"},
{"dataset": "wikipedia", "config_name": "20220301.en", "revision": "4d013bdd32c475c8536aae00a56efc774f061649"},
{"dataset": "wikipedia", "config_name": "20220301.fr", "revision": "4d013bdd32c475c8536aae00a56efc774f061649"},
{"dataset": "wikipedia", "config_name": "20220301.frr", "revision": "4d013bdd32c475c8536aae00a56efc774f061649"},
{"dataset": "wikipedia", "config_name": "20220301.it", "revision": "4d013bdd32c475c8536aae00a56efc774f061649"},
{"dataset": "wikipedia", "config_name": "20220301.simple", "revision": "4d013bdd32c475c8536aae00a56efc774f061649"},
{"dataset": "wiki40b", "config_name": "en", "revision": "7b21a2e64b90323b2d3d1b81aa349bb4bc76d9bf"},
{
"dataset": "wiki_dpr",
"config_name": "psgs_w100.nq.compressed",
"revision": "b24a417d802a583f8922946c1c75210290e93108",
},
{
"dataset": "wiki_dpr",
"config_name": "psgs_w100.nq.no_index",
"revision": "b24a417d802a583f8922946c1c75210290e93108",
},
{
"dataset": "wiki_dpr",
"config_name": "psgs_w100.multiset.no_index",
"revision": "b24a417d802a583f8922946c1c75210290e93108",
},
{"dataset": "natural_questions", "config_name": "default", "revision": "19ba7767b174ad046a84f46af056517a3910ee57"},
]


def list_datasets_on_hf_gcp_parameters(with_config=True):
def list_datasets_on_hf_gcp_parameters(with_config=True, with_revision=True):
columns = ["dataset"]
if with_config:
return [
{
"testcase_name": d["dataset"] + "/" + d["config_name"],
"dataset": d["dataset"],
"config_name": d["config_name"],
}
for d in DATASETS_ON_HF_GCP
]
else:
return [
{"testcase_name": dataset, "dataset": dataset} for dataset in {d["dataset"] for d in DATASETS_ON_HF_GCP}
]


@parameterized.named_parameters(list_datasets_on_hf_gcp_parameters(with_config=True))
columns.append("config_name")
if with_revision:
columns.append("revision")
dataset_list = [{col: dataset[col] for col in columns} for dataset in DATASETS_ON_HF_GCP]

def get_testcase_name(dataset):
testcase_name = dataset["dataset"]
if with_config:
testcase_name += "/" + dataset["config_name"]
if with_revision:
testcase_name += "@" + dataset["revision"]
return testcase_name

dataset_list = [{"testcase_name": get_testcase_name(dataset), **dataset} for dataset in dataset_list]
return dataset_list


@parameterized.named_parameters(list_datasets_on_hf_gcp_parameters(with_config=True, with_revision=True))
class TestDatasetOnHfGcp(TestCase):
dataset = None
config_name = None
revision = None

def test_dataset_info_available(self, dataset, config_name):
def test_dataset_info_available(self, dataset, config_name, revision):
with TemporaryDirectory() as tmp_dir:
dataset_module = dataset_module_factory(dataset, cache_dir=tmp_dir)

builder_cls = import_main_class(dataset_module.module_path, dataset=True)

builder_instance: DatasetBuilder = builder_cls(
builder = load_dataset_builder(
dataset,
config_name,
revision=revision,
cache_dir=tmp_dir,
config_name=config_name,
hash=dataset_module.hash,
)

dataset_info_url = "/".join(
[
HF_GCP_BASE_URL,
builder_instance._relative_data_dir(with_hash=False).replace(os.sep, "/"),
builder._relative_data_dir(with_hash=False).replace(os.sep, "/"),
config.DATASET_INFO_FILENAME,
]
)
Expand All @@ -76,30 +88,20 @@ def test_dataset_info_available(self, dataset, config_name):
@pytest.mark.integration
def test_as_dataset_from_hf_gcs(tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("test_hf_gcp") / "test_wikipedia_simple"
dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_dir)
builder_cls = import_main_class(dataset_module.module_path)
builder_instance: DatasetBuilder = builder_cls(
cache_dir=tmp_dir,
config_name="20220301.frr",
hash=dataset_module.hash,
)
builder = load_dataset_builder("wikipedia", "20220301.frr", cache_dir=tmp_dir)
# use the HF cloud storage, not the original download_and_prepare that uses apache-beam
builder_instance._download_and_prepare = None
builder_instance.download_and_prepare(try_from_hf_gcs=True)
ds = builder_instance.as_dataset()
builder._download_and_prepare = None
builder.download_and_prepare(try_from_hf_gcs=True)
ds = builder.as_dataset()
assert ds


@pytest.mark.integration
def test_as_streaming_dataset_from_hf_gcs(tmp_path):
dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_path)
builder_cls = import_main_class(dataset_module.module_path, dataset=True)
builder_instance: DatasetBuilder = builder_cls(
cache_dir=tmp_path,
config_name="20220301.frr",
hash=dataset_module.hash,
builder = load_dataset_builder(
"wikipedia", "20220301.frr", revision="4d013bdd32c475c8536aae00a56efc774f061649", cache_dir=tmp_path
)
ds = builder_instance.as_streaming_dataset()
ds = builder.as_streaming_dataset()
assert ds
assert isinstance(ds, IterableDatasetDict)
assert "train" in ds
Expand Down
Loading