attention op & Sequential with OrderedDict

This commit is contained in:
Gword 2020-08-28 20:48:10 +08:00
parent 7e1678c7ea
commit 7b41ab678f
5 changed files with 285 additions and 10 deletions

View File

@ -512,11 +512,10 @@ class Module:
end = 0 end = 0
for k in key_: for k in key_:
if isinstance(v, nn.Sequential): if isinstance(v, nn.Sequential):
if ori_int(k) >= len(v.layers): v = v[k]
end = 1 if v is None:
end=1
break break
else:
v = v[ori_int(k)]
else: else:
if hasattr(v, k): if hasattr(v, k):
v = getattr(v, k) v = getattr(v, k)

175
python/jittor/attention.py Normal file
View File

@ -0,0 +1,175 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import init, Module, nn
import numpy as np
import math
class MultiheadAttention(Module):
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
assert dropout==0, "TODO: dropout>0"
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size")
#TODO: quant_noise
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
assert not add_bias_kv, "TODO: add_bias_kv=True"
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.onnx_trace = False
self.tpu = False
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
init.xavier_uniform_(self.k_proj.weight)
init.xavier_uniform_(self.v_proj.weight)
init.xavier_uniform_(self.q_proj.weight)
# init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
init.xavier_normal_(self.bias_v)
def execute(
self,
query,
key = None,
value = None,
key_padding_mask = None,
incremental_state = None,
need_weights = True,
static_kv = False,
attn_mask = None,
before_softmax = False,
need_head_weights = False,
):
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert incremental_state is None, "TODO: incremental_state is not None"
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = q*self.scaling
assert self.bias_k is None, "TODO: self.bias_k is not None:"
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if k is not None:
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if v is not None:
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
assert saved_state is None, "TODO: saved_state is not None"
assert k is not None
src_len = k.shape[1]
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
assert not self.add_zero_attn, "TODO: self.add_zero_attn=True"
attn_weights = nn.bmm(q, k.transpose(0, 2, 1))
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
assert attn_mask is None, "TODO: attn_mask is not None"
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
if before_softmax:
return attn_weights, v
attn_weights_float = nn.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
assert v is not None
attn = nn.bmm(attn_weights, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.shape[1] == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights = None
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dims=[0])
return attn, attn_weights

View File

@ -55,3 +55,30 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"):
def relu_invariant_gauss_(var, mode="fan_in"): def relu_invariant_gauss_(var, mode="fan_in"):
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode)) var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
#TODO: bound = gain * math.sqrt(6.0/fan) ??
def xavier_uniform(shape, dtype, gain=1.0):
assert len(shape)>1
matsize=1
for i in shape[2:]:
matsize *= i
fan = (shape[1] * matsize) + (shape[0] * matsize)
bound = gain * math.sqrt(1.0/fan)
return uniform(shape, dtype, -bound, bound)
def xavier_uniform_(var, gain=1.0):
var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain))
def xavier_gauss(shape, dtype, gain=1.0):
assert len(shape)>1
matsize=1
for i in shape[2:]:
matsize *= i
fan = (shape[1] * matsize) + (shape[0] * matsize)
std = gain * math.sqrt(2.0/fan)
return gauss(shape, dtype, 0, std)
def xavier_gauss_(var, gain=1.0):
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))

View File

@ -13,6 +13,7 @@
import jittor as jt import jittor as jt
from jittor import init, Module from jittor import init, Module
import numpy as np import numpy as np
import collections
import math import math
from jittor.pool import Pool, pool, AdaptiveAvgPool2d from jittor.pool import Pool, pool, AdaptiveAvgPool2d
from jittor.optim import * from jittor.optim import *
@ -780,13 +781,17 @@ class Upsample(Module):
class Sequential(Module): class Sequential(Module):
def __init__(self, *args): def __init__(self, *args):
self.layers = [] self.layers = collections.OrderedDict()
for mod in args: for mod in args:
self.append(mod) if isinstance(mod, collections.OrderedDict):
for k, m in mod.items():
self.add_module(k, m)
else:
self.append(mod)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.layers[idx] return self.layers[idx]
def execute(self, x): def execute(self, x):
for layer in self.layers: for k, layer in self.layers.items():
x = layer(x) x = layer(x)
return x return x
def dfs(self, parents, k, callback, callback_leave): def dfs(self, parents, k, callback, callback_leave):
@ -794,7 +799,7 @@ class Sequential(Module):
ret = callback(parents, k, self, n_children) ret = callback(parents, k, self, n_children)
if ret == False: if ret == False:
return return
for k,v in enumerate(self.layers): for k,v in self.layers.items():
parents.append(self) parents.append(self)
v.dfs(parents, k, callback, callback_leave) v.dfs(parents, k, callback, callback_leave)
parents.pop() parents.pop()
@ -803,6 +808,10 @@ class Sequential(Module):
def append(self, mod): def append(self, mod):
assert callable(mod), f"Module <{type(mod)}> is not callable" assert callable(mod), f"Module <{type(mod)}> is not callable"
assert not isinstance(mod, type), f"Module is not a type" assert not isinstance(mod, type), f"Module is not a type"
self.layers.append(mod) self.layers[len(self.layers)]=mod
def add_module(self, name, mod):
assert callable(mod), f"Module <{type(mod)}> is not callable"
assert not isinstance(mod, type), f"Module is not a type"
self.layers[name]=mod
ModuleList = Sequential ModuleList = Sequential

View File

@ -0,0 +1,65 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import jittor.attention as jtatt
import numpy as np
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
import fairseq
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(q,k,v,tatt,jatt):
tq=torch.from_numpy(q)
jq=jt.array(q)
tk=torch.from_numpy(k)
jk=jt.array(k)
tv=torch.from_numpy(v)
jv=jt.array(v)
jatt.load_parameters(tatt.state_dict())
ty, tw = tatt(tq, tk, tv)
jy, jw = jatt(jq, jk, jv)
assert np.allclose(ty.detach().numpy(), jy.numpy(), rtol=1e-3)
assert np.allclose(tw.detach().numpy(), jw.numpy(), rtol=1e-3)
@unittest.skipIf(skip_this_test, "No Torch found")
class TestAttention(unittest.TestCase):
def test_attention(self):
q=np.random.rand(4,8,16).astype(np.float32)
k=np.random.rand(4,8,16).astype(np.float32)
v=np.random.rand(4,8,16).astype(np.float32)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1)
jatt=jt.attention.MultiheadAttention(16,1)
check_equal(q,k,v,tatt,jatt)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4)
jatt=jt.attention.MultiheadAttention(16,4)
check_equal(q,k,v,tatt,jatt)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1,self_attention=True)
jatt=jt.attention.MultiheadAttention(16,1,self_attention=True)
check_equal(q,q,q,tatt,jatt)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4,self_attention=True)
jatt=jt.attention.MultiheadAttention(16,4,self_attention=True)
check_equal(q,q,q,tatt,jatt)
if __name__ == "__main__":
unittest.main()