ADD file via upload

This commit is contained in:
Tuberrr 2024-08-03 00:45:23 +08:00
parent 38516c1b46
commit c8d30a694f
1 changed files with 59 additions and 0 deletions

59
Update.py Normal file
View File

@ -0,0 +1,59 @@
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import numpy as np
import random
class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = list(idxs)
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return image, label
class LocalUpdate(object):
def __init__(self, args, dataset=None, idxs=None, verbose=False):
self.args = args
self.loss_func = nn.CrossEntropyLoss()
self.selected_clients = []
self.ldr_train = DataLoader(
DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True
)
self.verbose = verbose
def train(self, net):
net.train()
# train and update
optimizer = torch.optim.SGD(
net.parameters(), lr=self.args.lr, momentum=self.args.momentum
)
Predict_loss = 0
for iter in range(self.args.local_ep):
for batch_idx, (images, labels) in enumerate(self.ldr_train):
images, labels = images.to(self.args.device), labels.to(
self.args.device
)
net.zero_grad()
log_probs = net(images)["output"]
loss = self.loss_func(log_probs, labels)
loss.backward()
optimizer.step()
Predict_loss += loss.item()
if self.verbose:
info = "\nUser predict Loss={:.4f}".format(
Predict_loss / (self.args.local_ep * len(self.ldr_train))
)
print(info)
torch.cuda.empty_cache()
return net.state_dict()