ADD file via upload

This commit is contained in:
Tuberrr 2024-08-03 00:42:05 +08:00
parent b378fecc74
commit d8700dbf4c
1 changed files with 99 additions and 0 deletions

99
Edge.py Normal file
View File

@ -0,0 +1,99 @@
import numpy as np
from models.Update import *
from models.Fed import *
from models.test import *
from utils.utilis import *
import copy
def test_with_loss(net_glob, dataset_test, args):
# testing
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print("Testing accuracy: {:.2f}".format(acc_test))
return acc_test.item(), loss_test
def split_list_into_365_parts(input_list):
n = len(input_list) # 获取输入列表的长度
# 计算每个子列表的基本长度
base_size = n // 365
# 计算需要额外添加一个元素的子列表的数量
remainder = n % 365
result = []
start = 0 # 开始索引
for i in range(365):
# 如果索引小于余数,当前子列表多一个元素
end = start + base_size + (1 if i < remainder else 0)
# 切分列表并添加到结果中
result.append(input_list[start:end])
start = end # 更新下一个子列表的开始索引
return result
def Edge(args, net_glob, dataset_train, dataset_test, dict_users, hosp_Model_datavolume, center_assignments):
net_glob.train()
print("Start Edge")
num_hosp = args.num_centers
isNext = True
today = 1
m = 1
ratio = [0.46, 0.31,0.12,0.11]
# training
acc = []
GB = []
transf = sum(hosp_Model_datavolume)
total_transf = 0
w_glob = net_glob.state_dict()
# dict_day = [split_list_into_365_parts(dict_hosp[idx]) for idx in range(len(dict_hosp))]
while isNext:
if today >= args.physical_time:
isNext = False
# print("*" * 80)
print("today{:3d}".format(today))
# day7 = list(range(today,today+7))
for idx_hosp in range(num_hosp):
w_locals = []
lens = []
net_glob.load_state_dict(w_glob)
l = 0
client_train_list = random.sample(
center_assignments[idx_hosp], m
)
# for day in day7:
l = 0
for idx_client in client_train_list:
local = LocalUpdate(
args=args, dataset=dataset_train, idxs=dict_users[idx_client]
)
for iter in range(args.epochs):
w = local.train(net=copy.deepcopy(net_glob).to(args.device))
net_glob.load_state_dict(w)
l += len(dict_users[idx_client])
w = net_glob.state_dict()
w_locals.append(copy.deepcopy(w))
l *= ratio[idx_hosp]
lens.append(l)
# lens.append(len(dict_hosp[idx_hosp]))
today+=1
# update global weights
w_glob = Aggregation(w_locals, lens)
total_transf += transf
item_acc, item_loss = test_with_loss(net_glob, dataset_test, args)
ta, tl = test_with_loss(net_glob, dataset_train, args)
acc.append(item_acc)
GB.append(total_transf)
save_result(acc, "test_acc", args)
save_result(GB, "test_GB", args)