Skip to content

Commit

Permalink
DLinear 2.2 retrained and rescored with local-regional holiday feats …
Browse files Browse the repository at this point in the history
…fixed
  • Loading branch information
AhmetZamanis committed Mar 8, 2023
1 parent 61327c0 commit 39722c1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
Binary file modified ModelScores/ModelScoresStore.docx
Binary file not shown.
2 changes: 1 addition & 1 deletion ReportPart2.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,7 @@ model_dlinear = DLinear(
kernel_size = 25,
batch_size = 64,
n_epochs = 500,
model_name = "DLinearStore2.2",
model_name = "DLinearStoreX",
log_tensorboard = True,
save_checkpoints = True,
show_warnings = True,
Expand Down
32 changes: 31 additions & 1 deletion TFTStore.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@



# Specify TFT model 2.1 (TFT specific params all default, local-regional fix,
# higher initial LR)
model_tft = TFTModel(
input_chunk_length = 30,
output_chunk_length = 15,
hidden_size = 16,
lstm_layers = 1,
num_attention_heads = 4,
dropout = 0.1,
hidden_continuous_size = 8,
batch_size = 32,
n_epochs = 500,
likelihood = None,
loss_fn = torch.nn.MSELoss(),
model_name = "TFTStore2.1",
log_tensorboard = True,
save_checkpoints = True,
show_warnings = True,
optimizer_kwargs = {"lr": 0.005},
lr_scheduler_cls = torch.optim.lr_scheduler.ReduceLROnPlateau,
lr_scheduler_kwargs = {"patience": 5},
pl_trainer_kwargs = {
"callbacks": [early_stopper],
"accelerator": "gpu",
"devices": [0]
}
)




# All covariates, future & past
tft_futcovars = [
Expand Down Expand Up @@ -88,7 +118,7 @@


# Load best checkpoint
model_tft = TFTModel.load_from_checkpoint("TFTStore2.0", best = True)
model_tft = TFTModel.load_from_checkpoint("TFTStore2.1", best = True)



Expand Down

0 comments on commit 39722c1

Please sign in to comment.