Skip to content

Commit

Permalink
interpolation: add the ability to call interpolate after resample
Browse files Browse the repository at this point in the history
  • Loading branch information
guanjieshen committed Jan 18, 2022
1 parent 015879b commit f871c2f
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 73 deletions.
25 changes: 16 additions & 9 deletions python/tempo/interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@


class Interpolation:
def __init__(self, is_resampled: bool):
self.is_resampled = is_resampled

def __validate_fill(self, method: str):
"""
Validate if the fill provided is within the allowed list of values.
Expand Down Expand Up @@ -145,7 +148,7 @@ def __interpolate_column(
col(f"previous_{target_col}"),
).otherwise(col(target_col)),
)
# Handle backwards fill
# Handle backwards fill
if method == "bfill":
output_df = output_df.withColumn(
target_col,
Expand Down Expand Up @@ -246,8 +249,7 @@ def __generate_target_fill(
.orderBy(ts_col)
.rowsBetween(0, sys.maxsize)
),
)
.withColumn(
).withColumn(
f"next_{target_col}",
lead(df[target_col]).over(
Window.partitionBy(*partition_cols).orderBy(ts_col)
Expand Down Expand Up @@ -277,20 +279,25 @@ def interpolate(
:param func - aggregate function used for sampling to the specified interval
:param method - interpolation function usded to fill missing values
:param show_interpolated - show if row is interpolated?
:return DataFrame
:return DataFrame
"""
# Validate input parameters
self.__validate_fill(method)
self.__validate_col(tsdf.df, partition_cols, target_cols, ts_col)

# Resample and Normalize Input
resampled_input: DataFrame = tsdf.resample(
freq=freq, func=func, metricCols=target_cols
).df
# Only select required columns for interpolation
input_cols: List[str] = [*partition_cols, ts_col, *target_cols]
sampled_input: DataFrame = tsdf.df.select(*input_cols)

if self.is_resampled is False:
# Resample and Normalize Input
sampled_input: DataFrame = tsdf.resample(
freq=freq, func=func, metricCols=target_cols
).df

# Fill timeseries for nearest values
time_series_filled = self.__generate_time_series_fill(
resampled_input, partition_cols, ts_col
sampled_input, partition_cols, ts_col
)

# Generate surrogate timestamps for each target column
Expand Down
2 changes: 1 addition & 1 deletion python/tempo/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def aggregate(tsdf, freq, func, metricCols = None, prefix = None, fill = None):
if fill:
res = imputes.join(res, tsdf.partitionCols + [tsdf.ts_col], "leftouter").na.fill(0, metrics)

return(tempo.TSDF(res, ts_col = tsdf.ts_col, partition_cols = tsdf.partitionCols))
return res


def checkAllowableFreq(tsdf, freq):
Expand Down
51 changes: 47 additions & 4 deletions python/tempo/tsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ def resample(self, freq, func=None, metricCols = None, prefix=None, fill = None)
:return: TSDF object with sample data using aggregate function
"""
rs.validateFuncExists(func)
enriched_tsdf = rs.aggregate(self, freq, func, metricCols, prefix, fill)
return(enriched_tsdf)
enriched_df = rs.aggregate(self, freq, func, metricCols, prefix, fill)
return (_ResampledTSDF(enriched_df, ts_col = self.ts_col, partition_cols = self.partitionCols, freq = freq, func = func))

def interpolate(self, freq: str, func: str, method: str, target_cols: List[str] = None,ts_col: str = None, partition_cols: List[str]=None, show_interpolated:bool = False):
"""
Expand Down Expand Up @@ -615,9 +615,9 @@ def interpolate(self, freq: str, func: str, method: str, target_cols: List[str]
((datatype[1] in summarizable_types) and
(datatype[0].lower() not in prohibited_cols))]

interpolate_service: Interpolation = Interpolation()
interpolate_service: Interpolation = Interpolation(is_resampled=False)
tsdf_input = TSDF(self.df, ts_col = ts_col, partition_cols=partition_cols)
interpolated_df:DataFrame = interpolate_service.interpolate(tsdf_input,ts_col, partition_cols,target_cols, freq, func, method, show_interpolated)
interpolated_df = interpolate_service.interpolate(tsdf_input,ts_col, partition_cols,target_cols, freq, func, method, show_interpolated)

return TSDF(interpolated_df, ts_col = ts_col, partition_cols=partition_cols)

Expand Down Expand Up @@ -711,3 +711,46 @@ def tempo_fourier_util(pdf):
result = result.drop("tdval", "tpoints")

return TSDF(result, self.ts_col, self.partitionCols, self.sequence_col)


class _ResampledTSDF(TSDF):
def __init__(self, df, ts_col="event_ts", partition_cols=None, sequence_col = None, freq = None, func = None):
super(_ResampledTSDF, self).__init__(df, ts_col, partition_cols, sequence_col)
self.__freq = freq
self.__func = func

