Skip to content

Commit

Permalink
Make the json script more flexible (huggingface#372)
Browse files Browse the repository at this point in the history
* make the json script more flexible

* style and quality

* better error message
  • Loading branch information
thomwolf committed Jul 10, 2020
1 parent 0a74a92 commit 485e5b5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
53 changes: 45 additions & 8 deletions datasets/json/json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# coding=utf-8

import json
from dataclasses import dataclass
from io import BytesIO
from typing import List, Union

import pyarrow as pa
import pyarrow.json as paj
Expand All @@ -12,23 +15,30 @@
class JsonConfig(nlp.BuilderConfig):
"""BuilderConfig for JSON."""

read_options: paj.ReadOptions = paj.ReadOptions()
parse_options: paj.ParseOptions = paj.ParseOptions()
features: nlp.Features = None
field: str = None
use_threads: bool = True
block_size: int = None
newlines_in_values: bool = None

@property
def pa_read_options(self):
return self.read_options
return paj.ReadOptions(use_threads=self.use_threads, block_size=self.block_size)

@property
def pa_parse_options(self):
return self.parse_options
return paj.ParseOptions(explicit_schema=self.schema, newlines_in_values=self.newlines_in_values)

@property
def schema(self):
return pa.schema(self.features.type) if self.features is not None else None


class Json(nlp.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = JsonConfig

def _info(self):
return nlp.DatasetInfo()
return nlp.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
""" We handle string, list and dicts in datafiles
Expand All @@ -49,7 +59,34 @@ def _split_generators(self, dl_manager):

def _generate_tables(self, files):
for i, file in enumerate(files):
pa_table = paj.read_json(
file, read_options=self.config.pa_read_options, parse_options=self.config.pa_parse_options,
)
if self.config.field is not None:
with open(file, encoding="utf-8") as f:
dataset = json.load(f)

# We keep only the field we are interested in
dataset = dataset[self.config.field]

# We accept two format: a list of dicts or a dict of lists
if isinstance(dataset, (list, tuple)):
pa_table = paj.read_json(
BytesIO("\n".join(json.dumps(row) for row in dataset).encode("utf-8")),
read_options=self.config.pa_read_options,
parse_options=self.config.pa_parse_options,
)
else:
pa_table = pa.Table.from_pydict(mapping=dataset, schema=self.config.schema)
else:
try:
pa_table = paj.read_json(
file, read_options=self.config.pa_read_options, parse_options=self.config.pa_parse_options,
)
except pa.ArrowInvalid:
with open(file, encoding="utf-8") as f:
dataset = json.load(f)
raise ValueError(
f"Not able to read records in the JSON file at {file}. "
f"You should probably indicate the field of the JSON file containing your records. "
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
f"Select the correct one and provide it as `field='XXX'` to the `load_dataset` method. "
)
yield i, pa_table
6 changes: 3 additions & 3 deletions src/nlp/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def add_faiss_index(
device: Optional[int] = None,
string_factory: Optional[str] = None,
metric_type: Optional[int] = None,
custom_index: Optional["faiss.Index"] = None,
custom_index: Optional["faiss.Index"] = None, # noqa: F821
train_size: Optional[int] = None,
faiss_verbose: bool = False,
dtype=np.float32,
Expand Down Expand Up @@ -1367,7 +1367,7 @@ def add_faiss_index_from_external_arrays(
device: Optional[int] = None,
string_factory: Optional[str] = None,
metric_type: Optional[int] = None,
custom_index: Optional["faiss.Index"] = None,
custom_index: Optional["faiss.Index"] = None, # noqa: F821
train_size: Optional[int] = None,
faiss_verbose: bool = False,
dtype=np.float32,
Expand Down Expand Up @@ -1407,7 +1407,7 @@ def add_elasticsearch_index(
index_name: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
es_client: Optional["elasticsearch.Elasticsearch"] = None,
es_client: Optional["elasticsearch.Elasticsearch"] = None, # noqa: F821
es_index_name: Optional[str] = None,
es_index_config: Optional[dict] = None,
):
Expand Down

0 comments on commit 485e5b5

Please sign in to comment.