Skip to content

Commit

Permalink
Fix spark imports (huggingface#5795)
Browse files Browse the repository at this point in the history
fix spark imports
  • Loading branch information
lhoestq committed Apr 26, 2023
1 parent cf4a195 commit 5b011a2
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/datasets/packaged_modules/spark/spark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import posixpath
import uuid
from dataclasses import dataclass
from typing import Iterable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union

import pyarrow as pa
import pyspark

import datasets
from datasets.arrow_writer import ArrowWriter
from datasets.arrow_writer import ArrowWriter, ParquetWriter
from datasets.config import MAX_SHARD_SIZE
from datasets.filesystems import (
is_remote_filesystem,
Expand All @@ -18,6 +18,9 @@

logger = datasets.utils.logging.get_logger(__name__)

if TYPE_CHECKING:
import pyspark


@dataclass
class SparkConfig(datasets.BuilderConfig):
Expand All @@ -31,10 +34,12 @@ class Spark(datasets.DatasetBuilder):

def __init__(
self,
df: pyspark.sql.DataFrame,
df: "pyspark.sql.DataFrame",
cache_dir: str = None,
**config_kwargs,
):
import pyspark

self._spark = pyspark.sql.SparkSession.builder.getOrCreate()
self.df = df
self._validate_cache_dir(cache_dir)
Expand Down Expand Up @@ -86,6 +91,8 @@ def _prepare_split_single(
file_format: str,
max_shard_size: int,
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
import pyspark

writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
embed_local_files = file_format == "parquet"

Expand Down

0 comments on commit 5b011a2

Please sign in to comment.