ADD file via upload
This commit is contained in:
parent
5a0a1015bb
commit
6851c04e3b
|
@ -0,0 +1,27 @@
|
|||
import copy
|
||||
import torch
|
||||
|
||||
|
||||
def Aggregation(w, lens):
|
||||
w_avg = None
|
||||
if lens == None:
|
||||
total_count = len(w)
|
||||
lens = []
|
||||
for i in range(len(w)):
|
||||
lens.append(1.0)
|
||||
else:
|
||||
total_count = sum(lens)
|
||||
|
||||
for i in range(0, len(w)):
|
||||
if i == 0:
|
||||
w_avg = copy.deepcopy(w[0])
|
||||
for k in w_avg.keys():
|
||||
w_avg[k] = w[i][k] * lens[i]
|
||||
else:
|
||||
for k in w_avg.keys():
|
||||
w_avg[k] += w[i][k] * lens[i]
|
||||
|
||||
for k in w_avg.keys():
|
||||
w_avg[k] = torch.div(w_avg[k], total_count)
|
||||
|
||||
return w_avg
|
Loading…
Reference in New Issue