Skip to content

Commit

Permalink
fix bad type in overflow check (huggingface#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Aug 14, 2020
1 parent 388e566 commit 70e935c
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/nlp/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,26 @@ def write_on_file(self):
type = None if self.update_features and self.pa_writer is None else self._type
if self.current_rows:
pa_array = pa.array(self.current_rows, type=type)
first_example = pa.array(self.current_rows[0:1], type=type)[0]
inferred_type = pa_array.type
first_example = pa.array(self.current_rows[0:1], type=inferred_type)[0]
# Sanity check
if pa_array[0] != first_example:
# There was an Overflow in StructArray. Let's reduce the batch_size
new_batch_size = self.writer_batch_size
while pa_array[0] != first_example:
if new_batch_size < 2:
raise RuntimeError("The given example is too big (>2GB) to fit in an array.")
new_batch_size = self.writer_batch_size // 2
pa_array = pa.array(self.current_rows[:new_batch_size], type=type)
pa_array = pa.array(self.current_rows[:new_batch_size], type=inferred_type)
logger.warning(
"Batch size is too big (>2GB). Reducing it from {} to {}".format(
self.writer_batch_size, new_batch_size
)
)
self.writer_batch_size = new_batch_size
n_batches = len(self.current_rows) // new_batch_size
n_batches += int(len(self.current_rows) % new_batch_size != 0)
for i in range(n_batches):
pa_array = pa.array(self.current_rows[i * new_batch_size : (i + 1) * new_batch_size], type=type,)
self._write_array_on_file(pa_array)
for i in range(0, len(self.current_rows), new_batch_size):
rows_batch = self.current_rows[i, i + new_batch_size]
self._write_array_on_file(pa.array(rows_batch, type=inferred_type))
else:
# All good
self._write_array_on_file(pa_array)
Expand Down

0 comments on commit 70e935c

Please sign in to comment.