Skip to content

Commit

Permalink
codestyle fix for hello-pt-resnet example.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen committed Sep 20, 2024
1 parent d5c210a commit e442d1b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from src.simple_network import Resnet18

from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
from nvflare.job_config.script_runner import ScriptRunner
from src.simple_network import Resnet18

if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "src/hello-pt_cifar10_fl.py"

job = FedAvgJob(
name="hello-pt_cifar10_fedavg", n_clients=n_clients, num_rounds=num_rounds,
initial_model=Resnet18(num_classes=10)
name="hello-pt_cifar10_fedavg",
n_clients=n_clients,
num_rounds=num_rounds,
initial_model=Resnet18(num_classes=10),
)

# Add clients
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os

import torch
from simple_network import Resnet18
from torch import nn
from torch.optim import SGD
from torch.utils.data.dataloader import DataLoader
Expand All @@ -23,7 +24,6 @@

import nvflare.client as flare
from nvflare.client.tracking import SummaryWriter
from simple_network import Resnet18

DATASET_PATH = "/tmp/nvflare/data"

Expand Down
3 changes: 1 addition & 2 deletions examples/hello-world/hello-pt-resnet/src/simple_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Any
from typing import Any, Optional

from torchvision.models import ResNet
from torchvision.models._utils import _ovewrite_named_param
from torchvision.models.resnet import BasicBlock, ResNet18_Weights


class Resnet18(ResNet):

def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any):
self.num_classes = num_classes

Expand Down

0 comments on commit e442d1b

Please sign in to comment.