Skip to content

Commit

Permalink
Changed to use job.as_id().
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen committed Jul 24, 2024
1 parent 14cf934 commit 6c4d9c8
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions examples/getting_started/pt/swarm_script_executor_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,23 @@
executor = ScriptExecutor(task_script_path=train_script)
job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"])

client_controller = SwarmClientController()
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])

client_controller = CrossSiteEvalClientController()
job.to(client_controller, f"site-{i}", tasks=["cse_*"])

# In swarm learning, each client acts also as an aggregator
aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS)
job.to(aggregator, f"site-{i}", id="aggregator")
job.to(aggregator, f"site-{i}")

# In swarm learning, each client uses a model persistor and shareable_generator
job.to(PTFileModelPersistor(model=Net()), f"site-{i}", id="persistor")
persistor = PTFileModelPersistor(model=Net())
job.to(persistor, f"site-{i}")
job.to(SimpleModelShareableGenerator(), f"site-{i}", id="shareable_generator")

client_controller = SwarmClientController(
aggregator_id=job.as_id(aggregator),
persistor_id=job.as_id(persistor)
)
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])

client_controller = CrossSiteEvalClientController()
job.to(client_controller, f"site-{i}", tasks=["cse_*"])

# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir")

0 comments on commit 6c4d9c8

Please sign in to comment.