-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
71 lines (52 loc) · 1.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from args import args
import random
import numpy as np
import pathlib
import torch
import data
from FL_train import *
def main():
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Make the a directory corresponding to this run for saving results, checkpoints etc.
i = 0
while True:
run_base_dir = pathlib.Path(f"{args.log_dir}/FRL~try={str(i)}")
if not run_base_dir.exists():
os.makedirs(run_base_dir)
args.name = args.name + f"~try={i}"
break
i += 1
(run_base_dir / "output.txt").write_text(str(args))
args.run_base_dir = run_base_dir
print(f"=> Saving data in {run_base_dir}")
#distribute the dataset
print ("dataset to use is: ", args.set)
print ("number of FL clients: ", args.nClients)
print ("non-iid degree data distribution: ", args.non_iid_degree)
print ("batch size is : ", args.batch_size)
print ("test batch size is: ", args.test_batch_size)
data_distributer = getattr(data, args.set)()
tr_loaders = data_distributer.get_tr_loaders()
te_loader = data_distributer.get_te_loader()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()
print ("use_cuda: ", use_cuda)
#Federated Learning
print ("type of FL: ", args.FL_type)
if args.FL_type == "FRL":
FRL_train(tr_loaders, te_loader)
elif args.FL_type == "FedAVG":
FedAVG(tr_loaders, te_loader)
elif args.FL_type == "trimmedMean":
Tr_Mean(tr_loaders, te_loader)
elif args.FL_type == "Mkrum":
Mkrum(tr_loaders, te_loader)
else:
FedAVG(tr_loaders, te_loader)
if __name__ == "__main__":
main()