Coder Social home page Coder Social logo

pytorch_bert_japanese's Introduction

PytorchでBERTの日本語学習済みモデルを利用する

これはPytorchで日本語の学習済みBERTモデルを読み込み、文章ベクトル(Sentence Embedding)を計算するためのコードです。

詳細は下記ブログを参考ください。

PytorchでBERTの日本語学習済みモデルを利用する - 文章埋め込み編

環境

準備

日本語の学習済みBERTモデル

京都大学の黒橋・河原研究室が公開している「BERT日本語Pretrainedモデル」を利用します。下記ウェブページからモデルファイルをダウンロードして解凍してください。

BERT日本語Pretrainedモデル - KUROHASHI-KAWAHARA LAB

Juman++

Juman++をインストールします。インストール方法については、下記の公式レポジトリを参照ください。

ku-nlp/jumanpp: Juman++ (a Morphological Analyzer Toolkit)

なお、macOSならばHomebrewを使って下記のように簡単にインストールできます。

$ brew install jumanpp

Pythonパッケージ

pytorch-pretrained-bertおよびpyknpをインストールします。

$ pip install pytorch-pretrained-bert
$ pip install pyknp

なお、ここではPytorchをBERT実装に利用するので、Pytorchはインストールされているものとします。

PyTorch

実行する

本レポジトリのbert_juman.pyからBertWithJumanModelクラスをインポートします。クラスの引数には、ダウンロードした日本語の学習済みBERTモデルのディレクトリを指定します。必要なファイルはpytorch_model.binvocab.txtのみです。

In []: from bert_juman import BertWithJumanModel

In []: bert = BertWithJumanModel("/path/to/Japanese_L-12_H-768_A-12_E-30_BPE")

In []: bert.get_sentence_embedding("吾輩は猫である。")
Out[]:
array([ 2.22642735e-01, -2.40221739e-01,  1.09303640e-02, -1.02307117e+00,
        1.78834641e+00, -2.73566216e-01, -1.57942638e-01, -7.98571169e-01,
       -2.77438164e-02, -8.05811465e-01,  3.46736580e-01, -7.20409870e-01,
        1.03382647e-01, -5.33944130e-01, -3.25344890e-01, -1.02880754e-01,
        2.26500735e-01, -8.97880018e-01,  2.52314955e-01, -7.09809303e-01,
[...]        

またget_sentence_embedding()の引数には、文章ベクトルを計算するのに利用するBERTの隠れ層の位置pooling_layerと、プーリングの方法pooling_strategyが指定できます。pooling_layer-1で最終層、-2で最終層の手前の層となります。また、pooling_strategyには

  • REDUCE_MEAN: 要素ごとにaverage-pooling
  • REDUCE_MAX: 要素ごとにmax-pooling
  • REDUCE_MEAN_MAX: REDUCE_MEANREDUCE_MAXを結合したもの
  • CLS_TOKEN: [CLS]トークンのベクトルをそのまま利用

が選択できます。

In []: bert.get_sentence_embedding("吾輩は猫である。",
   ...:                             pooling_layer=-1,
   ...:                             pooling_strategy="REDUCE_MAX")
   ...:
Out[]:
array([ 1.2089624 ,  0.6267309 ,  0.7243419 , -0.12712255,  1.8050476 ,
        0.43929055,  0.605848  ,  0.5058241 ,  0.8335829 , -0.26000524,
[...]        

これらのパラメータはhanxiao/bert-as-serviceを参考にしています。

GPU Option

In []: bert = BertWithJumanModel("../Japanese_L-12_H-768_A-12_E-30_BPE", use_cuda=True)

In []: bert.get_sentence_embedding("吾輩は猫である。")
Out[]:
array([-4.25627649e-01, -3.42006773e-01, -7.15175271e-02, -1.09820020e+00,
        1.08186746e+00, -2.35576674e-01, -1.89862609e-01, -5.50959229e-01,

pytorch_bert_japanese's People

Contributors

yagays avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pytorch_bert_japanese's Issues

実行時にsegmentation fault

日本語で失礼します。本レポジトリとあまり関係ない質問かもしれませんが、質問させてください。

当方環境
CentOS7
pyenvでのanaconda3-5.2.0
pytorch CPUバージョン(No CUDA)
256GBのRAM搭載

pip install torch==1.2.0+cpu torchvision==0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html

この環境で動かなくなり調査しました結果、

 all_encoder_layers, _ = model(tokens_tensor)

でSegmentation fault発生して死んでしまうことがわかったのですが(つまりpytorch_pretrained_bertのBertModelでの問題っぽい)
いくつか質問させてください。

  1. コードを拝見させていただいたところBertModelのfrom_pretrainedでbinファイルを直接読まず、フォルダを読んでいますが、これはどうしてなのでしょう?

  2. このコードが動いた環境について詳しく情報が欲しいです。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.