Skip to content

Commit

Permalink
Version 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Tagar committed Feb 10, 2018
1 parent ca811ef commit 59b8fa7
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 31 deletions.
169 changes: 139 additions & 30 deletions abalon/spark/sparkutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,39 @@

###########################################################################################################

sparkutils_init_complete = False

def sparkutils_init (i_spark, i_debug=False):
def sparkutils_init (i_spark=None, i_debug=False):
'''
Initialize module-level variables
:param i_spark: an object of pyspark.sql.session.SparkSession
:param i_debug: debug output of the below functions?
'''

global sparkutils_init_complete
if sparkutils_init_complete: return

global spark, debug
(spark, debug) = (i_spark, i_debug)

if i_spark:
spark = i_spark
else:
assert 'spark' in main_ns.globals(), "'spark' variable doesn't exist in global namespace"
spark = main_ns.globals()['spark']

debug = i_debug

from pyspark.sql.session import SparkSession
if not isinstance(spark, SparkSession):
raise TypeError("spark parameter should be of type SparkSession")
raise TypeError("'spark' variable should be of type SparkSession")

global sc
sc = spark.sparkContext

global hadoop, conf, fs

hadoop = sc._jvm.org.apache.hadoop # Get a reference to org.apache.hadoop through py4j object
sc = spark.sparkContext

hadoop = sc._jvm.org.apache.hadoop # Get a reference to org.apache.hadoop through py4j interface

# Create Configuration object
# see - https://hadoop.apache.org/docs/r2.6.4/api/org/apache/hadoop/conf/Configuration.html
Expand All @@ -58,12 +69,43 @@ def sparkutils_init (i_spark, i_debug=False):
# see - https://hadoop.apache.org/docs/stable/api/org/apache/hadoop/fs/FileSystem.html#get(org.apache.hadoop.conf.Configuration)
fs = hadoop.fs.FileSystem.get(conf)

sparkutils_init_complete = True


###########################################################################################################


from pyspark.sql.types import LongType, StructField, StructType


###
def dfPartitionSampler(df, percent_sample=5):
'''
Fast sampler. Great for huge datasets with lots of partitions.
Return first X partitions from original dataframe df,
while preserving partitioning and order of data.
:param df: source dataframe
:param percent_sample: e.g., 5 means first 5% of partitions will be returned only
:returns: new dataframe
'''

sparkutils_init()

full_partitions_num = df.rdd.getNumPartitions()
sample_patitions_num = int(full_partitions_num * percent_sample / 100)

def partition_sampler(partitionIndex, iterator):
if partitionIndex < sample_patitions_num:
return iterator
else:
return iter(())

rdd = df.rdd.mapPartitionsWithIndex(partition_sampler, preservesPartitioning=True)

return spark.createDataFrame(rdd, df.schema)


###
def dfZipWithIndex(df, offset=1, colName="rowId"):
'''
Expand All @@ -75,6 +117,8 @@ def dfZipWithIndex(df, offset=1, colName="rowId"):
:param colName: name of the index column
'''

sparkutils_init()

