淘先锋技术网

首页 1 2 3 4 5 6 7

论文传送门: https://arxiv.org/pdf/1911.11423.pdf

代码传送门: https://github.com/Smerity/sha-rnn

SHA-RNN 是由几个部分组成的:一个可训练的嵌入层,一层或者多层堆叠的单头注意力RNN (SHA-RNN) ,再加一个softmax分类器。

其中,SHA-RNN的结构就是下图这样:

                                                  

相比之下,SHA-RNN模型的注意力是简化的,只留一个头,唯一的矩阵乘法出现在query (下图Q) 那里,A是缩放点乘注意力 (Scaled Dot-Product Attention) ,是向量之间的运算。

                                                 

代码见下:

class SHARNN(nn.Module):
    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0, tie_weights=False):
        super().__init__()
        embed_dim = ninp
        hidden_dim = nhid
        self.ninp, self.nhid = ninp, nhid
        self.nlayers = nlayers
        num_embeddings = ntoken
        self.num_max_positions = 5000 # 2500 # 5000 # 4096 # 2048 # 4096 + 1024 # 2048 # 5000 # 4096 # 1024 # 4096 # 512 # 1024 # 4096 # 4608 # 7168 # 8192 # 6144 # 4608 # 5000 # 4096 # 3072 # 8192 # 4096
        self.num_heads = 1 # 4
        num_layers = nlayers
        self.causal = True
        self.drop = nn.Dropout(dropout)
        self.idrop = nn.Dropout(dropouti)
        self.hdrop = nn.Dropout(dropouth)

        #from fastai.text.models import QRNN, QRNNLayer

        self.blocks = nn.ModuleList()
        for idx in range(num_layers):
            #rnn = True if idx in [0, num_layers - 1] else mid_rnn
            #rnn = rnns[0]
            #rnn = rnns[idx % 2]
            #rnn = rnns[idx]
            rnn = True
            self.blocks.append(Block(embed_dim, hidden_dim, self.num_heads, dropout=dropouth, rnn=rnn, residual=False, use_attn=True if idx == num_layers - 2 else False))

        #self.pos_emb = nn.Parameter(torch.zeros(size=(self.num_max_positions, 1, embed_dim), dtype=torch.float))
        self.pos_emb = [0] * self.num_max_positions
        #self.position_gates = torch.nn.ParameterList([nn.Parameter(torch.zeros(size=(1, 1, embed_dim), dtype=torch.float)) for _ in range(num_layers)])

        self.encoder = nn.Embedding(num_embeddings, embed_dim)
        self.decoder = nn.Linear(embed_dim, num_embeddings)
        if tie_weights:
            #if nhid != ninp:
            #    raise ValueError('When using the tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=0.1 / np.sqrt(self.ninp))

        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x, hidden=None, mems=None, padding_mask=None, return_h=True):
        """ Input has shape [seq length, batch] """
        e = self.encoder(x)
        e = self.idrop(e)

        if mems is not None:
            maxmem = self.num_max_positions - len(e)
            mems = [m[-maxmem:] for m in mems]

        total_length = len(x) + (len(mems[0]) if mems else 0)
        #pos_seq = torch.arange(self.num_max_positions - 1, -1, -1.0, device=e.device, dtype=torch.float)
        #pe = self.pos_emb(pos_seq)
        # #!&*!^$*&!*#&!YRUFEYDBW!^U#TEGWDBSTHTI!@UYEGDI^HJSTDGIQ
        pe = self.pos_emb #* 0
        #pe = self.dynamic_pe[:len(e)]
        #pe = self.idrop(pe)

        h = e

        new_hidden = []

        new_mems = []

        focus = []

        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)
            if mems:
                max_mems = max(len(m) for m in mems)
                happy = torch.zeros((len(x), max_mems), device=h.device, dtype=h.dtype)
                attn_mask = torch.cat([happy, attn_mask], dim=-1)

        for idx, block in enumerate(self.blocks):
            mem = mems[idx] if mems else None
            hid = hidden[idx] if hidden else None
            #p = torch.sigmoid(self.position_gates[idx]) * pe
            h, m, nh, f = block(h, pe, attn_mask=attn_mask, mem=mem, hidden=hid)
            #focus.append(f)
            new_hidden.append(nh)
            new_mems.append(m)

        h = self.drop(h)

        if return_h:
            return h, new_hidden, new_mems, None, None
        return h, new_hidden, new_mems

接下来讲前馈层 (“Boom” Layer) :

用了一个v∈ℝH向量,又用矩阵乘法 (GeLU激活) 得到另一个向量u∈ℝN×H。然后,把u向量分解成N个向量,再求和,得到w∈ℝH向量。与传统的下映射层 (Down-Projection Layers) 相比,减少了运算量,除掉了一整个矩阵的参数。

代码见下:

class Boom(nn.Module):

    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1, shortcut=False):
        super(Boom, self).__init__()
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout) if dropout else None
        if not shortcut:
            self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.shortcut = shortcut
        #self.act = nn.ReLU()
        self.act = GELU()
        #self.act = nn.Tanh()

    def forward(self, input):
        x = self.act(self.linear1(input))
        if self.dropout: x = self.dropout(x)
        if self.shortcut:
            # Trim the end off if the size is different
            ninp = input.shape[-1]
            x = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
            # Divide the hidden size evenly into chunks
            x = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
            # Collapse the chunks through summation
            #h = h + self.drop(x).sum(dim=-2)
            z = x.sum(dim=-2)
        else:
            z = self.linear2(x)

        return z