attention op & Sequential with OrderedDict

This commit is contained in:
Gword 2020-08-28 20:48:10 +08:00
parent 7e1678c7ea
commit 7b41ab678f
5 changed files with 285 additions and 10 deletions

View File

@ -512,11 +512,10 @@ class Module:
end = 0
for k in key_:
if isinstance(v, nn.Sequential):
if ori_int(k) >= len(v.layers):
end = 1
v = v[k]
if v is None:
end=1
break
else:
v = v[ori_int(k)]
else:
if hasattr(v, k):
v = getattr(v, k)

175
python/jittor/attention.py Normal file
View File

@ -0,0 +1,175 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import init, Module, nn
import numpy as np
import math
class MultiheadAttention(Module):
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
assert dropout==0, "TODO: dropout>0"
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size")
#TODO: quant_noise
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
assert not add_bias_kv, "TODO: add_bias_kv=True"
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters()
self.onnx_trace = False
self.tpu = False
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
init.xavier_uniform_(self.k_proj.weight)
init.xavier_uniform_(self.v_proj.weight)
init.xavier_uniform_(self.q_proj.weight)
# init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
init.xavier_normal_(self.bias_v)
def execute(
self,
query,
key = None,
value = None,
key_padding_mask = None,
incremental_state = None,
need_weights = True,
static_kv = False,
attn_mask = None,
before_softmax = False,
need_head_weights = False,
):
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert incremental_state is None, "TODO: incremental_state is not None"
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = q*self.scaling
assert self.bias_k is None, "TODO: self.bias_k is not None:"
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if k is not None:
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
if v is not None:
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2)
assert saved_state is None, "TODO: saved_state is not None"
assert k is not None
src_len = k.shape[1]
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
assert not self.add_zero_attn, "TODO: self.add_zero_attn=True"
attn_weights = nn.bmm(q, k.transpose(0, 2, 1))
assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
assert attn_mask is None, "TODO: attn_mask is not None"
assert key_padding_mask is None, "TODO: key_padding_mask is not None"
if before_softmax:
return attn_weights, v
attn_weights_float = nn.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
assert v is not None
attn = nn.bmm(attn_weights, v)
assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.shape[1] == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights = None
if need_weights:
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dims=[0])
return attn, attn_weights

View File

@ -55,3 +55,30 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"):
def relu_invariant_gauss_(var, mode="fan_in"):
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
#TODO: bound = gain * math.sqrt(6.0/fan) ??
def xavier_uniform(shape, dtype, gain=1.0):
assert len(shape)>1
matsize=1
for i in shape[2:]:
matsize *= i
fan = (shape[1] * matsize) + (shape[0] * matsize)
bound = gain * math.sqrt(1.0/fan)
return uniform(shape, dtype, -bound, bound)
def xavier_uniform_(var, gain=1.0):
var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain))
def xavier_gauss(shape, dtype, gain=1.0):
assert len(shape)>1
matsize=1
for i in shape[2:]:
matsize *= i
fan = (shape[1] * matsize) + (shape[0] * matsize)
std = gain * math.sqrt(2.0/fan)
return gauss(shape, dtype, 0, std)
def xavier_gauss_(var, gain=1.0):
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))

View File

@ -13,6 +13,7 @@
import jittor as jt
from jittor import init, Module
import numpy as np
import collections
import math
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
from jittor.optim import *
@ -780,13 +781,17 @@ class Upsample(Module):
class Sequential(Module):
def __init__(self, *args):
self.layers = []
self.layers = collections.OrderedDict()
for mod in args:
self.append(mod)
if isinstance(mod, collections.OrderedDict):
for k, m in mod.items():
self.add_module(k, m)
else:
self.append(mod)
def __getitem__(self, idx):
return self.layers[idx]
def execute(self, x):
for layer in self.layers:
for k, layer in self.layers.items():
x = layer(x)
return x
def dfs(self, parents, k, callback, callback_leave):
@ -794,7 +799,7 @@ class Sequential(Module):
ret = callback(parents, k, self, n_children)
if ret == False:
return
for k,v in enumerate(self.layers):
for k,v in self.layers.items():
parents.append(self)
v.dfs(parents, k, callback, callback_leave)
parents.pop()
@ -803,6 +808,10 @@ class Sequential(Module):
def append(self, mod):
assert callable(mod), f"Module <{type(mod)}> is not callable"
assert not isinstance(mod, type), f"Module is not a type"
self.layers.append(mod)
self.layers[len(self.layers)]=mod
def add_module(self, name, mod):
assert callable(mod), f"Module <{type(mod)}> is not callable"
assert not isinstance(mod, type), f"Module is not a type"
self.layers[name]=mod
ModuleList = Sequential

View File

@ -0,0 +1,65 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import jittor.attention as jtatt
import numpy as np
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
import fairseq
except:
torch = None
tnn = None
skip_this_test = True
def check_equal(q,k,v,tatt,jatt):
tq=torch.from_numpy(q)
jq=jt.array(q)
tk=torch.from_numpy(k)
jk=jt.array(k)
tv=torch.from_numpy(v)
jv=jt.array(v)
jatt.load_parameters(tatt.state_dict())
ty, tw = tatt(tq, tk, tv)
jy, jw = jatt(jq, jk, jv)
assert np.allclose(ty.detach().numpy(), jy.numpy(), rtol=1e-3)
assert np.allclose(tw.detach().numpy(), jw.numpy(), rtol=1e-3)
@unittest.skipIf(skip_this_test, "No Torch found")
class TestAttention(unittest.TestCase):
def test_attention(self):
q=np.random.rand(4,8,16).astype(np.float32)
k=np.random.rand(4,8,16).astype(np.float32)
v=np.random.rand(4,8,16).astype(np.float32)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1)
jatt=jt.attention.MultiheadAttention(16,1)
check_equal(q,k,v,tatt,jatt)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4)
jatt=jt.attention.MultiheadAttention(16,4)
check_equal(q,k,v,tatt,jatt)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1,self_attention=True)
jatt=jt.attention.MultiheadAttention(16,1,self_attention=True)
check_equal(q,q,q,tatt,jatt)
tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4,self_attention=True)
jatt=jt.attention.MultiheadAttention(16,4,self_attention=True)
check_equal(q,q,q,tatt,jatt)
if __name__ == "__main__":
unittest.main()