diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index a4955ed5a0..948d64d519 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -32,7 +32,8 @@ # Add clients for i in range(n_clients): executor = ScriptRunner( - script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + script=train_script, + script_args="", # f"--batch_size 32 --data_path /tmp/data/site-{i}" ) job.to(executor, f"site-{i+1}") diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 7395466c68..5dd0a49cc9 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -48,7 +48,10 @@ def main(): client_name = sys_info["site_name"] train_dataset = CIFAR10( - root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True + root=os.path.join(DATASET_PATH, client_name), + transform=transforms, + download=True, + train=True, ) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) @@ -74,9 +77,19 @@ def main(): running_loss += cost.cpu().detach().numpy() / images.size()[0] if i % 3000 == 0: - print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}") - global_step = input_model.current_round * steps + epoch * len(train_loader) + i - summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + print( + f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss/3000}" + ) + global_step = ( + input_model.current_round * steps + + epoch * len(train_loader) + + i + ) + summary_writer.add_scalar( + tag="loss_for_each_batch", + scalar=running_loss, + global_step=global_step, + ) running_loss = 0.0 print("Finished Training")