Skip to content

Commit

Permalink
fix __new__
Browse files Browse the repository at this point in the history
  • Loading branch information
Tagar committed Feb 18, 2018
1 parent 90af0e0 commit 1cdf77d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
29 changes: 16 additions & 13 deletions abalon/spark/pivoter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@

###########################################################################################################

class BasicSparkPivoter:
from pyspark.sql.types import *


def __init__(self):
pass
class BasicSparkPivoter:

def __new__ (df, idx_col, all_vars=None):
def __new__ (cls, df, idx_col=None, all_vars=None):
'''
Pivots a dataframe without aggregation.
Expand All @@ -60,13 +60,15 @@ def __new__ (df, idx_col, all_vars=None):
(there is no aggregation happens for value - use AggSparkPivoter instead if this is needed)
:param df: dataframe to pivot (see expected schema of the df above)
:param idx_col: name of the index column
:param idx_col: name of the index column; if not specified, will be taked from df
:param all_vars: list of all distinct values of `colname` column;
the only reason it's passed to this function is so you can redefine order of pivoted columns;
if not specified, datset will be scanned for all possible colnames
:return: resulting dataframe
'''

self = super(BasicSparkPivoter, cls).__new__(cls)

return self.pivot_df(df, idx_col, all_vars)

def merge_two_dicts(x, y):
Expand All @@ -86,18 +88,21 @@ def pivot_df (df, idx_col, all_vars):
if not all_vars:
# get list of variables from the dataset:
all_vars = sorted([row[0] for row in df.rdd.map(lambda (idx, k, v): k).distinct().collect()])

self.all_vars = all_vars

if not idx_col:
idx_col = df.columns[1] # take 2nd column name

pivoted_rdd = (df.rdd
.map(lambda (idx, k, v): (idx, {k: v})) # convert k,v to a 1-element dict
.reduceByKey(self.merge_two_dicts) # merge into a single dict for all vars for this idx
.map(lambda (idx, d): list(self.map_dict_to_denseArray(idx, d)))
# create final rdd with dense array of all variables
)

fields = [StructField(idx_col, StringType(), False)]
fields = [StructField(idx_col, StringType(), False)]
fields += [StructField(field_name, DoubleType(), True) for field_name in self.all_vars]

schema = StructType(fields)

pivoted_df = spark.createDataFrame(pivoted_rdd, schema)
Expand All @@ -110,23 +115,21 @@ def pivot_df (df, idx_col, all_vars):

class AggSparkPivoter (BasicSparkPivoter):

def __init__(self):
BasicSparkPivoter.__init__(self)
pass

def __new__ (df, idx_col, all_vars=None, agg_op=operator.add):
def __new__ (df, idx_col=None, all_vars=None, agg_op=operator.add):
'''
Pivots a dataframe without aggregation.
:param df: dataframe to pivot (see expected schema of the df above)
:param idx_col: name of the index column
:param idx_col: name of the index column; if not specified, will be taked from df
:param all_vars: list of all distinct values of `colname` column;
the only reason it's passed to this function is so you can redefine order of pivoted columns;
if not specified, datset will be scanned for all possible colnames
:param agg_op: aggregation operation/function, defaults to `add`
:return: resulting dataframe
'''

self = super(AggSparkPivoter, cls).__new__(cls)

self.agg_op = agg_op

return self.pivot_df(df, idx_col, all_vars)
Expand Down
3 changes: 2 additions & 1 deletion abalon/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

version = '2.1.1'
version = '2.1.2'


0 comments on commit 1cdf77d

Please sign in to comment.