59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
import torch
|
|
from torch import nn, autograd
|
|
from torch.utils.data import DataLoader, Dataset
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import random
|
|
|
|
|
|
class DatasetSplit(Dataset):
|
|
def __init__(self, dataset, idxs):
|
|
self.dataset = dataset
|
|
self.idxs = list(idxs)
|
|
|
|
def __len__(self):
|
|
return len(self.idxs)
|
|
|
|
def __getitem__(self, item):
|
|
image, label = self.dataset[self.idxs[item]]
|
|
return image, label
|
|
|
|
|
|
class LocalUpdate(object):
|
|
def __init__(self, args, dataset=None, idxs=None, verbose=False):
|
|
self.args = args
|
|
self.loss_func = nn.CrossEntropyLoss()
|
|
self.selected_clients = []
|
|
self.ldr_train = DataLoader(
|
|
DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True
|
|
)
|
|
self.verbose = verbose
|
|
|
|
def train(self, net):
|
|
net.train()
|
|
# train and update
|
|
optimizer = torch.optim.SGD(
|
|
net.parameters(), lr=self.args.lr, momentum=self.args.momentum
|
|
)
|
|
|
|
Predict_loss = 0
|
|
for iter in range(self.args.local_ep):
|
|
for batch_idx, (images, labels) in enumerate(self.ldr_train):
|
|
images, labels = images.to(self.args.device), labels.to(
|
|
self.args.device
|
|
)
|
|
net.zero_grad()
|
|
log_probs = net(images)["output"]
|
|
loss = self.loss_func(log_probs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
Predict_loss += loss.item()
|
|
if self.verbose:
|
|
info = "\nUser predict Loss={:.4f}".format(
|
|
Predict_loss / (self.args.local_ep * len(self.ldr_train))
|
|
)
|
|
print(info)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
return net.state_dict() |