Coder Social home page Coder Social logo

Comments (4)

airaria avatar airaria commented on May 27, 2024

这个做不到,需要修改bert模型。
可以参考以下步骤:

  1. 修改bert的接口,使之返回一个value_layer list,list中的元素是各个层的value_layer
  2. 给这个特征随便取个名字,比如叫"value_layer",添加为features:
textbrewer.presets.FEATURES.append("value_layer")
  1. 仿照textbrewer.losses中的损失函数(比如仿照hid_mse_loss),实现你自己的中间层损失函数,这个函数接收教师和学生的value_layer,蒸馏温度和mask为参数(嫌麻烦的话可不实现mask相关逻辑)。并在textbrewer.presets. MATCH_LOSS_MAP中注册:
textbrewer.presets. MATCH_LOSS_MAP['custom_loss'] = your_custom_loss_function
  1. 之后可以像使用其他损失函数一样使用你的自定义损失函数了,比中间层匹配的配置可写为:
intermediate_matches = [{"layer_T":1, "layer_S":1, "feature":"value_layer", "loss":"custom_loss", "weight":1}, ... ]

最后别忘了,使用新的loss时 adaptor返回的字典里要提供loss需要的value_layer:

def adaptor(batch, model_outputs):
return {'logits' : ...,
        'value_layer': ..., #BERT返回的value_layer list,
         ... }

Hope it helps.

from textbrewer.

airaria avatar airaria commented on May 27, 2024

补充一点:
如loss形式本身无变化(还是mse或cross-entropy等计算方式)修改的只是输入,也可跳过步骤3,直接利用现有损失函数如hid_mse_loss,att_mse_loss计算value_layer相关的损失,只要把feature改成"value_layer"即可

from textbrewer.

YYangZiXin avatar YYangZiXin commented on May 27, 2024

感谢你的耐心解答!

from textbrewer.

catqaq avatar catqaq commented on May 27, 2024

工具非常棒!
@airaria 您前面说明了如何自定义特征,并添加到中间层匹配里面去。我们使用过程中遇到一个小问题,如果新特征是不分层的,就像logits那样,该如何添加映射呢?使用CustomMatch吗,但是我们暂时没有发现比较简明的示例。目前我们的做法是使用新特征来作为logits(在满足size要求的情况下),相当于复用了你们预设的logits字段,但是使用的是我们自己定义的新特征。请问有更好的处理方式吗?

from textbrewer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.