Skip to content

Commit

Permalink
[2.5] Update flower CLI (NVIDIA#2792)
Browse files Browse the repository at this point in the history
* update flower cli

* update flwr hello-world job (NVIDIA#9)

* update flwr hello-world job

* add license header

* update readme

---------

Co-authored-by: Holger Roth <6304754+holgerroth@users.noreply.github.com>
Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 35830dd commit f388c12
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 75 deletions.
24 changes: 16 additions & 8 deletions examples/hello-world/hello-flower/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@ In this example, we run 2 Flower clients and Flower Server in parallel using NVF

To run Flower code in NVFlare, we created a job, including an app with the following custom folder content
```bash
$ tree jobs/hello-flwr-pt
.
├── client.py # <-- contains `ClientApp`
├── server.py # <-- contains `ServerApp`
├── task.py # <-- task-specific code (model, data)
$ tree jobs/hello-flwr-pt/app/custom

├── flwr_pt
│ ├── client.py # <-- contains `ClientApp`
│ ├── __init__.py # <-- to register the python module
│ ├── server.py # <-- contains `ServerApp`
│ └── task.py # <-- task-specific code (model, data)
└── pyproject.toml # <-- Flower project file
```
Note, this code is directly copied from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example.
Note, this code is adapted from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example.

## Install dependencies
To run this job with NVFlare, we first need to install the dependencies.
If you haven't already, we recommend creating a virtual environment.
```bash
python3 -m venv nvflare_flwr
source nvflare_flwr/bin/activate
```
To run a job with NVFlare, we first need to install its dependencies.
```bash
pip install -r requirements.txt
pip install ./jobs/hello-flwr-pt/app/custom
```

## Run a simulation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"tasks": ["*"],
"executor": {
"path": "nvflare.app_opt.flower.executor.FlowerExecutor",
"args": {
"client_app": "client:app"
}
"args": {}
}
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
{
"id": "ctl",
"path": "nvflare.app_opt.flower.controller.FlowerController",
"args": {
"server_app": "server:app"
}
"args": {}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""flwr_pt."""
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from flwr.client import ClientApp, NumPyClient
from task import DEVICE, Net, get_weights, load_data, set_weights, test, train
from flwr.common import Context

from .task import DEVICE, Net, get_weights, load_data, set_weights, test, train

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
Expand All @@ -32,7 +34,7 @@ def evaluate(self, parameters, config):
return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
def client_fn(context: Context):
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()

Expand All @@ -41,13 +43,3 @@ def client_fn(cid: str):
app = ClientApp(
client_fn=client_fn,
)


# Legacy mode
if __name__ == "__main__":
from flwr.client import start_client

start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
from typing import List, Tuple

from flwr.common import Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig
from flwr.common import Context, Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from task import Net, get_weights

from .task import Net, get_weights


# Define metric aggregation function
Expand Down Expand Up @@ -53,23 +54,16 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
)


# Define config
config = ServerConfig(num_rounds=3)


# Flower ServerApp
app = ServerApp(
config=config,
strategy=strategy,
)
def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

# Define config
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)

# Legacy mode
if __name__ == "__main__":
from flwr.server import start_server

start_server(
server_address="0.0.0.0:8080",
config=config,
strategy=strategy,
)
# Create ServerApp
app = ServerApp(server_fn=server_fn)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "flwr_pt"
version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.11.0,<2.0",
"nvflare~=2.5.0rc",
"torch==2.2.1",
"torchvision==0.17.1",
]

[tool.hatch.build.targets.wheel]
packages = ["."]

[tool.flwr.app]
publisher = "nvidia"

[tool.flwr.app.components]
serverapp = "flwr_pt.server:app"
clientapp = "flwr_pt.client:app"

[tool.flwr.app.config]
num-server-rounds = 3

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 2
4 changes: 0 additions & 4 deletions examples/hello-world/hello-flower/requirements.txt

This file was deleted.

25 changes: 6 additions & 19 deletions nvflare/app_opt/flower/applet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,9 @@


class FlowerClientApplet(CLIApplet):
def __init__(
self,
client_app: str,
):
"""Constructor of FlowerClientApplet, which extends CLIApplet.
Args:
client_app: the client app specification of the Flower app
"""
def __init__(self):
"""Constructor of FlowerClientApplet, which extends CLIApplet."""
CLIApplet.__init__(self)
self.client_app = client_app

def get_command(self, ctx: dict) -> CommandDescriptor:
"""Implementation of the get_command method required by the super class CLIApplet.
Expand Down Expand Up @@ -64,7 +56,7 @@ def get_command(self, ctx: dict) -> CommandDescriptor:
job_id = fl_ctx.get_job_id()
custom_dir = ws.get_app_custom_dir(job_id)
app_dir = ws.get_app_dir(job_id)
cmd = f"flower-client-app --insecure --grpc-adapter --superlink {addr} --dir {custom_dir} {self.client_app}"
cmd = f"flower-supernode --insecure --grpc-adapter --superlink {addr} {custom_dir}"

# use app_dir as the cwd for flower's client app.
# this is necessary for client_api to be used with the flower client app for metrics logging
Expand All @@ -76,23 +68,20 @@ def get_command(self, ctx: dict) -> CommandDescriptor:
class FlowerServerApplet(Applet):
def __init__(
self,
server_app: str,
database: str,
superlink_ready_timeout: float,
server_app_args: list = None,
):
"""Constructor of FlowerServerApplet.
Args:
server_app: Flower's server app specification
database: database spec to be used by the server app
superlink_ready_timeout: how long to wait for the superlink process to become ready
server_app_args: an optional list that contains additional command args passed to flower server app
"""
Applet.__init__(self)
self._app_process_mgr = None
self._superlink_process_mgr = None
self.server_app = server_app
self.database = database
self.superlink_ready_timeout = superlink_ready_timeout
self.server_app_args = server_app_args
Expand Down Expand Up @@ -148,8 +137,8 @@ def start(self, app_ctx: dict):
db_arg = f"--database {self.database}"

superlink_cmd = (
f"flower-superlink --insecure {db_arg} "
f"--fleet-api-address {server_addr} --fleet-api-type grpc-adapter "
f"flower-superlink --insecure --fleet-api-type grpc-adapter {db_arg} "
f"--fleet-api-address {server_addr} "
f"--driver-api-address {driver_addr}"
)

Expand All @@ -175,9 +164,7 @@ def start(self, app_ctx: dict):
if self.server_app_args:
args_str = " ".join(self.server_app_args)

app_cmd = (
f"flower-server-app --insecure --superlink {driver_addr} --dir {custom_dir} {args_str} {self.server_app}"
)
app_cmd = f"flower-server-app --insecure --superlink {driver_addr} {args_str} {custom_dir}"
cmd_desc = CommandDescriptor(
cmd=app_cmd,
log_file_name="server_app_log.txt",
Expand Down
4 changes: 0 additions & 4 deletions nvflare/app_opt/flower/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class FlowerController(TieController):
def __init__(
self,
num_rounds=1,
server_app: str = "server:app",
database: str = "",
server_app_args: list = None,
superlink_ready_timeout: float = 10.0,
Expand All @@ -43,7 +42,6 @@ def __init__(
Args:
num_rounds: number of rounds. Not used in this version.
server_app: the server app specification for Flower server app
database: database name
server_app_args: additional server app CLI args
superlink_ready_timeout: how long to wait for the superlink to become ready before starting server app
Expand Down Expand Up @@ -73,7 +71,6 @@ def __init__(
check_object_type("server_app_args", server_app_args, list)

self.num_rounds = num_rounds
self.server_app = server_app
self.database = database
self.server_app_args = server_app_args
self.superlink_ready_timeout = superlink_ready_timeout
Expand All @@ -86,7 +83,6 @@ def get_connector(self, fl_ctx: FLContext):

def get_applet(self, fl_ctx: FLContext):
return FlowerServerApplet(
server_app=self.server_app,
database=self.database,
superlink_ready_timeout=self.superlink_ready_timeout,
server_app_args=self.server_app_args,
Expand Down
4 changes: 1 addition & 3 deletions nvflare/app_opt/flower/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
class FlowerExecutor(TieExecutor):
def __init__(
self,
client_app: str = "client:app",
start_task_name=Constant.START_TASK_NAME,
configure_task_name=Constant.CONFIG_TASK_NAME,
per_msg_timeout=10.0,
Expand All @@ -40,7 +39,6 @@ def __init__(
self.tx_timeout = tx_timeout
self.client_shutdown_timeout = client_shutdown_timeout
self.num_rounds = None
self.client_app = client_app

def get_connector(self, fl_ctx: FLContext):
return GrpcClientConnector(
Expand All @@ -50,7 +48,7 @@ def get_connector(self, fl_ctx: FLContext):
)

def get_applet(self, fl_ctx: FLContext):
return FlowerClientApplet(self.client_app)
return FlowerClientApplet()

def configure(self, config: dict, fl_ctx: FLContext):
self.num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS)
Expand Down

0 comments on commit f388c12

Please sign in to comment.