Skip to content

Commit

Permalink
Merge branch 'main' into more_app_opt_example_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Aug 14, 2024
2 parents ebd0075 + 82b2874 commit c129ba0
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
12 changes: 6 additions & 6 deletions research/fed-bn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ Download the necessary datasets by running:
```

# Run FedBN on different data splits

We first set the job template path
```commandline
nvflare config -jt ../../job_templates
```
We will use the in-process client API, we choose the sag_pt job template and run the following command to create the job:
```
./create_job.sh
Expand All @@ -53,12 +56,9 @@ Run the FedBN simulation with the following command:
```

## Visualizing Results
To visualize training losses, we use the [Comet ML](https://www.comet.com/site/).
Below is an example of the loss visualization output:
![FedBN Loss Results](./figs/loss.jpeg)
With tensorboard, below is an example of the loss for the two sites:
![FedBN Loss Results](./figs/loss.png)

> **Note**: To use Comet ML experiment tracking system, you need to get Comet API key to get access.
> Alternatively, you can use Tensorboard or MLfow.

## Citation
If you find the code and dataset useful, please cite our paper.
Expand Down
Binary file removed research/fed-bn/figs/loss.jpeg
Binary file not shown.
Binary file added research/fed-bn/figs/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion research/fed-bn/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
nvflare~=2.4.0rc
torch
torchvision
comet_ml
tensorboard
19 changes: 10 additions & 9 deletions research/fed-bn/src/fedbn_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# (optional) metrics
import comet_ml
import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -24,13 +22,16 @@
# (1) import nvflare client API
import nvflare.client as flare

# (optional) set a fix place so we don't need to download everytime
# (optional) metrics
from nvflare.client.tracking import SummaryWriter

# (optional) set a fix place for data storage
# so we don't need to download everytime
DATASET_PATH = "/tmp/nvflare/data"

# (optional) We change to use GPU to speed things up.
# if you want to use CPU, change DEVICE="cpu"
DEVICE = "cuda:0"
# input your own comet ml account API key
COMET_API_KEY = ""


# key function for FedBN
Expand All @@ -56,9 +57,7 @@ def main():
# (2) initializes NVFlare client API
flare.init()

comet_ml.init()
exp = comet_ml.Experiment(project_name="fedbn_cifar10", api_key=COMET_API_KEY)

summary_writer = SummaryWriter()
while flare.is_running():
# (3) receives FLModel from NVFlare
input_model = flare.receive()
Expand All @@ -75,6 +74,7 @@ def main():
# (optional) calculate total steps
steps = epochs * len(trainloader)
for epoch in range(epochs): # loop over the dataset multiple times

running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
Expand All @@ -96,7 +96,7 @@ def main():
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
global_step = input_model.current_round * steps + epoch * len(trainloader) + i

exp.log_metrics({"loss": running_loss}, step=global_step)
summary_writer.add_scalar(tag="loss", scalar=running_loss, global_step=global_step)
running_loss = 0.0

print("Finished Training")
Expand Down Expand Up @@ -131,6 +131,7 @@ def evaluate(input_weights):

# (6) evaluate on received model for model selection
accuracy = evaluate(input_model.params)
summary_writer.add_scalar(tag="global_model_accuracy", scalar=accuracy, global_step=input_model.current_round)
# (7) construct trained FL model
output_model = flare.FLModel(
params=net.cpu().state_dict(),
Expand Down

0 comments on commit c129ba0

Please sign in to comment.