Coder Social home page Coder Social logo

tantailong / textgen Goto Github PK

View Code? Open in Web Editor NEW

This project forked from shibing624/textgen

0.0 1.0 0.0 4.47 MB

textgen, Text Generation models. 文本生成,包括:UDA,Seq2Seq,ERNIE-GEN,BERT,XLNet,GPT-2等模型实现,开箱即用。

License: Apache License 2.0

Python 100.00%

textgen's Introduction

PyPI version Downloads Contributions welcome GitHub contributors License Apache 2.0 python_version GitHub issues Wechat Group

TextGen

🌈 Implementation of Text Generation models.

textgen实现了多种文本生成模型,包括:UDA、Seq2Seq、BERT、XLNet、GPT2、ERNIE-GEN等模型,开箱即用。

Guide

Question

文本生成,文本数据增强怎么做?

Solution

  1. UDA,非核心词替换
  2. EDA,简单数据增强技术:相似词、同义词替换,随机词插入、删除、替换
  3. 回译(BT, Back Translate),中文-英文-中文
  4. 生成模型,Seq2Seq,GPT2等

Feature

UDA(非核心词替换)

基于Google提出的UDA(非核心词替换)算法,将文本中一定比例的不重要词替换为同义词,从而产生新的文本。

BT(回译)

基于百度翻译API,把中文句子翻译为英文,再把英文翻译为新的中文。

Seq2Seq

基于Encoder-Decoder结构,序列到序列生成新的文本。

GPT2

基于Transformer的decode结果的自回归生成模型。

TGLS

无监督文本生成模型:本文提出TGLS——一种基于“先搜索后学习”的无监督文本生成方法, 模型反复迭代,最终能生成较高质量的文本。

Demo

Demo: https://huggingface.co/spaces/shibing624/textgen

run example: examples/gradio_demo.py to see the demo:

python examples/gradio_demo.py

Install

pip3 install torch # conda install pytorch
pip3 install -U textgen

or

git clone https://github.com/shibing624/textgen.git
cd textgen
python3 setup.py install

Usage

1. 文本数据增强

example: examples/text_augmentation_demo.py

import sys

sys.path.append('..')
from textgen.augment import TextAugment

if __name__ == '__main__':
    docs = ['主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容',
            '晚上肚子好难受',
            '你会武功吗,我不会',
            '组装标题质量受限于广告主自提物料的片段质量,且表达丰富度有限',
            '晚上一个人好孤单,想:找附近的人陪陪我.',
            ]
    m = TextAugment(sentence_list=docs)
    a = docs[0]
    print(a)

    b = m.augment(a, aug_ops='random-0.1')
    print('random-0.1:', b)

    b = m.augment(a, aug_ops='insert-0.1')
    print('insert-0.1:', b)

    # tfidf
    b = m.augment(a, aug_ops='tfidf-0.2')
    print('tfidf-0.2:', b)

    b = m.augment(a, aug_ops='mix-0.1', similar_prob=0.1,
                  random_prob=0.4, delete_prob=0.3, insert_prob=0.2)
    print('mix-0.1:', b)

    b = m.augment(a, aug_ops='bt')
    print('bt:', b)

output:

主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容
random-0.1: ('主要的机器学习、深度学习吗计算机视觉、好孤单对话系统相关内容', [('研究', '', 2, 3), ('', '', 12, 13), ('智能', '好孤单', 19, 22)])
insert-0.1: ('主要研究机器机器学习、深度学习、计算机视觉、智能对话对话系统系统相关内容', [('机器', '机器机器', 4, 8), ('对话', '对话对话', 24, 28), ('系统', '系统系统', 28, 32)])
tfidf-0.2: ('主要原因研究机器学习、深度学习、计算机硬件视觉、智能化对话系统相关内容', [('主要', '主要原因', 0, 4), ('计算机', '计算机硬件', 16, 21), ('智能', '智能化', 24, 27)])
mix-0.1: ('主要受限于机器学习、深度学习、计算机视觉、智能对话系统相关内容', [('研究', '受限于', 2, 5)])
bt: ('主要研究机器学习、深度学习、计算机视觉和智能对话系统', [])

2. Seq2Seq 模型

训练并预测ConvSeq2Seq模型:

example: examples/seq2sesq/training_convseq2seq_model_demo.py

import argparse
from loguru import logger
import sys

sys.path.append('../..')
from textgen.seq2seq.conv_seq2seq_model import ConvSeq2SeqModel


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_file', default='zh_dialog.tsv', type=str, help='Training data file')
    parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
    parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
    parser.add_argument('--output_dir', default='./outputs/convseq2seq_zh/', type=str, help='Model output directory')
    parser.add_argument('--max_seq_length', default=50, type=int, help='Max sequence length')
    parser.add_argument('--num_epochs', default=200, type=int, help='Number of training epochs')
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
    args = parser.parse_args()
    logger.info(args)

    if args.do_train:
        logger.info('Loading data...')
        model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
                                 model_dir=args.output_dir, max_length=args.max_seq_length)
        model.train_model(args.train_file)
        print(model.eval_model(args.train_file))

    if args.do_predict:
        model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
                                 model_dir=args.output_dir, max_length=args.max_seq_length)
        sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
        print("inputs:", sentences)
        print(model.predict(sentences))


if __name__ == '__main__':
    main()

output:

["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
['人工智能是工程和科学的分支,致力于构建思维的机器。', '我的程序运行在python,所以我在任何运脑上工作!', '我不能错热是一个疯狂的人工智能"200年。']

3. GPT2 模型

example: examples/language_generation/gpt2_demo.py

TODO

  • ernie-gen
  • xlnet

Contact

  • Issue(建议):GitHub issues
  • 邮件我:xuming: [email protected]
  • 微信我: 加我微信号:xuming624, 备注:姓名-公司名-NLP 进NLP交流群。

License

授权协议为 The Apache License 2.0,可免费用做商业用途。请在产品说明中附加textgen的链接和授权协议。

Contribute

项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:

  • tests添加相应的单元测试
  • 使用python -m pytest来运行所有单元测试,确保所有单测都是通过的

之后即可提交PR。

Reference

textgen's People

Contributors

shibing624 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.