update embeding

This commit is contained in:
zwy 2020-05-15 15:36:12 +08:00
parent d79bd8e09c
commit 60efbfd461
1 changed files with 10 additions and 0 deletions

View File

@ -513,6 +513,16 @@ class ReplicationPad2d(Module):
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
])
class Embedding(Module):
def __init__(self, num, dim):
self.num = num
self.dim = dim
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
def execute(self, x):
res = self.weight[x].reshape([x.shape[0],self.dim])
return res
class PixelShuffle(Module):
def __init__(self, upscale_factor):
self.upscale_factor = upscale_factor