-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* spark docs * maddie's comments * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
- Loading branch information
Showing
2 changed files
with
59 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Use with Spark | ||
|
||
This document is a quick introduction to using 🤗 Datasets with Spark, with a particular focus on how to load a Spark DataFrame into a [`Dataset`] object. | ||
|
||
From there, you have fast access to any element and you can use it as a data loader to train models. | ||
|
||
## Load from Spark | ||
|
||
A [`Dataset`] object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to PyTorch, TensorFlow and JAX tensors. | ||
The Arrow table is memory mapped from disk, which can load datasets bigger than your available RAM. | ||
|
||
You can get a [`Dataset`] from a Spark DataFrame using [`Dataset.from_spark`]: | ||
|
||
```py | ||
>>> from datasets import Dataset | ||
>>> df = spark.createDataFrame( | ||
... data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]], | ||
... columns=["id", "name"], | ||
... ) | ||
>>> ds = Dataset.from_spark(df) | ||
``` | ||
|
||
The Spark workers write the dataset on disk in a cache directory as Arrow files, and the [`Dataset`] is loaded from there. | ||
|
||
### Caching | ||
|
||
The resulting [`Dataset`] is cached; if you call [`Dataset.from_spark`] multiple times on the same DataFrame it won't re-run the Spark | ||
job that writes the dataset as Arrow files on disk. | ||
|
||
You can set the cache location by passing `cache_dir=` to [`Dataset.from_spark`]. | ||
Make sure to use a disk that is available to both your workers and your current machine (the driver). | ||
|
||
<Tip warning={true}> | ||
|
||
In a different session, a Spark DataFrame doesn't have the same [semantic hash](https://spark.apache.org/docs/3.2.0/api/python/reference/api/pyspark.sql.DataFrame.semanticHash.html), and it will rerun a Spark job and store it in a new cache. | ||
|
||
</Tip> | ||
|
||
### Feature types | ||
|
||
If your dataset is made of images, audio data or N-dimensional arrays, you can specify the `features=` argument in [`Dataset.from_spark`]: | ||
|
||
```py | ||
>>> from datasets import Dataset, Features, Image, Value | ||
>>> data = [(0, open("image.png", "rb").read())] | ||
>>> df = spark.createDataFrame(data, "idx: int, image: binary") | ||
>>> # Also works if you have arrays | ||
>>> # data = [(0, np.zeros(shape=(32, 32, 3), dtype=np.int32).tolist())] | ||
>>> # df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>") | ||
>>> features = Features({"idx": Value("int64"), "image": Image()}) | ||
>>> dataset = Dataset.from_spark(df, features=features) | ||
>>> dataset[0] | ||
{'idx': 0, 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>} | ||
``` | ||
|
||
You can check the [`Features`] documentation to know about all the feature types available. | ||
|