new_schema = StructType(
[StructField(colName, LongType(), True)] # new added field in front
+ df.schema.fields # previous schema
Expand All @@ -86,10 +130,15 @@ def dfZipWithIndex(df, offset=1, colName="rowId"):

return spark.createDataFrame(new_rdd, new_schema)


###
def file_to_df(df_name, file_path, header=True, delimiter='|', inferSchema=True, columns=None
, cache=False, zipWithIndex=None, partitions=None, cluster_by=None
, quote='"', escape='\\'
def file_to_df(df_name, file_path, delimiter='|', quote='"', escape='\\'
, header=True, inferSchema=True, columns=None
, where_clause=None, partitions_sampled_percent=None
, zipWithIndex=None
, partitions=None
, partition_by=None, sort_by=None, cluster_by=None
, cache=False
):
'''
Reads in a delimited file and sets up a Spark dataframe
Expand All @@ -104,15 +153,21 @@ def file_to_df(df_name, file_path, header=True, delimiter='|', inferSchema=True,
:param delimiter: one character
:param cache: boolean - cache this dataframe?
:param zipWithIndex: new column name to be assigned by zipWithIndex()
:param where_clause: filtering (before zipWithIndex/repartition/order by/cluster by etc is applied)
:param partitions_sampled_percent: if specified, will be passed as percent_sample to dfPartitionSampler()
:param partitions: number - if specified, will repartition to that number
:param partition_by: if partitions is specified, this parameter controls which columns to partition by
:param sort_by: if specified, sort by that key within partitions (no shuffling)
:param cluster_by: string - list of columns (comma separated) to run CLUSTER BY on
:param quote: character - by default the quote character is ", but can be set to any character.
Delimiters inside quotes are ignored; set to '\0' to disable quoting
:param escape: character - by default the escape character is \, but can be set to any character.
Escaped quote characters are ignored
'''

df = (sqlc.read.format('csv')
sparkutils_init()

df = (spark.read.format('csv')
.option('header', header)
.option('delimiter', delimiter)
.option('inferSchema', inferSchema)
Expand All @@ -124,17 +179,34 @@ def file_to_df(df_name, file_path, header=True, delimiter='|', inferSchema=True,
if columns:
df = df.toDF(*columns)

if where_clause:
before_where_clause = df_name + '_before_where_clause'
df.registerTempTable(before_where_clause)
df = spark.sql("SELECT * FROM " + before_where_clause + " WHERE " + where_clause)
spark.catalog.dropTempView(before_where_clause)

if partitions_sampled_percent:
df = dfPartitionSampler(df, partitions_sampled_percent)

if zipWithIndex:
# zipWithIndex has to happen before repartition()
# zipWithIndex has to happen before repartition() or any other shuffling
df = dfZipWithIndex(df, colName=zipWithIndex)

if partitions:
df = df.repartition(partitions)
if partition_by:
df = df.repartition(partitions, partition_by)
else:
df = df.repartition(partitions)

if sort_by:
df = df.sortWithinPartitions(sort_by, ascending=True)

if cluster_by:
# "CLUSTER BY" is equivalent to repartition(x, cols) + sortWithinPartitions(cols)

before_cluster_by = df_name + '_before_cluster_by'
df.registerTempTable(before_cluster_by)
df = sqlc.sql("SELECT * FROM " + before_cluster_by + " CLUSTER BY " + cluster_by)
df = spark.sql("SELECT * FROM " + before_cluster_by + " CLUSTER BY " + cluster_by)
spark.catalog.dropTempView(before_cluster_by)

if cache:
Expand All @@ -145,22 +217,33 @@ def file_to_df(df_name, file_path, header=True, delimiter='|', inferSchema=True,

return df


###
def sql_to_df(df_name, sql, cache=False, partitions=None):
def sql_to_df(df_name, sql, cache=False, partitions=None, partition_by=None, sort_by=None):
'''
Runs an sql query and sets up a Spark dataframe
:param df_name: registers this dataframe as a tempTable/view for SQL access;
important: it also registers a global variable under that name
:param sql: Spark SQL query to runs
:param partitions: number - if specified, will repartition to that number
:param partition_by: if partitions is specified, this parameter controls which columns to partition by
:param sort_by: if specified, sort by that key within partitions (no shuffling)
:param cache: cache this dataframe?
'''

df = sqlc.sql(sql)
sparkutils_init()

df = spark.sql(sql)

if partitions:
df = df.repartition(partitions)
if partition_by:
df = df.repartition(partitions, partition_by)
else:
df = df.repartition(partitions)

if sort_by:
df = df.sortWithinPartitions(sort_by, ascending=True)

if cache:
df = df.cache()
Expand All @@ -172,32 +255,51 @@ def sql_to_df(df_name, sql, cache=False, partitions=None):


###########################################################################################################
## below three are no implemented yet
## Basic wrappers around certain hadoop.fs.FileSystem API calls:
## https://hadoop.apache.org/docs/r2.8.2/api/org/apache/hadoop/fs/FileSystem.html

def HDFSfileExists (file_path):
def hdfs_exists (file_path):
'''
Returns True if HDFS file exists
:param file_path: file patch
:return: boolean
'''
return True

def dropHDFSfile (file_path):
sparkutils_init()

# https://hadoop.apache.org/docs/r2.8.2/api/org/apache/hadoop/fs/FileSystem.html#exists(org.apache.hadoop.fs.Path)
return fs.exists(hadoop.fs.Path(file_path))


def hdfs_drop (file_path, recursive=True):
'''
Drop HDFS file
Drop HDFS file/dir
:param file_path: HDFS file patch
:param file_path: HDFS file/directory path
:param recursive: drop subdirectories too
'''

def renameHDFSfile (src_name, dst_name):
sparkutils_init()

# https://hadoop.apache.org/docs/r2.8.2/api/org/apache/hadoop/fs/FileSystem.html#delete(org.apache.hadoop.fs.Path,%20boolean)
return fs.delete(hadoop.fs.Path(file_path), recursive)


def hdfs_rename (src_name, dst_name):
'''
Renames src file to dst file name
:param src_name: source name
:param dst_name: target name
'''

sparkutils_init()

# https://hadoop.apache.org/docs/r2.8.2/api/org/apache/hadoop/fs/FileSystem.html#rename(org.apache.hadoop.fs.Path,%20org.apache.hadoop.fs.Path)
return fs.rename(hadoop.fs.Path(src_name), hadoop.fs.Path(dst_name))


###########################################################################################################

def HDFScopyMerge (src_dir, dst_file, overwrite=False, deleteSource=False):
Expand All @@ -209,10 +311,12 @@ def HDFScopyMerge (src_dir, dst_file, overwrite=False, deleteSource=False):
:param src_dir: source directoy to get files from
:param dst_file: destination file to merge file to
:param overwrite: overwrite destination file if already exists?
:param overwrite: overwrite destination file if already exists? this would also overwrite temp file if exists
:param deleteSource: drop source directory after merge is complete
"""

sparkutils_init()

def debug_print (message):
if debug:
print("HDFScopyMerge(): " + message)
Expand All @@ -227,8 +331,9 @@ def debug_print (message):
# determine order of files in which they will be written:
files.sort(key=lambda f: str(f))

if not overwrite and HDFSfileExists(dst_file):
dropHDFSfile(dst_file)
if overwrite and hdfs_exists(dst_file):
hdfs_drop(dst_file)
debug_print("Target file {} dropped".format(dst_file))

# use temp file for the duration of the merge operation
dst_file_tmp = "{}.IN_PROGRESS.tmp".format(dst_file)
Expand All @@ -250,14 +355,14 @@ def debug_print (message):
out_stream.close()

if deleteSource:
fs.delete(hadoop.fs.Path(src_dir), True) # True=recursive
hdfs_drop(src_dir)
debug_print("Source directory {} removed.".format(src_dir))

try:
renameHDFSfile(dst_file_tmp, dst_file)
hdfs_rename(dst_file_tmp, dst_file)
debug_print("Temp file renamed to {}".format(dst_file))
except:
dropHDFSfile(dst_file_tmp) # drop temp file if we can't rename it to target name
hdfs_drop(dst_file_tmp) # drop temp file if we can't rename it to target name
raise


Expand All @@ -273,6 +378,8 @@ def HDFSwriteString (dst_file, content, overwrite=True, appendEOL=True):
:param appendEOL: append new line character?
"""

sparkutils_init()

out_stream = fs.create(hadoop.fs.Path(dst_file), overwrite)

if appendEOL:
Expand Down Expand Up @@ -309,7 +416,9 @@ def dataframeToHDFSfile (dataframe, dst_file, overwrite=False
:param quoteMode: https://commons.apache.org/proper/commons-csv/apidocs/org/apache/commons/csv/QuoteMode.html
"""

if not overwrite and HDFSfileExists(dst_file):
sparkutils_init()

if not overwrite and hdfs_exists(dst_file):
raise ValueError("Target file {} already exists and Overwrite is not requested".format(dst_file))

dst_dir = dst_file + '.tmpdir'
Expand Down
2 changes: 1 addition & 1 deletion abalon/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

version = '1.6.0'
version = '2.0.0'

0 comments on commit 59b8fa7

Please sign in to comment.