diff --git a/examples/hello-world/hello-pt-resnet/README.md b/examples/hello-world/hello-pt-resnet/README.md index f023e81013..4b1a0f51a6 100644 --- a/examples/hello-world/hello-pt-resnet/README.md +++ b/examples/hello-world/hello-pt-resnet/README.md @@ -10,7 +10,9 @@ instead of the SimpleNetwork. The Job API only supports the object instance created directly out of the Python Class. It does not support the object instance created through using the Python function. Comparing with the hello-pt example, if we replace the SimpleNetwork() object with the resnet18(num_classes=10), -the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. The job API can +the "resnet18(num_classes=10)" creates an torchvision "ResNet" object instance out of the "resnet18" function. +As shown in the [torchvision reset](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L684-L705), +the resnet18 is a Python function, which creates and returns a ResNet object. The job API can only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18". This example demonstrates how to wrap up the resnet18 Python function into a Resnet18 Python class. Then uses the Resnet18(num_classes=10) diff --git a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py index 7d1788014b..a4955ed5a0 100644 --- a/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py +++ b/examples/hello-world/hello-pt-resnet/fedavg_script_runner_pt.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from src.simple_network import Resnet18 +from src.resnet_18 import Resnet18 from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob from nvflare.job_config.script_runner import ScriptRunner diff --git a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py index 20d43fa574..cfd12ba947 100644 --- a/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt-resnet/src/hello-pt_cifar10_fl.py @@ -15,7 +15,7 @@ import os import torch -from simple_network import Resnet18 +from resnet_18 import Resnet18 from torch import nn from torch.optim import SGD from torch.utils.data.dataloader import DataLoader diff --git a/examples/hello-world/hello-pt-resnet/src/simple_network.py b/examples/hello-world/hello-pt-resnet/src/resnet_18.py similarity index 100% rename from examples/hello-world/hello-pt-resnet/src/simple_network.py rename to examples/hello-world/hello-pt-resnet/src/resnet_18.py