From 9b2cf8a4d8c4648a70631a5c34d0ea600752f06f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Mar 2023 06:36:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../stock_trading/stock_trading.py | 28 +++++++++---------- .../stock_trading_rolling_window.py | 14 +++++----- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/finrl/applications/stock_trading/stock_trading.py b/finrl/applications/stock_trading/stock_trading.py index 464c78f7c..a0e8bfc60 100644 --- a/finrl/applications/stock_trading/stock_trading.py +++ b/finrl/applications/stock_trading/stock_trading.py @@ -5,12 +5,12 @@ import pandas as pd from stable_baselines3.common.logger import configure + from finrl.agents.stablebaselines3.models import DRLAgent from finrl.config import DATA_SAVE_DIR from finrl.config import INDICATORS from finrl.config import RESULTS_DIR from finrl.config import TENSORBOARD_LOG_DIR - from finrl.config import TRAINED_MODEL_DIR from finrl.config_tickers import DOW_30_TICKER from finrl.main import check_and_make_directories @@ -26,20 +26,18 @@ def stock_trading( - train_start_date: str, - train_end_date: str, - trade_start_date: str, - trade_end_date: str, - if_store_actions: bool = True, - if_store_result: bool = True, - if_using_a2c: bool = True, - if_using_ddpg: bool = True, - if_using_ppo: bool = True, - if_using_sac: bool = True, - if_using_td3: bool = True, + train_start_date: str, + train_end_date: str, + trade_start_date: str, + trade_end_date: str, + if_store_actions: bool = True, + if_store_result: bool = True, + if_using_a2c: bool = True, + if_using_ddpg: bool = True, + if_using_ppo: bool = True, + if_using_sac: bool = True, + if_using_td3: bool = True, ): - - sys.path.append("../FinRL") check_and_make_directories( [DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR] @@ -247,7 +245,7 @@ def stock_trading( # select the rows between trade_start and trade_end (not included), since some values may not in this region dji = dji.loc[ (dji[date_col] >= trade_start_date) & (dji[date_col] < trade_end_date) - ] + ] result = dji diff --git a/finrl/applications/stock_trading/stock_trading_rolling_window.py b/finrl/applications/stock_trading/stock_trading_rolling_window.py index c6f9dc76f..59a1d8a04 100644 --- a/finrl/applications/stock_trading/stock_trading_rolling_window.py +++ b/finrl/applications/stock_trading/stock_trading_rolling_window.py @@ -48,13 +48,13 @@ def stock_trading_rolling_window( trade_start_date: str, trade_end_date: str, rolling_window_length: int, - if_store_actions: bool=True, - if_store_result: bool=True, - if_using_a2c: bool=True, - if_using_ddpg: bool=True, - if_using_ppo: bool=True, - if_using_sac: bool=True, - if_using_td3: bool=True, + if_store_actions: bool = True, + if_store_result: bool = True, + if_using_a2c: bool = True, + if_using_ddpg: bool = True, + if_using_ppo: bool = True, + if_using_sac: bool = True, + if_using_td3: bool = True, ): # sys.path.append("../FinRL") check_and_make_directories(