Skip to content

Commit

Permalink
Clean modular torch training scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
AhmetZamanis committed Mar 9, 2023
1 parent 51d0dd8 commit abfaa34
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 229 deletions.
2 changes: 0 additions & 2 deletions DLinearTrain.py → DLinearTrainDisagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
model_summary = RichModelSummary(max_depth = -1)



# Specify D-Linear model
model_dlinear = DLinear(
input_chunk_length = 30,
Expand All @@ -42,7 +41,6 @@

# All covariates, future & past
dlinear2_futcovars = ['tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday', 'day_sin', 'day_cos', "month_sin", "month_cos", 'oil', 'oil_ma28', 'onpromotion', 'onp_ma28', 'local_holiday', 'regional_holiday', 'national_holiday', 'ny1', 'ny2', 'ny_eve31', 'ny_eve30', 'xmas_before', 'xmas_after', 'quake_after', 'dia_madre', 'futbol', 'black_friday', 'cyber_monday']

dlinear2_pastcovars = ["sales_ema7", "transactions", "trns_ma7"]


Expand Down
2 changes: 2 additions & 0 deletions DataPrepDisagg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


# Initialize list of disagg covariates
disagg_covars = []

Expand Down
51 changes: 19 additions & 32 deletions DataPrepMain.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import tensorboard
from tqdm import tqdm

# Import transformers
from sklearn.preprocessing import StandardScaler
from darts.dataprocessing.transformers import Scaler
from darts.dataprocessing.transformers import MissingValuesFiller
from sktime.transformations.series.difference import Differencer

# Import forecasting models
from darts.models.forecasting.dlinear import DLinearModel as DLinear
from darts.models.forecasting.rnn_model import RNNModel as RNN
from darts.models.forecasting.tft_model import TFTModel

# Import utils
from darts import TimeSeries
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from itertools import product
from functools import reduce


# Set printing options
np.set_printoptions(suppress=True, precision=4)
pd.options.display.float_format = '{:.4f}'.format
Expand Down Expand Up @@ -230,17 +246,12 @@
category_store_nbr = category_store_nbr.pivot(columns="category_store_nbr", values="sales")

# Merge all wide dataframes
from functools import reduce
wide_frames = [total, store_nbr, category_store_nbr]
df_sales = reduce(lambda left, right: pd.merge(
left, right, how="left", on="date"), wide_frames)
df_sales = df_sales.rename(columns = {"sales":"TOTAL"})
del total, store_nbr, wide_frames, category_store_nbr


from darts import TimeSeries
from itertools import product

# Create multivariate time series with sales components
ts_sales = TimeSeries.from_dataframe(df_sales, freq="D")

Expand Down Expand Up @@ -268,13 +279,12 @@


# Fill gaps
from darts.dataprocessing.transformers import MissingValuesFiller
na_filler = MissingValuesFiller()
ts_sales = na_filler.transform(ts_sales)


# Create differencer
from sktime.transformations.series.difference import Differencer

diff = Differencer(lags = 1)


Expand Down Expand Up @@ -310,9 +320,6 @@
total_covars1 = ts_totalcovars1.pd_dataframe()


from darts.utils.timeseries_generation import datetime_attribute_timeseries


# Retrieve copy of total_covars1, drop Fourier terms, trend knot (leaving daily predictors common to all categories).
common_covars = total_covars1[total_covars1.columns[2:21].values.tolist()]

Expand All @@ -332,23 +339,3 @@



# Import transformers
from sklearn.preprocessing import StandardScaler
from darts.dataprocessing.transformers import Scaler

# Import baseline models
from darts.models.forecasting.baselines import NaiveDrift, NaiveSeasonal
from darts.models.forecasting.sf_ets import StatsForecastETS as ETS

# Import forecasting models
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
from darts.models.forecasting.auto_arima import AutoARIMA
from darts.models.forecasting.random_forest import RandomForest
from darts.models.forecasting.xgboost import XGBModel
from darts.models.forecasting.dlinear import DLinearModel as DLinear
from darts.models.forecasting.rnn_model import RNNModel as RNN

# Import time decomposition functions
from darts.utils.statistics import extract_trend_and_seasonality as decomposition
from darts.utils.statistics import remove_from_series
from darts.utils.utils import ModelMode, SeasonalityMode
2 changes: 0 additions & 2 deletions DataPrepStore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@




# Initialize list of store covariates
store_covars = []

Expand Down
4 changes: 4 additions & 0 deletions JunkCode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
exec(open("test2.py").read())



Sys.setenv(QUARTO_PYTHON="./venv/Scripts/python.exe")

print(np.isnan(series1.values()).sum())
Expand Down
198 changes: 198 additions & 0 deletions TFTTrainStore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import RichProgressBar, RichModelSummary

# Create early stopper
early_stopper = EarlyStopping(
monitor = "val_loss",
min_delta = 5000, # 1% of min. MSE of best model so far
patience = 10
)

# Progress bar
progress_bar = RichProgressBar()

# Rich model summary
model_summary = RichModelSummary(max_depth = -1)


# # Specify TFT model 2.0 (TFT specific params all default)
# model_tft = TFTModel(
# input_chunk_length = 30,
# output_chunk_length = 15,
# hidden_size = 16,
# lstm_layers = 1,
# num_attention_heads = 4,
# dropout = 0.1,
# hidden_continuous_size = 8,
# batch_size = 32,
# n_epochs = 500,
# likelihood = None,
# loss_fn = torch.nn.MSELoss(),
# model_name = "TFTStoreX",
# log_tensorboard = True,
# save_checkpoints = True,
# show_warnings = True,
# optimizer_kwargs = {"lr": 0.002},
# lr_scheduler_cls = torch.optim.lr_scheduler.ReduceLROnPlateau,
# lr_scheduler_kwargs = {"patience": 5},
# pl_trainer_kwargs = {
# "callbacks": [early_stopper],
# "accelerator": "gpu",
# "devices": [0]
# }
# )


# # Specify TFT model 2.1 (TFT specific params all default, local-regional fix,
# # higher initial LR)
# model_tft = TFTModel(
# input_chunk_length = 30,
# output_chunk_length = 15,
# hidden_size = 16,
# lstm_layers = 1,
# num_attention_heads = 4,
# dropout = 0.1,
# hidden_continuous_size = 8,
# batch_size = 32,
# n_epochs = 500,
# likelihood = None,
# loss_fn = torch.nn.MSELoss(),
# model_name = "TFTStore2.1",
# log_tensorboard = True,
# save_checkpoints = True,
# show_warnings = True,
# optimizer_kwargs = {"lr": 0.005},
# lr_scheduler_cls = torch.optim.lr_scheduler.ReduceLROnPlateau,
# lr_scheduler_kwargs = {"patience": 5},
# pl_trainer_kwargs = {
# "callbacks": [early_stopper],
# "accelerator": "gpu",
# "devices": [0]
# }
# )


# # Specify TFT model 2.2 (TFT specific params all default, local-regional fix,
# # initial LR 0.003)
# model_tft = TFTModel(
# input_chunk_length = 30,
# output_chunk_length = 15,
# hidden_size = 16,
# lstm_layers = 1,
# num_attention_heads = 4,
# dropout = 0.1,
# hidden_continuous_size = 8,
# batch_size = 32,
# n_epochs = 500,
# likelihood = None,
# loss_fn = torch.nn.MSELoss(),
# model_name = "TFTStore2.2",
# log_tensorboard = True,
# save_checkpoints = True,
# show_warnings = True,
# optimizer_kwargs = {"lr": 0.003},
# lr_scheduler_cls = torch.optim.lr_scheduler.ReduceLROnPlateau,
# lr_scheduler_kwargs = {"patience": 5},
# pl_trainer_kwargs = {
# "callbacks": [early_stopper],
# "accelerator": "gpu",
# "devices": [0]
# }
# )


# Specify TFT model 2.3 (TFT specific params all default, local-regional fix,
# # initial LR 0.002)
model_tft = TFTModel(
input_chunk_length = 30,
output_chunk_length = 15,
hidden_size = 16,
lstm_layers = 1,
num_attention_heads = 4,
dropout = 0.1,
hidden_continuous_size = 8,
batch_size = 32,
n_epochs = 500,
likelihood = None,
loss_fn = torch.nn.MSELoss(),
model_name = "TFTStoreX",
log_tensorboard = True,
save_checkpoints = True,
show_warnings = True,
optimizer_kwargs = {"lr": 0.002},
lr_scheduler_cls = torch.optim.lr_scheduler.ReduceLROnPlateau,
lr_scheduler_kwargs = {"patience": 5},
pl_trainer_kwargs = {
"callbacks": [early_stopper],
"accelerator": "gpu",
"devices": [0]
}
)


# All covariates, future & past
tft_futcovars = [
"trend", 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday',
'day_sin', 'day_cos', "month_sin", "month_cos", 'oil', 'oil_ma28', 'onpromotion',
'onp_ma28', 'local_holiday', 'regional_holiday', 'national_holiday', 'ny1',
'ny2', 'ny_eve31', 'ny_eve30', 'xmas_before', 'xmas_after', 'quake_after',
'dia_madre', 'futbol', 'black_friday', 'cyber_monday']

tft_pastcovars = ["sales_ema7", "transactions", "trns_ma7"]


# Fit TFT model
model_tft.fit(
series = [y[:-45] for y in y_train_store],
future_covariates = [x[tft_futcovars] for x in x_store],
past_covariates = [x[tft_pastcovars] for x in x_store],
val_series = [y[-45:] for y in y_train_store],
val_future_covariates = [x[tft_futcovars] for x in x_store],
val_past_covariates = [x[tft_pastcovars] for x in x_store],
verbose = True
)


# # Load best checkpoint
# model_tft = TFTModel.load_from_checkpoint("TFTStore2.0", best = True)
#
#
#
#
# # First fit & validate the first store to initialize series
# pred_tft_store = model_tft.predict(
# n=15,
# series = y_train_store[0],
# future_covariates = x_store[0][tft_futcovars],
# past_covariates = x_store[0][tft_pastcovars]
# )
#
# # Then loop over all categories except first
# for i in tqdm(range(1, len(y_train_store))):
#
# # Predict validation data
# pred = model_tft.predict(
# n=15,
# series = y_train_store[i],
# future_covariates = x_store[i][tft_futcovars],
# past_covariates = x_store[i][tft_pastcovars]
# )
#
# # Stack predictions to multivariate series
# pred_tft_store = pred_tft_store.stack(pred)
#
# del pred, i
#
#
#
# # Score TFT
# scores_hierarchy(
# ts_sales[stores][-15:],
# trafo_zeroclip(pred_tft_store),
# stores,
# "TFT (global, all features)"
# )



Loading

0 comments on commit abfaa34

Please sign in to comment.