Skip to content

Commit

Permalink
Spark docs (#5796)
Browse files Browse the repository at this point in the history
* spark docs

* maddie's comments

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
  • Loading branch information
lhoestq and stevhliu committed Apr 27, 2023
1 parent 4bb172c commit e210dc2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
title: Use with PyTorch
- local: use_with_jax
title: Use with JAX
- local: use_with_spark
title: Use with Spark
- local: cache
title: Cache management
- local: filesystems
Expand Down
57 changes: 57 additions & 0 deletions docs/source/use_with_spark.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Use with Spark

This document is a quick introduction to using 🤗 Datasets with Spark, with a particular focus on how to load a Spark DataFrame into a [`Dataset`] object.

From there, you have fast access to any element and you can use it as a data loader to train models.

## Load from Spark

A [`Dataset`] object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to PyTorch, TensorFlow and JAX tensors.
The Arrow table is memory mapped from disk, which can load datasets bigger than your available RAM.

You can get a [`Dataset`] from a Spark DataFrame using [`Dataset.from_spark`]:

```py
>>> from datasets import Dataset
>>> df = spark.createDataFrame(
... data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
... columns=["id", "name"],
... )
>>> ds = Dataset.from_spark(df)
```

The Spark workers write the dataset on disk in a cache directory as Arrow files, and the [`Dataset`] is loaded from there.

### Caching

The resulting [`Dataset`] is cached; if you call [`Dataset.from_spark`] multiple times on the same DataFrame it won't re-run the Spark
job that writes the dataset as Arrow files on disk.

You can set the cache location by passing `cache_dir=` to [`Dataset.from_spark`].
Make sure to use a disk that is available to both your workers and your current machine (the driver).

<Tip warning={true}>

In a different session, a Spark DataFrame doesn't have the same [semantic hash](https://spark.apache.org/docs/3.2.0/api/python/reference/api/pyspark.sql.DataFrame.semanticHash.html), and it will rerun a Spark job and store it in a new cache.

</Tip>

### Feature types

If your dataset is made of images, audio data or N-dimensional arrays, you can specify the `features=` argument in [`Dataset.from_spark`]:

```py
>>> from datasets import Dataset, Features, Image, Value
>>> data = [(0, open("image.png", "rb").read())]
>>> df = spark.createDataFrame(data, "idx: int, image: binary")
>>> # Also works if you have arrays
>>> # data = [(0, np.zeros(shape=(32, 32, 3), dtype=np.int32).tolist())]
>>> # df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
>>> features = Features({"idx": Value("int64"), "image": Image()})
>>> dataset = Dataset.from_spark(df, features=features)
>>> dataset[0]
{'idx': 0, 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>}
```

You can check the [`Features`] documentation to know about all the feature types available.

0 comments on commit e210dc2

Please sign in to comment.