ADD file via upload
This commit is contained in:
parent
b378fecc74
commit
d8700dbf4c
|
@ -0,0 +1,99 @@
|
|||
import numpy as np
|
||||
from models.Update import *
|
||||
from models.Fed import *
|
||||
from models.test import *
|
||||
from utils.utilis import *
|
||||
import copy
|
||||
|
||||
def test_with_loss(net_glob, dataset_test, args):
|
||||
|
||||
# testing
|
||||
acc_test, loss_test = test_img(net_glob, dataset_test, args)
|
||||
|
||||
print("Testing accuracy: {:.2f}".format(acc_test))
|
||||
|
||||
return acc_test.item(), loss_test
|
||||
|
||||
def split_list_into_365_parts(input_list):
|
||||
n = len(input_list) # 获取输入列表的长度
|
||||
# 计算每个子列表的基本长度
|
||||
base_size = n // 365
|
||||
# 计算需要额外添加一个元素的子列表的数量
|
||||
remainder = n % 365
|
||||
|
||||
result = []
|
||||
start = 0 # 开始索引
|
||||
|
||||
for i in range(365):
|
||||
# 如果索引小于余数,当前子列表多一个元素
|
||||
end = start + base_size + (1 if i < remainder else 0)
|
||||
# 切分列表并添加到结果中
|
||||
result.append(input_list[start:end])
|
||||
start = end # 更新下一个子列表的开始索引
|
||||
|
||||
return result
|
||||
|
||||
def Edge(args, net_glob, dataset_train, dataset_test, dict_users, hosp_Model_datavolume, center_assignments):
|
||||
net_glob.train()
|
||||
|
||||
print("Start Edge")
|
||||
|
||||
num_hosp = args.num_centers
|
||||
|
||||
isNext = True
|
||||
today = 1
|
||||
m = 1
|
||||
|
||||
ratio = [0.46, 0.31,0.12,0.11]
|
||||
|
||||
# training
|
||||
acc = []
|
||||
GB = []
|
||||
transf = sum(hosp_Model_datavolume)
|
||||
total_transf = 0
|
||||
|
||||
w_glob = net_glob.state_dict()
|
||||
|
||||
# dict_day = [split_list_into_365_parts(dict_hosp[idx]) for idx in range(len(dict_hosp))]
|
||||
|
||||
while isNext:
|
||||
if today >= args.physical_time:
|
||||
isNext = False
|
||||
# print("*" * 80)
|
||||
print("today{:3d}".format(today))
|
||||
# day7 = list(range(today,today+7))
|
||||
for idx_hosp in range(num_hosp):
|
||||
w_locals = []
|
||||
lens = []
|
||||
net_glob.load_state_dict(w_glob)
|
||||
l = 0
|
||||
client_train_list = random.sample(
|
||||
center_assignments[idx_hosp], m
|
||||
)
|
||||
# for day in day7:
|
||||
l = 0
|
||||
for idx_client in client_train_list:
|
||||
local = LocalUpdate(
|
||||
args=args, dataset=dataset_train, idxs=dict_users[idx_client]
|
||||
)
|
||||
for iter in range(args.epochs):
|
||||
w = local.train(net=copy.deepcopy(net_glob).to(args.device))
|
||||
net_glob.load_state_dict(w)
|
||||
l += len(dict_users[idx_client])
|
||||
w = net_glob.state_dict()
|
||||
w_locals.append(copy.deepcopy(w))
|
||||
l *= ratio[idx_hosp]
|
||||
lens.append(l)
|
||||
# lens.append(len(dict_hosp[idx_hosp]))
|
||||
today+=1
|
||||
# update global weights
|
||||
w_glob = Aggregation(w_locals, lens)
|
||||
total_transf += transf
|
||||
|
||||
item_acc, item_loss = test_with_loss(net_glob, dataset_test, args)
|
||||
ta, tl = test_with_loss(net_glob, dataset_train, args)
|
||||
acc.append(item_acc)
|
||||
GB.append(total_transf)
|
||||
|
||||
save_result(acc, "test_acc", args)
|
||||
save_result(GB, "test_GB", args)
|
Loading…
Reference in New Issue