Skip to content

Commit

Permalink
Add Ditto helper to app_common, update prostate example learners (NVI…
Browse files Browse the repository at this point in the history
…DIA#437)

* Add Ditto helper to app_common, update prostate example learners

* Correct a bug in local valid

* refactor DittoHelper

* remove a redundant lr in DittoHelper

* remove a redundant lr in DittoHelper

* remove a redundant lr in DittoHelper

* format adjustment

* refactor Ditto helper with abstract method

* further touchups on ditto helper class and supervised learner

* add type hint and check with default
  • Loading branch information
ZiyueXu77 authored Apr 27, 2022
1 parent a6a6664 commit 9f2c8b9
Show file tree
Hide file tree
Showing 20 changed files with 687 additions and 850 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"executors": [
{
"tasks": [
"train", "submit_model", "validate"
"train", "validate"
],
"executor": {
"id": "Executor",
Expand All @@ -24,10 +24,10 @@
"components": [
{
"id": "prostate-learner",
"path": "pt.learners.supervised_prostate_learner.SupervisedProstateLearner",
"path": "pt.learners.supervised_monai_prostate_learner.SupervisedMonaiProstateLearner",
"args": {
"aggregation_epochs": 10,
"train_config_filename": "config_train.json"
"train_config_filename": "config_train.json",
"aggregation_epochs": 1
}
}
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"format_version": 2,
"min_clients": 1,
"num_rounds": 100,
"num_rounds": 150,
"server": {
"heart_beat_timeout": 600
},
Expand All @@ -15,7 +15,7 @@
"model": {
"path": "monai.networks.nets.unet.UNet",
"args": {
"dimensions": 3,
"dimensions": 2,
"in_channels": 1,
"out_channels": 1,
"channels": [16, 32, 64, 128, 256],
Expand All @@ -33,31 +33,14 @@
{
"id": "aggregator",
"name": "InTimeAccumulateWeightedAggregator",
"args": {
"aggregation_weights": {
"client_MSD": 1.0,
"client_I2CVB": 1.0,
"client_NCI_ISBI_3T": 1.0,
"client_NCI_ISBI_Dx": 1.0
}
}
"args": {}
},
{
"id": "model_selector",
"name": "IntimeModelSelectionHandler",
"args": {}
},
{
"id": "model_locator",
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
"weigh_by_local_iter": true
}
},
{
"id": "json_generator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
Expand All @@ -75,15 +58,6 @@
"train_task_name": "train",
"train_timeout": 0
}
},
{
"id": "global_model_eval",
"name": "GlobalModelEval",
"args": {
"model_locator_id": "model_locator",
"validation_timeout": 6000,
"cleanup_models": true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"learning_rate": 1e-2,
"learning_rate": 1e-3,
"fedproxloss_mu": 0.0,
"cache_dataset": 0.0,
"dataset_base_dir": "PWD/data_preparation/dataset",
"datalist_json_path": "PWD/data_preparation/datalists/client_All.json"
"cache_dataset": 1.0,
"dataset_base_dir": "PWD/data_preparation/dataset_2D",
"datalist_json_path": "PWD/data_preparation/datalist/client_All.json"
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"executors": [
{
"tasks": [
"train", "submit_model", "validate"
"train", "validate"
],
"executor": {
"id": "Executor",
Expand All @@ -24,11 +24,11 @@
"components": [
{
"id": "prostate-learner",
"path": "pt.learners.prostate_ditto_learner.ProstateDittoLearner",
"path": "pt.learners.supervised_monai_prostate_ditto_learner.SupervisedMonaiProstateDittoLearner",
"args": {
"aggregation_epochs": 10,
"local_model_epochs": 10,
"train_config_filename": "config_train.json"
"train_config_filename": "config_train.json",
"aggregation_epochs": 1,
"ditto_model_epochs": 1
}
}
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"format_version": 2,
"min_clients": 4,
"num_rounds": 100,
"min_clients": 6,
"num_rounds": 150,
"server": {
"heart_beat_timeout": 600
},
Expand All @@ -15,7 +15,7 @@
"model": {
"path": "monai.networks.nets.unet.UNet",
"args": {
"dimensions": 3,
"dimensions": 2,
"in_channels": 1,
"out_channels": 1,
"channels": [16, 32, 64, 128, 256],
Expand All @@ -33,31 +33,14 @@
{
"id": "aggregator",
"name": "InTimeAccumulateWeightedAggregator",
"args": {
"aggregation_weights": {
"client_MSD": 1.0,
"client_I2CVB": 1.0,
"client_NCI_ISBI_3T": 1.0,
"client_NCI_ISBI_Dx": 1.0
}
}
"args": {}
},
{
"id": "model_selector",
"name": "IntimeModelSelectionHandler",
"args": {}
},
{
"id": "model_locator",
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
"weigh_by_local_iter": true
}
},
{
"id": "json_generator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
Expand All @@ -75,15 +58,6 @@
"train_task_name": "train",
"train_timeout": 0
}
},
{
"id": "global_model_eval",
"name": "GlobalModelEval",
"args": {
"model_locator_id": "model_locator",
"validation_timeout": 6000,
"cleanup_models": true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"ref_learning_rate": 1e-2,
"learning_rate": 1e-2,
"learning_rate": 1e-3,
"ditto_learning_rate": 1e-3,
"ditto_lambda": 1.0,
"fedproxloss_mu": 0.0,
"cache_dataset": 0.0,
"dataset_base_dir": "PWD/data_preparation/dataset",
"datalist_json_path": "PWD/data_preparation/datalists/client_All.json"
"cache_dataset": 1.0,
"dataset_base_dir": "PWD/data_preparation/dataset_2D",
"datalist_json_path": "PWD/data_preparation/datalist/client_All.json"
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"executors": [
{
"tasks": [
"train", "submit_model", "validate"
"train", "validate"
],
"executor": {
"id": "Executor",
Expand All @@ -24,10 +24,10 @@
"components": [
{
"id": "prostate-learner",
"path": "pt.learners.supervised_prostate_learner.SupervisedProstateLearner",
"path": "pt.learners.supervised_monai_prostate_learner.SupervisedMonaiProstateLearner",
"args": {
"aggregation_epochs": 10,
"train_config_filename": "config_train.json"
"train_config_filename": "config_train.json",
"aggregation_epochs": 1
}
}
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"format_version": 2,
"min_clients": 4,
"num_rounds": 100,
"min_clients": 6,
"num_rounds": 150,
"server": {
"heart_beat_timeout": 600
},
Expand All @@ -15,7 +15,7 @@
"model": {
"path": "monai.networks.nets.unet.UNet",
"args": {
"dimensions": 3,
"dimensions": 2,
"in_channels": 1,
"out_channels": 1,
"channels": [16, 32, 64, 128, 256],
Expand All @@ -33,31 +33,14 @@
{
"id": "aggregator",
"name": "InTimeAccumulateWeightedAggregator",
"args": {
"aggregation_weights": {
"client_MSD": 1.0,
"client_I2CVB": 1.0,
"client_NCI_ISBI_3T": 1.0,
"client_NCI_ISBI_Dx": 1.0
}
}
"args": {}
},
{
"id": "model_selector",
"name": "IntimeModelSelectionHandler",
"args": {}
},
{
"id": "model_locator",
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
"weigh_by_local_iter": true
}
},
{
"id": "json_generator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
"workflows": [
Expand All @@ -75,15 +58,6 @@
"train_task_name": "train",
"train_timeout": 0
}
},
{
"id": "global_model_eval",
"name": "GlobalModelEval",
"args": {
"model_locator_id": "model_locator",
"validation_timeout": 6000,
"cleanup_models": true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"learning_rate": 1e-2,
"learning_rate": 1e-3,
"fedproxloss_mu": 0.0,
"cache_dataset": 0.0,
"dataset_base_dir": "PWD/data_preparation/dataset",
"datalist_json_path": "PWD/data_preparation/datalists/client_All.json"
"cache_dataset": 1.0,
"dataset_base_dir": "PWD/data_preparation/dataset_2D",
"datalist_json_path": "PWD/data_preparation/datalist/client_All.json"
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"executors": [
{
"tasks": [
"train", "submit_model", "validate"
"train", "validate"
],
"executor": {
"id": "Executor",
Expand All @@ -24,10 +24,10 @@
"components": [
{
"id": "prostate-learner",
"path": "pt.learners.supervised_prostate_learner.SupervisedProstateLearner",
"path": "pt.learners.supervised_monai_prostate_learner.SupervisedMonaiProstateLearner",
"args": {
"aggregation_epochs": 10,
"train_config_filename": "config_train.json"
"train_config_filename": "config_train.json",
"aggregation_epochs": 1
}
}
]
Expand Down
Loading

0 comments on commit 9f2c8b9

Please sign in to comment.