论文传送门: https://arxiv.org/pdf/1911.11423.pdf
代码传送门: https://github.com/Smerity/sha-rnn
SHA-RNN 是由几个部分组成的:一个可训练的嵌入层,一层或者多层堆叠的单头注意力RNN (SHA-RNN) ,再加一个softmax分类器。
其中,SHA-RNN的结构就是下图这样:
相比之下,SHA-RNN模型的注意力是简化的,只留一个头,唯一的矩阵乘法出现在query (下图Q) 那里,A是缩放点乘注意力 (Scaled Dot-Product Attention) ,是向量之间的运算。
代码见下:
class SHARNN(nn.Module):
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0, tie_weights=False):
super().__init__()
embed_dim = ninp
hidden_dim = nhid
self.ninp, self.nhid = ninp, nhid
self.nlayers = nlayers
num_embeddings = ntoken
self.num_max_positions = 5000 # 2500 # 5000 # 4096 # 2048 # 4096 + 1024 # 2048 # 5000 # 4096 # 1024 # 4096 # 512 # 1024 # 4096 # 4608 # 7168 # 8192 # 6144 # 4608 # 5000 # 4096 # 3072 # 8192 # 4096
self.num_heads = 1 # 4
num_layers = nlayers
self.causal = True
self.drop = nn.Dropout(dropout)
self.idrop = nn.Dropout(dropouti)
self.hdrop = nn.Dropout(dropouth)
#from fastai.text.models import QRNN, QRNNLayer
self.blocks = nn.ModuleList()
for idx in range(num_layers):
#rnn = True if idx in [0, num_layers - 1] else mid_rnn
#rnn = rnns[0]
#rnn = rnns[idx % 2]
#rnn = rnns[idx]
rnn = True
self.blocks.append(Block(embed_dim, hidden_dim, self.num_heads, dropout=dropouth, rnn=rnn, residual=False, use_attn=True if idx == num_layers - 2 else False))
#self.pos_emb = nn.Parameter(torch.zeros(size=(self.num_max_positions, 1, embed_dim), dtype=torch.float))
self.pos_emb = [0] * self.num_max_positions
#self.position_gates = torch.nn.ParameterList([nn.Parameter(torch.zeros(size=(1, 1, embed_dim), dtype=torch.float)) for _ in range(num_layers)])
self.encoder = nn.Embedding(num_embeddings, embed_dim)
self.decoder = nn.Linear(embed_dim, num_embeddings)
if tie_weights:
#if nhid != ninp:
# raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.apply(self.init_weights)
def init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
module.weight.data.normal_(mean=0.0, std=0.1 / np.sqrt(self.ninp))
if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
module.bias.data.zero_()
def forward(self, x, hidden=None, mems=None, padding_mask=None, return_h=True):
""" Input has shape [seq length, batch] """
e = self.encoder(x)
e = self.idrop(e)
if mems is not None:
maxmem = self.num_max_positions - len(e)
mems = [m[-maxmem:] for m in mems]
total_length = len(x) + (len(mems[0]) if mems else 0)
#pos_seq = torch.arange(self.num_max_positions - 1, -1, -1.0, device=e.device, dtype=torch.float)
#pe = self.pos_emb(pos_seq)
# #!&*!^$*&!*#&!YRUFEYDBW!^U#TEGWDBSTHTI!@UYEGDI^HJSTDGIQ
pe = self.pos_emb #* 0
#pe = self.dynamic_pe[:len(e)]
#pe = self.idrop(pe)
h = e
new_hidden = []
new_mems = []
focus = []
attn_mask = None
if self.causal:
attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1)
if mems:
max_mems = max(len(m) for m in mems)
happy = torch.zeros((len(x), max_mems), device=h.device, dtype=h.dtype)
attn_mask = torch.cat([happy, attn_mask], dim=-1)
for idx, block in enumerate(self.blocks):
mem = mems[idx] if mems else None
hid = hidden[idx] if hidden else None
#p = torch.sigmoid(self.position_gates[idx]) * pe
h, m, nh, f = block(h, pe, attn_mask=attn_mask, mem=mem, hidden=hid)
#focus.append(f)
new_hidden.append(nh)
new_mems.append(m)
h = self.drop(h)
if return_h:
return h, new_hidden, new_mems, None, None
return h, new_hidden, new_mems
接下来讲前馈层 (“Boom” Layer) :
用了一个v∈ℝH向量,又用矩阵乘法 (GeLU激活) 得到另一个向量u∈ℝN×H。然后,把u向量分解成N个向量,再求和,得到w∈ℝH向量。与传统的下映射层 (Down-Projection Layers) 相比,减少了运算量,除掉了一整个矩阵的参数。
代码见下:
class Boom(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.1, shortcut=False):
super(Boom, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout) if dropout else None
if not shortcut:
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.shortcut = shortcut
#self.act = nn.ReLU()
self.act = GELU()
#self.act = nn.Tanh()
def forward(self, input):
x = self.act(self.linear1(input))
if self.dropout: x = self.dropout(x)
if self.shortcut:
# Trim the end off if the size is different
ninp = input.shape[-1]
x = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
# Divide the hidden size evenly into chunks
x = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
# Collapse the chunks through summation
#h = h + self.drop(x).sum(dim=-2)
z = x.sum(dim=-2)
else:
z = self.linear2(x)
return z