ADD file via upload
This commit is contained in:
parent
38516c1b46
commit
c8d30a694f
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
from torch import nn, autograd
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class DatasetSplit(Dataset):
|
||||
def __init__(self, dataset, idxs):
|
||||
self.dataset = dataset
|
||||
self.idxs = list(idxs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idxs)
|
||||
|
||||
def __getitem__(self, item):
|
||||
image, label = self.dataset[self.idxs[item]]
|
||||
return image, label
|
||||
|
||||
|
||||
class LocalUpdate(object):
|
||||
def __init__(self, args, dataset=None, idxs=None, verbose=False):
|
||||
self.args = args
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
self.selected_clients = []
|
||||
self.ldr_train = DataLoader(
|
||||
DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True
|
||||
)
|
||||
self.verbose = verbose
|
||||
|
||||
def train(self, net):
|
||||
net.train()
|
||||
# train and update
|
||||
optimizer = torch.optim.SGD(
|
||||
net.parameters(), lr=self.args.lr, momentum=self.args.momentum
|
||||
)
|
||||
|
||||
Predict_loss = 0
|
||||
for iter in range(self.args.local_ep):
|
||||
for batch_idx, (images, labels) in enumerate(self.ldr_train):
|
||||
images, labels = images.to(self.args.device), labels.to(
|
||||
self.args.device
|
||||
)
|
||||
net.zero_grad()
|
||||
log_probs = net(images)["output"]
|
||||
loss = self.loss_func(log_probs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
Predict_loss += loss.item()
|
||||
if self.verbose:
|
||||
info = "\nUser predict Loss={:.4f}".format(
|
||||
Predict_loss / (self.args.local_ep * len(self.ldr_train))
|
||||
)
|
||||
print(info)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return net.state_dict()
|
Loading…
Reference in New Issue