ADD file via upload

This commit is contained in:
Tuberrr 2024-08-03 00:38:35 +08:00
parent 7050d141a2
commit 50e90cd4ef
1 changed files with 45 additions and 0 deletions

45
main.py Normal file
View File

@ -0,0 +1,45 @@
import torch
from utils.set_seed import *
from utils.options import args_parser
from utils.get_dataset import *
from models.Nets import *
from Algorithm.CLoud import *
from Algorithm.Edge import *
def map_users_to_centers(hosp_datavolume):
cnt = 0
center_assignments = [[] for _ in range(len(hosp_datavolume))]
for idx_center in range(len(hosp_datavolume)):
center_assignments[idx_center] = list(range(cnt,cnt+hosp_datavolume[idx_center]))
cnt += hosp_datavolume[idx_center]
return center_assignments
if __name__ == "__main__":
# parse args解析命令行 #加一个数据中心数量
args = args_parser()
args.device = torch.device(
"cuda:{}".format(args.gpu)
if torch.cuda.is_available()
else "cpu"
)
set_random_seed(1)
hosp_datavolume = [75,50,20,16]
hosp_Model_datavolume = [5,2,0.7,1]
dataset_train, dataset_test, dict_users = get_dataset(args, hosp_datavolume)
center_assignments = map_users_to_centers(hosp_datavolume)
dict_public = get_public(center_assignments)
net_glob = CNNCifar10()
net_glob.to(args.device)
print(net_glob) # 初始化模型对象
if args.algorithm == "Edge":
Edge(args, net_glob, dataset_train, dataset_test, dict_users ,hosp_Model_datavolume, center_assignments)
elif args.algorithm == "Cloud":
Cloud(args, net_glob, dataset_train, dataset_test, dict_users, dict_public)