From 7b41ab678ffd1e44708fcd972f3b43a3e1294ac5 Mon Sep 17 00:00:00 2001 From: Gword <471184555@qq.com> Date: Fri, 28 Aug 2020 20:48:10 +0800 Subject: [PATCH] attention op & Sequential with OrderedDict --- python/jittor/__init__.py | 7 +- python/jittor/attention.py | 175 +++++++++++++++++++++++++++ python/jittor/init.py | 29 ++++- python/jittor/nn.py | 19 ++- python/jittor/test/test_attention.py | 65 ++++++++++ 5 files changed, 285 insertions(+), 10 deletions(-) create mode 100644 python/jittor/attention.py create mode 100644 python/jittor/test/test_attention.py diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 7bc94c3b..4433d49e 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -512,11 +512,10 @@ class Module: end = 0 for k in key_: if isinstance(v, nn.Sequential): - if ori_int(k) >= len(v.layers): - end = 1 + v = v[k] + if v is None: + end=1 break - else: - v = v[ori_int(k)] else: if hasattr(v, k): v = getattr(v, k) diff --git a/python/jittor/attention.py b/python/jittor/attention.py new file mode 100644 index 00000000..f5aaede2 --- /dev/null +++ b/python/jittor/attention.py @@ -0,0 +1,175 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import jittor as jt +from jittor import init, Module, nn +import numpy as np +import math + +class MultiheadAttention(Module): + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + assert dropout==0, "TODO: dropout>0" + + self.head_dim = embed_dim // num_heads + assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size") + + #TODO: quant_noise + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + assert not add_bias_kv, "TODO: add_bias_kv=True" + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + self.tpu = False + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + init.xavier_uniform_(self.k_proj.weight) + init.xavier_uniform_(self.v_proj.weight) + init.xavier_uniform_(self.q_proj.weight) + + # init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + init.xavier_normal_(self.bias_v) + + def execute( + self, + query, + key = None, + value = None, + key_padding_mask = None, + incremental_state = None, + need_weights = True, + static_kv = False, + attn_mask = None, + before_softmax = False, + need_head_weights = False, + ): + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + + assert incremental_state is None, "TODO: incremental_state is not None" + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q = q*self.scaling + + assert self.bias_k is None, "TODO: self.bias_k is not None:" + + q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) + if k is not None: + k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) + if v is not None: + v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) + + assert saved_state is None, "TODO: saved_state is not None" + assert k is not None + src_len = k.shape[1] + + assert key_padding_mask is None, "TODO: key_padding_mask is not None" + assert not self.add_zero_attn, "TODO: self.add_zero_attn=True" + + attn_weights = nn.bmm(q, k.transpose(0, 2, 1)) + + assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + assert attn_mask is None, "TODO: attn_mask is not None" + assert key_padding_mask is None, "TODO: key_padding_mask is not None" + + if before_softmax: + return attn_weights, v + + attn_weights_float = nn.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + + assert v is not None + attn = nn.bmm(attn_weights, v) + assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.shape[1] == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights = None + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dims=[0]) + + return attn, attn_weights diff --git a/python/jittor/init.py b/python/jittor/init.py index b8256d64..b548311c 100644 --- a/python/jittor/init.py +++ b/python/jittor/init.py @@ -54,4 +54,31 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"): return gauss(shape, dtype, 0, std) def relu_invariant_gauss_(var, mode="fan_in"): - var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode)) \ No newline at end of file + var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode)) + +#TODO: bound = gain * math.sqrt(6.0/fan) ?? +def xavier_uniform(shape, dtype, gain=1.0): + assert len(shape)>1 + + matsize=1 + for i in shape[2:]: + matsize *= i + fan = (shape[1] * matsize) + (shape[0] * matsize) + bound = gain * math.sqrt(1.0/fan) + return uniform(shape, dtype, -bound, bound) + +def xavier_uniform_(var, gain=1.0): + var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain)) + +def xavier_gauss(shape, dtype, gain=1.0): + assert len(shape)>1 + + matsize=1 + for i in shape[2:]: + matsize *= i + fan = (shape[1] * matsize) + (shape[0] * matsize) + std = gain * math.sqrt(2.0/fan) + return gauss(shape, dtype, 0, std) + +def xavier_gauss_(var, gain=1.0): + var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain)) \ No newline at end of file diff --git a/python/jittor/nn.py b/python/jittor/nn.py index f12fdae7..efad9b4c 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -13,6 +13,7 @@ import jittor as jt from jittor import init, Module import numpy as np +import collections import math from jittor.pool import Pool, pool, AdaptiveAvgPool2d from jittor.optim import * @@ -780,13 +781,17 @@ class Upsample(Module): class Sequential(Module): def __init__(self, *args): - self.layers = [] + self.layers = collections.OrderedDict() for mod in args: - self.append(mod) + if isinstance(mod, collections.OrderedDict): + for k, m in mod.items(): + self.add_module(k, m) + else: + self.append(mod) def __getitem__(self, idx): return self.layers[idx] def execute(self, x): - for layer in self.layers: + for k, layer in self.layers.items(): x = layer(x) return x def dfs(self, parents, k, callback, callback_leave): @@ -794,7 +799,7 @@ class Sequential(Module): ret = callback(parents, k, self, n_children) if ret == False: return - for k,v in enumerate(self.layers): + for k,v in self.layers.items(): parents.append(self) v.dfs(parents, k, callback, callback_leave) parents.pop() @@ -803,6 +808,10 @@ class Sequential(Module): def append(self, mod): assert callable(mod), f"Module <{type(mod)}> is not callable" assert not isinstance(mod, type), f"Module is not a type" - self.layers.append(mod) + self.layers[len(self.layers)]=mod + def add_module(self, name, mod): + assert callable(mod), f"Module <{type(mod)}> is not callable" + assert not isinstance(mod, type), f"Module is not a type" + self.layers[name]=mod ModuleList = Sequential diff --git a/python/jittor/test/test_attention.py b/python/jittor/test/test_attention.py new file mode 100644 index 00000000..dba4b372 --- /dev/null +++ b/python/jittor/test/test_attention.py @@ -0,0 +1,65 @@ + +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import jittor.attention as jtatt +import numpy as np + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn + import fairseq +except: + torch = None + tnn = None + skip_this_test = True + +def check_equal(q,k,v,tatt,jatt): + tq=torch.from_numpy(q) + jq=jt.array(q) + tk=torch.from_numpy(k) + jk=jt.array(k) + tv=torch.from_numpy(v) + jv=jt.array(v) + + jatt.load_parameters(tatt.state_dict()) + ty, tw = tatt(tq, tk, tv) + jy, jw = jatt(jq, jk, jv) + assert np.allclose(ty.detach().numpy(), jy.numpy(), rtol=1e-3) + assert np.allclose(tw.detach().numpy(), jw.numpy(), rtol=1e-3) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestAttention(unittest.TestCase): + def test_attention(self): + q=np.random.rand(4,8,16).astype(np.float32) + k=np.random.rand(4,8,16).astype(np.float32) + v=np.random.rand(4,8,16).astype(np.float32) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1) + jatt=jt.attention.MultiheadAttention(16,1) + check_equal(q,k,v,tatt,jatt) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4) + jatt=jt.attention.MultiheadAttention(16,4) + check_equal(q,k,v,tatt,jatt) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1,self_attention=True) + jatt=jt.attention.MultiheadAttention(16,1,self_attention=True) + check_equal(q,q,q,tatt,jatt) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4,self_attention=True) + jatt=jt.attention.MultiheadAttention(16,4,self_attention=True) + check_equal(q,q,q,tatt,jatt) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file