Skip to content

Commit

Permalink
Merge branch 'main' into fix_swarm_script_executor_cifar10
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen authored Jul 1, 2024
2 parents 729e03d + 46048b3 commit 14cf934
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 32 deletions.
3 changes: 3 additions & 0 deletions examples/getting_started/tf/nvflare_tf_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,12 @@
"metadata": {},
"outputs": [],
"source": [
"from nvflare.client.config import ExchangeFormat\n",
"\n",
"for i in range(n_clients):\n",
" executor = ScriptExecutor(\n",
" task_script_path=\"src/cifar10_tf_fl.py\", task_script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n",
" params_exchange_format=ExchangeFormat.NUMPY,\n",
" )\n",
" job.to(executor, f\"site-{i}\", gpu=0)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from src.tf_net import TFNet

from nvflare import FedAvg, FedJob, ScriptExecutor
from nvflare.client.config import ExchangeFormat

if __name__ == "__main__":
n_clients = 2
Expand All @@ -36,7 +37,9 @@
# Add clients
for i in range(n_clients):
executor = ScriptExecutor(
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
task_script_path=train_script,
task_script_args="", # f"--batch_size 32 --data_path /tmp/data/site-{i}"
params_exchange_format=ExchangeFormat.NUMPY,
)
job.to(executor, f"site-{i}", gpu=0)

Expand Down
12 changes: 8 additions & 4 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def broadcast_model(
task_name: str = AppConstants.TASK_TRAIN,
data: FLModel = None,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
wait_time_after_min_received: int = 10,
wait_time_after_min_received: int = 0,
blocking: bool = True,
callback: Callable[[FLModel], None] = None,
) -> List:
Expand All @@ -113,9 +114,11 @@ def broadcast_model(
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. If no data is given, send empty FLModel.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
timeout (int, optional): time to wait for clients to perform task. Defaults to 0, i.e., never time out.
wait_time_after_min_received (int, optional): time to wait after
minimum number of clients responses has been received. Defaults to 10.
minimum number of clients responses has been received. Defaults to 0.
blocking (bool, optional): whether to block to wait for task result. Defaults to True.
callback (Callable[[FLModel], None], optional): callback when a result is received, only called when blocking=False. Defaults to None.
Expand All @@ -127,6 +130,9 @@ def broadcast_model(
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
if data and not isinstance(data, FLModel):
raise TypeError("data must be a FLModel or None but got {}".format(type(data)))
if min_responses is None:
min_responses = 0 # this is internally used by controller's broadcast to represent all targets
check_non_negative_int("min_responses", min_responses)
check_non_negative_int("timeout", timeout)
check_non_negative_int("wait_time_after_min_received", wait_time_after_min_received)
if not blocking and not isinstance(callback, Callable):
Expand All @@ -140,10 +146,8 @@ def broadcast_model(

if targets:
targets = [client.name if isinstance(client, Client) else client for client in targets]
min_responses = len(targets)
self.info(f"Sending task {task_name} to {targets}")
else:
min_responses = len(self.engine.get_clients())
self.info(f"Sending task {task_name} to all clients")

if blocking:
Expand Down
14 changes: 8 additions & 6 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def send_model_and_wait(
task_name: str = "train",
data: FLModel = None,
targets: Union[List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
wait_time_after_min_received: int = 10,
) -> List[FLModel]:
"""Send a task with data to targets and wait for results.
Args:
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. Defaults to None.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
timeout (int, optional): time to wait for clients to perform task. Defaults to 0 (never time out).
wait_time_after_min_received (int, optional): time to wait after minimum number of client responses have been received. Defaults to 10.
Returns:
List[FLModel]
Expand All @@ -62,17 +63,17 @@ def send_model_and_wait(
task_name=task_name,
data=data,
targets=targets,
min_responses=min_responses,
timeout=timeout,
wait_time_after_min_received=wait_time_after_min_received,
)

def send_model(
self,
task_name: str = "train",
data: FLModel = None,
targets: Union[List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
wait_time_after_min_received: int = 10,
callback: Callable[[FLModel], None] = None,
) -> None:
"""Send a task with data to targets (non-blocking). Callback is called when a result is received.
Expand All @@ -81,8 +82,9 @@ def send_model(
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. Defaults to None.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
timeout (int, optional): time to wait for clients to perform task. Defaults to 0 (never time out).
wait_time_after_min_received (int, optional): time to wait after minimum number of client responses have been received. Defaults to 10.
callback (Callable[[FLModel], None], optional): callback when a result is received. Defaults to None.
Returns:
Expand All @@ -92,8 +94,8 @@ def send_model(
task_name=task_name,
data=data,
targets=targets,
min_responses=min_responses,
timeout=timeout,
wait_time_after_min_received=wait_time_after_min_received,
blocking=False,
callback=callback,
)
Expand Down
11 changes: 6 additions & 5 deletions nvflare/lighter/impl/aws_template.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
aws_start_sh: |
function find_ec2_gpu_instance_type() {
local gpucnt=0
local gpumem=0
Expand Down Expand Up @@ -98,10 +99,10 @@ aws_start_sh: |
then
while true
do
read -e -i ${REGION} -p "* Cloud EC2 region, press ENTER to accept default: " REGION
prompt REGION "* Cloud EC2 region, press ENTER to accept default" "${REGION}"
if [ ${container} = false ]
then
read -e -i ${AMI_NAME} -p "* Cloud AMI image name, press ENTER to accept default (use amd64 or arm64): " AMI_NAME
prompt AMI_NAME "* Cloud AMI image name (use amd64 or arm64), press ENTER to accept default" "${AMI_NAME}"
printf " retrieving AMI ID for ${AMI_NAME} ... "
IMAGES=$(aws ec2 describe-images --region ${REGION} --owners ${AMI_IMAGE_OWNER} --filters "Name=name,Values=*${AMI_NAME}*" --output json)
if [ "${#IMAGES}" -lt 30 ]
Expand All @@ -118,9 +119,9 @@ aws_start_sh: |
fi
find_ec2_gpu_instance_type
fi
prompt AMI_IMAGE "* Cloud AMI image, press ENTER to accept default ${AMI_IMAGE}: "
read -e -i ${EC2_TYPE} -p "* Cloud EC2 type, press ENTER to accept default: " EC2_TYPE
prompt ans "region = ${REGION}, ami image = ${AMI_IMAGE}, EC2 type = ${EC2_TYPE}, OK? (Y/n) "
prompt AMI_IMAGE "* Cloud AMI image, press ENTER to accept default"
prompt EC2_TYPE "* Cloud EC2 type, press ENTER to accept default" "${EC2_TYPE}"
prompt ans "region = ${REGION}, ami image = ${AMI_IMAGE}, EC2 type = ${EC2_TYPE}, OK? (Y/n)"
if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]
then
break
Expand Down
20 changes: 10 additions & 10 deletions nvflare/lighter/impl/azure_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ azure_start_svr_header_sh: |
then
while true
do
prompt VM_IMAGE "Cloud VM image, press ENTER to accept default ${VM_IMAGE}: "
prompt VM_SIZE "Cloud VM size, press ENTER to accept default ${VM_SIZE}: "
prompt VM_IMAGE "Cloud VM image, press ENTER to accept default" "${VM_IMAGE}"
prompt VM_SIZE "Cloud VM size, press ENTER to accept default" "${VM_SIZE}"
if [ $self_dns == true ]
then
prompt LOCATION "Cloud location, press ENTER to accept default ${LOCATION}: "
prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, location = ${LOCATION}, OK? (Y/n) "
prompt LOCATION "Cloud location, press ENTER to accept default" "${LOCATION}"
prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, location = ${LOCATION}, OK? (Y/n)"
else
prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n) "
prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n)"
fi
if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi
done
Expand All @@ -82,7 +82,7 @@ azure_start_svr_header_sh: |
if [ $container == false ]
then
echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}."
prompt ans "Press ENTER when it's done or no additional dependencies. "
prompt ans "Press ENTER when it's done or no additional dependencies."
fi
az login --use-device-code -o none
Expand Down Expand Up @@ -221,9 +221,9 @@ azure_start_cln_header_sh: |
then
while true
do
prompt LOCATION "Cloud location, press ENTER to accept default ${LOCATION}: "
prompt VM_IMAGE "Cloud VM image, press ENTER to accept default ${VM_IMAGE}: "
prompt VM_SIZE "Cloud VM size, press ENTER to accept default ${VM_SIZE}: "
prompt LOCATION "Cloud location, press ENTER to accept default" "${LOCATION}"
prompt VM_IMAGE "Cloud VM image, press ENTER to accept default" "${VM_IMAGE}"
prompt VM_SIZE "Cloud VM size, press ENTER to accept default" "${VM_SIZE}"
prompt ans "location = ${LOCATION}, VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n) "
if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi
done
Expand All @@ -232,7 +232,7 @@ azure_start_cln_header_sh: |
if [ $container == false ]
then
echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}."
prompt ans "Press ENTER when it's done or no additional dependencies. "
prompt ans "Press ENTER when it's done or no additional dependencies."
fi
az login --use-device-code -o none
Expand Down
27 changes: 22 additions & 5 deletions nvflare/lighter/impl/master_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -798,12 +798,29 @@ cloud_script_header: |
}
function prompt() {
local __default="$1"
read -p "$2" ans
if [[ ! -z "$ans" ]]
then
eval $__default="'$ans'"
# usage: prompt NEW_VAR "Prompt message" ["${PROMPT_VALUE}"]
local __resultvar=$1
local __prompt=$2
local __default=${3:-}
local __result
if [[ ${BASH_VERSINFO[0]} -ge 4 && -n "$__default" ]]
then
read -e -i "$__default" -p "$__prompt: " __result
else
__default=${3:-${!__resultvar:-}}
if [[ -n $__default ]]
then
printf "%s [%s]: " "$__prompt" "$__default"
else
printf "%s: " "$__prompt"
fi
IFS= read -r __result
if [[ -z "$__result" && -n "$__default" ]]
then
__result="$__default"
fi
fi
eval $__resultvar="'$__result'"
}
function get_resources_file() {
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ python_requires = >= 3.8
install_requires =
cryptography>=36.0.0
Flask==3.0.2
Werkzeug==3.0.1
Werkzeug==3.0.3
Flask-JWT-Extended==4.6.0
Flask-SQLAlchemy==3.1.1
SQLAlchemy==2.0.16
Expand Down

0 comments on commit 14cf934

Please sign in to comment.