ADD file via upload
This commit is contained in:
parent
7050d141a2
commit
50e90cd4ef
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
from utils.set_seed import *
|
||||
from utils.options import args_parser
|
||||
from utils.get_dataset import *
|
||||
from models.Nets import *
|
||||
from Algorithm.CLoud import *
|
||||
from Algorithm.Edge import *
|
||||
|
||||
def map_users_to_centers(hosp_datavolume):
|
||||
cnt = 0
|
||||
center_assignments = [[] for _ in range(len(hosp_datavolume))]
|
||||
for idx_center in range(len(hosp_datavolume)):
|
||||
center_assignments[idx_center] = list(range(cnt,cnt+hosp_datavolume[idx_center]))
|
||||
cnt += hosp_datavolume[idx_center]
|
||||
|
||||
return center_assignments
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse args解析命令行 #加一个数据中心数量
|
||||
args = args_parser()
|
||||
args.device = torch.device(
|
||||
"cuda:{}".format(args.gpu)
|
||||
if torch.cuda.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
set_random_seed(1)
|
||||
|
||||
hosp_datavolume = [75,50,20,16]
|
||||
hosp_Model_datavolume = [5,2,0.7,1]
|
||||
|
||||
dataset_train, dataset_test, dict_users = get_dataset(args, hosp_datavolume)
|
||||
|
||||
center_assignments = map_users_to_centers(hosp_datavolume)
|
||||
|
||||
dict_public = get_public(center_assignments)
|
||||
|
||||
net_glob = CNNCifar10()
|
||||
|
||||
net_glob.to(args.device)
|
||||
print(net_glob) # 初始化模型对象
|
||||
|
||||
if args.algorithm == "Edge":
|
||||
Edge(args, net_glob, dataset_train, dataset_test, dict_users ,hosp_Model_datavolume, center_assignments)
|
||||
elif args.algorithm == "Cloud":
|
||||
Cloud(args, net_glob, dataset_train, dataset_test, dict_users, dict_public)
|
Loading…
Reference in New Issue