def interpolate(self, method: str, target_cols: List[str] = None, show_interpolated:bool = False):
"""
function to interpolate based on frequency, aggregation, and fill similar to pandas. Data will first be aggregated using resample, then missing values
will be filled based on the fill calculation.
:param method: function used to fill missing values e.g. linear, null, zero, bfill, ffill
:param target_cols [optional]: columns that should be interpolated, by default interpolates all numeric columns
:param show_interpolated [optional]: if true will include an additional column to show which rows have been fully interpolated.
:return: new TSDF object containing interpolated data
"""

# Set defaults for target columns, timestamp column and partition columns when not provided
if target_cols is None:
prohibited_cols: List[str] = self.partitionCols + [self.ts_col]
summarizable_types = ['int', 'bigint', 'float', 'double']

# get summarizable find summarizable columns
target_cols:List[str] = [datatype[0] for datatype in self.df.dtypes if
((datatype[1] in summarizable_types) and
(datatype[0].lower() not in prohibited_cols))]

interpolate_service: Interpolation = Interpolation(is_resampled=True)
tsdf_input = TSDF(self.df, ts_col = self.ts_col, partition_cols=self.partitionCols)
interpolated_df = interpolate_service.interpolate(
tsdf=tsdf_input,
ts_col=self.ts_col,
partition_cols=self.partitionCols,
target_cols=target_cols,
freq=self.__freq,
func=self.__func,
method=method,
show_interpolated=show_interpolated,
)

