68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
import numpy as np
|
|
from models.Update import *
|
|
from models.Fed import *
|
|
from models.test import *
|
|
from utils.utilis import *
|
|
import copy
|
|
|
|
def test_with_loss(net_glob, dataset_test, args):
|
|
|
|
# testing
|
|
acc_test, loss_test = test_img(net_glob, dataset_test, args)
|
|
|
|
print("Testing accuracy: {:.2f}".format(acc_test))
|
|
|
|
return acc_test.item(), loss_test
|
|
|
|
def Cloud(args, net_glob, dataset_train, dataset_test, dict_users, dict_public):
|
|
net_glob.train()
|
|
|
|
print("Start Cloud")
|
|
|
|
isNext = True
|
|
today = 1
|
|
m = 4
|
|
|
|
# training
|
|
acc = []
|
|
GB = []
|
|
transf = 161
|
|
total_transf = 0
|
|
|
|
while isNext:
|
|
if today >= args.physical_time:
|
|
isNext = False
|
|
client_train_list = random.sample(
|
|
dict_public, m
|
|
)
|
|
# print("*" * 80)
|
|
print("today: {:3d}".format(today))
|
|
|
|
for iter in range(args.epochs):
|
|
|
|
w_locals = []
|
|
lens = []
|
|
max_time = 0
|
|
for idx_client in client_train_list:
|
|
local = LocalUpdate(
|
|
args=args, dataset=dataset_train, idxs=dict_users[idx_client]
|
|
)
|
|
w = local.train(net=copy.deepcopy(net_glob).to(args.device))
|
|
|
|
w_locals.append(copy.deepcopy(w))
|
|
lens.append(len(dict_users[idx_client]))
|
|
|
|
w_glob = Aggregation(w_locals, lens)
|
|
# copy weight to net_glob
|
|
net_glob.load_state_dict(w_glob)
|
|
|
|
today += 1
|
|
total_transf += transf
|
|
|
|
item_acc, item_loss = test_with_loss(net_glob, dataset_test, args)
|
|
ta, tl = test_with_loss(net_glob, dataset_train, args)
|
|
acc.append(item_acc)
|
|
GB.append(total_transf)
|
|
|
|
save_result(acc, "test_acc", args)
|
|
save_result(GB, "test_GB", args) |