Skip to content

Commit

Permalink
refactor,
Browse files Browse the repository at this point in the history
  • Loading branch information
zhumingpassional committed Mar 31, 2023
1 parent b7a41ac commit 14c1031
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
12 changes: 8 additions & 4 deletions finrl/applications/stock_trading/stock_trading.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def stock_trading(
initial_amount = 1000000
env_kwargs = {
"hmax": 100,
"initial_amount": 1000000,
"initial_amount": initial_amount,
"num_stock_shares": num_stock_shares,
"buy_cost_pct": buy_cost_list,
"sell_cost_pct": sell_cost_list,
Expand Down Expand Up @@ -243,7 +243,11 @@ def stock_trading(
dji_ = get_baseline(ticker="^DJI", start=trade_start_date, end=trade_end_date)
dji = pd.DataFrame()
dji[date_col] = dji_[date_col]
dji.rename(columns={"account_value": "DJI"}, inplace=True)
dji["DJI"] = dji_["close"]
# select the rows between trade_start and trade_end (not included), since some values may not in this region
dji = dji.loc[
(dji[date_col] >= trade_start_date) & (dji[date_col] < trade_end_date)
]

result = dji

Expand All @@ -266,7 +270,7 @@ def stock_trading(
# remove the rows with nan
result = result.dropna(axis=0, how="any")

# cal the column name of strategies, including DJI
# calc the column name of strategies, including DJI
col_strategies = []
for col in result.columns:
if col != date_col and col != "" and "Unnamed" not in col:
Expand All @@ -281,7 +285,7 @@ def stock_trading(
# stats
for col in col_strategies:
stats = backtest_stats(result, value_col_name=col)
print("stats of " + col + ": \n", stats)
print("\nstats of " + col + ": \n", stats)

# print and save result
print("result: ", result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def stock_trading_rolling_window(
actions_sac.to_csv("actions_sac.csv") if if_using_sac else None
actions_td3.to_csv("actions_td3.csv") if if_using_td3 else None

# cal the column name of strategies, including DJI
# calc the column name of strategies, including DJI
col_strategies = []
for col in result.columns:
if col != date_col and col != "" and "Unnamed" not in col:
Expand All @@ -415,7 +415,7 @@ def stock_trading_rolling_window(
# stats
for col in col_strategies:
stats = backtest_stats(result, value_col_name=col)
print("stats of " + col + ": \n", stats)
print("\nstats of " + col + ": \n", stats)

# print and save result
print("result: ", result)
Expand Down

0 comments on commit 14c1031

Please sign in to comment.