Skip to content

Commit

Permalink
chore: convert linear interpolation to use spark native functions
Browse files Browse the repository at this point in the history
  • Loading branch information
guanjieshen committed Dec 23, 2021
1 parent 32220b0 commit 2fc47fc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
45 changes: 22 additions & 23 deletions python/tempo/interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
last,
lit,
to_timestamp,
udf,
unix_timestamp,
when,
)
from pyspark.sql.types import DoubleType, FloatType
from pyspark.sql.window import Window

from tempo import *
Expand All @@ -25,18 +23,11 @@


class Interpolation:
def __init__(self):
"""
Constructor
:linear_udf register linear calculation UDF
"""
self.__linear_udf = udf(Interpolation.__calc_linear, DoubleType())

@staticmethod
def __calc_linear(epoch, epoch_ff, epoch_bf, value_ff, value_bf, value):
def __calc_linear_spark(
self, df: DataFrame, epoch, epoch_ff, epoch_bf, value_ff, value_bf, value
):
"""
User defined function for calculating linear interpolation on a DataFrame.
Native Spark function for calculating linear interpolation on a DataFrame.
:param epoch - Original epoch timestamp of the column to be interpolated.
:param epoch_ff - Forward filled epoch timestamp of the column to be interpolated.
Expand All @@ -45,12 +36,20 @@ def __calc_linear(epoch, epoch_ff, epoch_bf, value_ff, value_bf, value):
:param value_bf - Backfilled value of the column to be interpolated.
:param value - Original value of the column to be interpolated.
"""
if epoch_bf == epoch_ff:
return value
else:
m = (value_ff - value_bf) / (epoch_ff - epoch_bf)
value_linear = value_bf + m * (epoch - epoch_bf)
return value_linear
cols: List[str] = df.columns
cols.remove(value)
expr: str = f"""
case when {value_bf} = {value_ff} then {value}
else
({value_ff}-{value_bf})
/({epoch_ff}-{epoch_bf})
*({epoch}-{epoch_bf})
+ {value_bf}
end as {value}
"""
interpolated: DataFrame = df.selectExpr(*cols, expr)
# Preserve column order
return interpolated.select(*df.columns)

# TODO: Currently not being used. But will useful for interpolating arbitrary ranges.
def get_time_range(self, df: DataFrame, ts_col: str) -> Tuple[str]:
Expand Down Expand Up @@ -213,16 +212,16 @@ def __interpolate_column(

# Handle linear fill
if fill == "linear":
output_df = output_df.withColumn(
target_col,
self.__linear_udf(
output_df = output_df.transform(
lambda df: self.__calc_linear_spark(
df,
ts_col,
"readtime_ff",
"readtime_bf",
"readvalue_ff",
"readvalue_bf",
target_col,
),
)
)

return output_df
Expand Down
2 changes: 1 addition & 1 deletion python/tests/interpol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def test_forward_fill_interpolation(self):
assert_df_equality(expected_df, actual_df)

def test_linear_fill_interpolation(self):
"""Test forward fill interpolation."""
"""Test linear fill interpolation."""
self.buildTestingDataFrame()

expected_schema = StructType(
Expand Down

0 comments on commit 2fc47fc

Please sign in to comment.