Skip to content

Commit

Permalink
[Benchmark] Fix amp level bug in some gpt tests (#9116)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Sep 13, 2024
1 parent 5c1779c commit 089a3c3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ function _set_params(){
sharding_degree=${10:-"1"} # (可选)
sharding_stage=${11:-"1"} # (可选)sharding case
level=${12:-"o1"} # o1|o2|o3

if [[ $FLAGS_enable_pir_api == "1" || $FLAGS_enable_pir_api == "True" ]]; then
if [ ${level} == "o3" ]; then
level="o2"
echo "amp level changed to o2 in pir mode"
else
echo "amp level is o3"
fi
else
echo "FLAGS_enable_pir_api = 0"
fi

local_batch_size=${13:-"8"} # (可选)本地batch size
schedule_mode=${14:-"1F1B"} # (可选)schedule mode
base_batch_size=$global_batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ function _set_params(){
sharding_degree=${10:-"1"} # (可选)
sharding_stage=${11:-"1"} # (可选)sharding case
level=${12:-"o1"} # o1|o2|o3

if [[ $FLAGS_enable_pir_api == "1" || $FLAGS_enable_pir_api == "True" ]]; then
if [ ${level} == "o3" ]; then
level="o2"
echo "amp level changed to o2 in pir mode"
else
echo "amp level is o3"
fi
else
echo "FLAGS_enable_pir_api = 0"
fi

local_batch_size=${13:-"8"} # (可选)本地batch size
schedule_mode=${14:-"1F1B"} # (可选)schedule mode
base_batch_size=$global_batch_size
Expand Down

0 comments on commit 089a3c3

Please sign in to comment.