ADD file via upload

This commit is contained in:
Tuberrr 2024-08-03 00:45:38 +08:00
parent c8d30a694f
commit 297c94eff1
1 changed files with 96 additions and 0 deletions

96
get_dataset.py Normal file
View File

@ -0,0 +1,96 @@
from torchvision import datasets, transforms
import numpy as np
import random
from collections import Counter
def separate_data(train_data, num_clients, num_classes, beta=0.4):
y_train = np.array(train_data.targets)
min_size_train = 0
min_require_size = 10
K = num_classes
N_train = len(y_train)
dict_users_train = {}
while min_size_train < min_require_size:
idx_batch_train = [[] for _ in range(num_clients)]
idx_batch_test = [[] for _ in range(num_clients)]
for k in range(K):
idx_k_train = np.where(y_train == k)[0]
np.random.shuffle(idx_k_train)
proportions = np.random.dirichlet(np.repeat(beta, num_clients))
proportions_train = np.array([p * (len(idx_j) < N_train / num_clients) for p, idx_j in zip(proportions, idx_batch_train)])
proportions_train = proportions_train / proportions_train.sum()
proportions_train = (np.cumsum(proportions_train) * len(idx_k_train)).astype(int)[:-1]
idx_batch_train = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch_train, np.split(idx_k_train, proportions_train))]
min_size_train = min([len(idx_j) for idx_j in idx_batch_train])
# if K == 2 and n_parties <= 10:
# if np.min(proportions) < 200:
# min_size = 0
# break
for j in range(num_clients):
np.random.shuffle(idx_batch_train[j])
dict_users_train[j] = idx_batch_train[j]
return dict_users_train
def get_public(center_assignments):
dict_public = []
for idx_center in range(len(center_assignments)):
num_elements_to_sample = int(len(center_assignments[idx_center]) * 0.6)
dict_public += random.sample(center_assignments[idx_center], num_elements_to_sample)
return dict_public
def get_hosp(dataset, hosp_datavolume):
all_idxs = [i for i in range(len(dataset))]
sum_hosp_datavolume = sum(hosp_datavolume)
dict_users = separate_data(
dataset, sum_hosp_datavolume, 10, 1.0
)
dict_hosp = [[] for _ in range(len(hosp_datavolume))]
idx_hosp = 0
cnt = 0
for idx_dict, dict in dict_users.items():
dict_hosp[idx_hosp] += dict
cnt += 1
if cnt == hosp_datavolume[idx_hosp]:
cnt = 0
idx_hosp += 1
return dict_hosp
def get_dataset(args, hosp_datavolume):
trans_cifar10_train = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
trans_cifar10_val = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dataset_train = datasets.CIFAR10(
"./data/cifar10", train=True, download=True, transform=trans_cifar10_train
)
dataset_test = datasets.CIFAR10(
"./data/cifar10", train=False, download=True, transform=trans_cifar10_val
)
# dict_public = get_public(dataset_train, ratio = 0.6)
# dict_hosp = get_hosp(dataset_train, hosp_datavolume)
dict_users = separate_data(dataset_train, args.num_users, args.num_classes, args.data_beta)
return dataset_train, dataset_test, dict_users