Coder Social home page Coder Social logo

jingchao19 / bert_mrc_cls Goto Github PK

View Code? Open in Web Editor NEW

This project forked from aivolcano/bert_mrc_cls

0.0 0.0 0.0 326 KB

Build model frameworks base on NEZHA to include top conference papers in the field of nlp, because it will help me find ID features

Python 100.00%

bert_mrc_cls's Introduction

任务目标

  • 囊括NLP领域的最新paper中提到的模型, 作者开发该框架,对于新的paper,只需要看model.py就可以融入该框架,并且可以考虑迁移到推荐系统中使用
  • 推荐系统中最重要的特征是ID类特征,它组成的用户行为句子依靠NLP技术进行特征提取,NLP提取特征的技术路线为推荐系统提供了可参考的技术路线,2021年推荐系统领域内的模型逐步向预训练BERT靠拢,以求能提取更有效的提取ID组成的用户行为句子特征
  • 该框架已经包括语义匹配、文本分类、阅读理解3个框架,未来继续增加文本生产和命名实体识别2部分,这样,NLP的框架内容基本就已经囊括了

核心技术

BERT隐藏层加权融合,召回段落、召回关键句 等等

特征工程

  • 召回技术

    • TextRank长文本召回关键句,
    • BM25算法:跨文档召回top-k个含有正确答案的段落
  • Pooling:BERT 12层隐藏层的向量进行加权

Ganesh Jawahar等人[3]证明BERT每一层对文本的理解都不同,因此将BERT的十二层transformer生成的表示赋予一个权重,(下文称为BERT动态融合),初始化公式为:

α_i = Dense_(unit=1) (represent_i)

ouput = Dense_(unit=512) (∑_(i=1)^nα_i · represent_i)

BERT的动态融合作为模型embedding的成绩会优于BERT最后一层向量作为模型的embedding,且收敛速度更快,epochs=3-4个就可以了

image

权重学习易造成高时空复杂度,我们还可以使用SumPooling、MeanPooling、MaxPooling等方法进行融合,选择层数偏后面几层。

模型内部结构 Pooling

  • 使用残差网络解决BERT12个隐藏层喂给BiLSTM时,模型发生退化问题 我们将BERT的12个隐藏层喂给BiLSTM,即:BERT -> BiLSTM -> Average Pooling的技术路线。实际上,作者在训练的过程中发现,增加LSTM之后,模型发生退化问题。因此我们开发了残差网络,让模型自己选择要不要跳过BiLSTM。

image

作者还提供更为丰富的模型替换LSTM,比如:GRU、Transformer、CNN、RTransformer(4选1) 等结构。由于使用残差网络,残差网络是加法,时空复杂度为O(1),参数量不变。

image

出现退化原因可能是:BERT的12层向量融合完成很好的提取了特征,这种情况复杂的模型反而效果会减弱。这在推荐系统中很常见,特征工程之后用个逻辑回归LR就能解决问题,可能对于LR来说,它只需要发挥自己的记忆能力,把特征工程整理出来的情况都记录在自己的评分卡中。

  • 开发不受维度限制的残差模型 顺着这一逻辑往下走,作者开发了不受维度限制的残差模块,希望解决维度不一致影响残差网络使用率低的问题。笔者认为在推荐系统和自然语言处理的深度学习模型都可以使用,包括特征提取层(Feature extrusion)和 全连接层(MLP)。它可以作为模型内部特征前后复用的一种策略。此外,残差网络时空复杂度变化小,且能保证效果达到最优时不会变差。

LSTM的残差连接

class ResidualWrapper4RNN(nn.Module):
    def __init__(self, model):
        super().__init__() # super(ResidualWrapper, self).__init__()
        self.model = model
    def forward(self, inputs, *args, **kwargs):
        delta = self.model(inputs, *args, **kwargs)[0]  # params.model_type=='bigru' or 'bilstm'
        return inputs + delta