return TSDF(interpolated_df, ts_col = self.ts_col, partition_cols=self.partitionCols)
156 changes: 97 additions & 59 deletions python/tests/interpol_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class InterpolationTest(SparkTest):
def buildTestingDataFrame(self):
def buildTestingDataFrame(self):
schema = StructType(
[
StructField("partition_a", StringType()),
Expand Down Expand Up @@ -70,10 +70,10 @@ def buildTestingDataFrame(self):
)

# register interpolation helper
self.interpolate_helper = Interpolation()
self.interpolate_helper = Interpolation(is_resampled=False)


class InterpolationUnitTest(InterpolationTest):

def test_fill_validation(self):
"""Test fill parameter is valid."""
self.buildTestingDataFrame()
Expand Down Expand Up @@ -399,15 +399,15 @@ def test_show_interpolated(self):

assert_df_equality(expected_df, actual_df, ignore_nullable=True)


class InterpolationIntegrationTest(InterpolationTest):
def test_interpolation_using_default_tsdf_params(self):
"""
Verify that interpolate uses the ts_col and partition_col from TSDF if not explicitly specified,
Verify that interpolate uses the ts_col and partition_col from TSDF if not explicitly specified,
and all columns numeric are automatically interpolated if target_col is not specified.
"""
self.buildTestingDataFrame()


expected_data = [
["A", "A-1", "2020-01-01 00:00:00", 0.0, None],
["A", "A-1", "2020-01-01 00:00:30", 1.0, None],
Expand Down Expand Up @@ -436,72 +436,110 @@ def test_interpolation_using_default_tsdf_params(self):
expected_df: DataFrame = self.buildTestDF(expected_schema, expected_data)

actual_df: DataFrame = self.simple_input_tsdf.interpolate(
freq="30 seconds",
func="mean",
method="linear"
freq="30 seconds", func="mean", method="linear"
).df

assert_df_equality(expected_df, actual_df, ignore_nullable=True)

def test_interpolation_using_custom_params(self):
"""Verify that by specifying optional paramters it will change the result of the interpolation based on those modified params."""
self.buildTestingDataFrame()

expected_data = [
["A", "A-1", "2020-01-01 00:00:00", 0.0, False, False],
["A", "A-1", "2020-01-01 00:00:30", 1.0, True, True],
["A", "A-1", "2020-01-01 00:01:00", 2.0, False, False],
["A", "A-1", "2020-01-01 00:01:30", 3.0, False, True],
["A", "A-1", "2020-01-01 00:02:00", 4.0, False, True],
["A", "A-1", "2020-01-01 00:02:30", 5.0, True, True],
["A", "A-1", "2020-01-01 00:03:00", 6.0, True, True],
["A", "A-1", "2020-01-01 00:03:30", 7.0, False, True],
["A", "A-1", "2020-01-01 00:04:00", 8.0, False, False],
["A", "A-1", "2020-01-01 00:04:30", 9.0, True, True],
["A", "A-1", "2020-01-01 00:05:00", 10.0, True, True],
["A", "A-1", "2020-01-01 00:05:30", 11.0, False, False],
]
"""Verify that by specifying optional paramters it will change the result of the interpolation based on those modified params."""
self.buildTestingDataFrame()

expected_data = [
["A", "A-1", "2020-01-01 00:00:00", 0.0, False, False],
["A", "A-1", "2020-01-01 00:00:30", 1.0, True, True],
["A", "A-1", "2020-01-01 00:01:00", 2.0, False, False],
["A", "A-1", "2020-01-01 00:01:30", 3.0, False, True],
["A", "A-1", "2020-01-01 00:02:00", 4.0, False, True],
["A", "A-1", "2020-01-01 00:02:30", 5.0, True, True],
["A", "A-1", "2020-01-01 00:03:00", 6.0, True, True],
["A", "A-1", "2020-01-01 00:03:30", 7.0, False, True],
["A", "A-1", "2020-01-01 00:04:00", 8.0, False, False],
["A", "A-1", "2020-01-01 00:04:30", 9.0, True, True],
["A", "A-1", "2020-01-01 00:05:00", 10.0, True, True],
["A", "A-1", "2020-01-01 00:05:30", 11.0, False, False],
]

expected_schema = StructType(
[
StructField("partition_a", StringType()),
StructField("partition_b", StringType()),
StructField("event_ts", StringType(), False),
StructField("value_a", DoubleType()),
StructField("is_ts_interpolated", BooleanType(), False),
StructField("is_interpolated_value_a", BooleanType(), False),
]
)
expected_schema = StructType(
[
StructField("partition_a", StringType()),
StructField("partition_b", StringType()),
StructField("event_ts", StringType(), False),
StructField("value_a", DoubleType()),
StructField("is_ts_interpolated", BooleanType(), False),
StructField("is_interpolated_value_a", BooleanType(), False),
]
)

expected_df: DataFrame = self.buildTestDF(expected_schema, expected_data)
expected_df: DataFrame = self.buildTestDF(expected_schema, expected_data)

actual_df: DataFrame = self.simple_input_tsdf.interpolate(
ts_col="event_ts",
show_interpolated=True,
partition_cols=["partition_a", "partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
method="linear"
).df
actual_df: DataFrame = self.simple_input_tsdf.interpolate(
ts_col="event_ts",
show_interpolated=True,
partition_cols=["partition_a", "partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
method="linear",
).df

assert_df_equality(expected_df, actual_df, ignore_nullable=True)
assert_df_equality(expected_df, actual_df, ignore_nullable=True)

def test_tsdf_constructor_params_are_updated(self):
"""Verify that resulting TSDF class has the correct values for ts_col and partition_col based on the interpolation."""
self.buildTestingDataFrame()
"""Verify that resulting TSDF class has the correct values for ts_col and partition_col based on the interpolation."""
self.buildTestingDataFrame()

actual_tsdf:TSDF = self.simple_input_tsdf.interpolate(
ts_col="event_ts",
show_interpolated=True,
partition_cols=["partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
method="linear"
)
actual_tsdf: TSDF = self.simple_input_tsdf.interpolate(
ts_col="event_ts",
show_interpolated=True,
partition_cols=["partition_b"],
target_cols=["value_a"],
freq="30 seconds",
func="mean",
method="linear",
)

self.assertEqual(actual_tsdf.ts_col, "event_ts")
self.assertEqual(actual_tsdf.partitionCols, ["partition_b"])

def test_interpolation_on_sampled_data(self):
"""Verify interpolation can be chained with resample within the TSDF class"""
self.buildTestingDataFrame()

expected_data = [
["A", "A-1", "2020-01-01 00:00:00", 0.0, False, False],
["A", "A-1", "2020-01-01 00:00:30", 1.0, True, True],
["A", "A-1", "2020-01-01 00:01:00", 2.0, False, False],
["A", "A-1", "2020-01-01 00:01:30", 3.0, False, True],
["A", "A-1", "2020-01-01 00:02:00", 4.0, False, True],
["A", "A-1", "2020-01-01 00:02:30", 5.0, True, True],
["A", "A-1", "2020-01-01 00:03:00", 6.0, True, True],
["A", "A-1", "2020-01-01 00:03:30", 7.0, False, True],
["A", "A-1", "2020-01-01 00:04:00", 8.0, False, False],
["A", "A-1", "2020-01-01 00:04:30", 9.0, True, True],
["A", "A-1", "2020-01-01 00:05:00", 10.0, True, True],
["A", "A-1", "2020-01-01 00:05:30", 11.0, False, False],
]

self.assertEqual(actual_tsdf.ts_col , "event_ts")
self.assertEqual(actual_tsdf.partitionCols ,["partition_b"])
expected_schema = StructType(
[
StructField("partition_a", StringType()),
StructField("partition_b", StringType()),
StructField("event_ts", StringType(), False),
StructField("value_a", DoubleType()),
StructField("is_ts_interpolated", BooleanType(), False),
StructField("is_interpolated_value_a", BooleanType(), False),
]
)

expected_df: DataFrame = self.buildTestDF(expected_schema, expected_data)

actual_df: DataFrame = (
self.simple_input_tsdf.resample(freq="30 seconds", func="mean", fill=None)
.interpolate(
method="linear", target_cols=["value_a"], show_interpolated=True
)
.df
)

assert_df_equality(expected_df, actual_df, ignore_nullable=True)

0 comments on commit f871c2f

Please sign in to comment.