mirror of https://github.com/Jittor/Jittor
update embeding
This commit is contained in:
parent
d79bd8e09c
commit
60efbfd461
|
@ -513,6 +513,16 @@ class ReplicationPad2d(Module):
|
|||
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
|
||||
])
|
||||
|
||||
class Embedding(Module):
|
||||
def __init__(self, num, dim):
|
||||
self.num = num
|
||||
self.dim = dim
|
||||
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
res = self.weight[x].reshape([x.shape[0],self.dim])
|
||||
return res
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
self.upscale_factor = upscale_factor
|
||||
|
|
Loading…
Reference in New Issue