self.bilstm = ResidualWrapper4RNN(nn.Sequential(
            nn.BiLSTM(config.hidden_size, params.lstm_hidden,
                             num_layers=params.num_layers, dropout=params.drop_prob, batch_first=True, bidirectional=True)
                             ))

result = self.bilstm(bert_ouput)

原理是:回到ResNet的核心,非线性激活函数的存在导致特征变化不可逆,因此造成模型退化的根本原因是非线性激活函数。因此F(x)= f(x) + x 可以理解为f(x)为非线性特征,x为线性特征。

image

作者开发的不受维度限制的残差网络数学公式是:

F(x)= f(x) + wx

该残差模块不受维度相等的条件限制,w的作用是维度变换,经过w的变换后,特征依然是线性的。

时空复杂度为0的写法:遇到维度不相等,可以直接用nn.Linear(), tf.keras.layers.Dense()让维度一致。然后再对位相加即可。

有时空复杂度的写法:向量不对位相加,直接拼接torch.cat([vector1, vector2],dim=-1), tf.concat([vector1, vector2], axis=-1) tf.keras.layers.concatation()

image

BERT作为特征提取器可能存在的问题

上述模型中,BERT的参数是不参与模型更新的,虽然我们设置了学习率,但是我们只使用了BERT的输出向量(batch_size, seq_len, embedding_dim),梯度无法传导到BERT中,原因是BERT 是基于Transformer Encoder完成的,在Transformer中 Decoder 模块的梯度传导可以有交叉信息熵算出来,由于 Decoder 中涉及 Encoder-Decoder Self-Attention,该注意力机制中 q 来自encoder ,k 和 v 来自decoder,随着decoder的更新,decoder中的注意力矩阵完成参数更新,该注意力矩阵中有来自encoder的q,q的更新带动encoder中的注意力矩阵更新,最终实现整个transformer的参数更新。

解决方案:为BERT增加辅助损失函数。

BERT的下游结构我们下接了BiLSTM

可改进的点

BERT模型动态融合需要BERT预训练模型已经很完美,因此可以使用我们该任务的语料喂给开源的预训练模型再训练20个epoch。

BERT Fine_tune微调依靠BERT自身已经能很好的提取特征,实际上,使用BERT时,经常遇到loss很小,但是总感觉模型没有学明白的问题,因此,作者考虑为BERT增加有监督学习的loss: [代码跑通,没有gpu还没跑出结果]BERT是无监督学习方式,该任务是文本分类,因此我们可以为BERT单独增加损失函数,相当于期中考试。具体做法是BERT最后一层的输出(也就是原始BERT)使用LR计算得到pre_label,与true_label计算得到损失

  • 修改loss:为bert增加辅助损失函数
# BERT增加损失函数:原始BERT输出直接做分类后计算一次损失(也由于BERT的重要性高于fine_tune部分,其loss权重可以高于fine_tune部分的权重)
ori_pooled_output = self.bert_cls(ori_pooled_output) #(none, 768) -> (none, 10)
bert_cls = F.softmax(ori_pooled_output, dim=-1)
bert_loss = nn.CrossEntropyLoss()(bert_cls, cls_label)

# 或者参考不均衡样本处理方法进行下采样,下采样是通用的思路 使用FocalLoss
class_loss = nn.CrossEntropyLoss()(classifier_logits, cls_label)# weight中设置不均衡的标签
class_loss = 0.8 * bert_loss + 0.2 * class_loss
outputs = class_loss
  • 为Attention矩阵增加残差,在MultiAttention那里增加残差网络

image

增加残差网络后:

image

这也就是Google新论文RealFormer的核心,论文地址:https://arxiv.org/abs/2012.11747v1

image

运行方案顺序

  • preprocess/preprocess.py
  • utils.py
  • train_fine_tune.py
  • predict.py

bert_mrc_cls's People

Contributors

aivolcano avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.