Coder Social home page Coder Social logo

pytorch-rnn's Introduction

pytorch-rnn

A rewrite of torch-rnn using PyTorch. Training is being worked on now, and torch-rnn checkpoints can be loaded and sampled from. An extensible and efficient HTTP sampling server has been implemented.

Installation

Install PyTorch using the official guide

Data preprocessing

At the moment you'll have to use the preprocessing scripts from torch-rnn. Only GridGRU models are supported at this time.

Training

Train the network using train.py.

python3 train.py --input-h5 my_data.h5 --input-json my_data.json --checkpoint-name my_model

This will create two files my_checkpoint_N.json and my_checkpoint_N.0 per epoch, where the JSON file contains architecture description and the .0 file contains raw model parameters. This allows faster, more flexible and more efficient model saving/loading. You can use GPU using --device cuda, but this is barely tested at this time. When training on CPU, make sure to set the optimal number of threads using the OMP_NUM_THREADS environment variable - otherwise pytorch defaults to using all cores, which seems to cause a huge slowdown. Also when running on a NUMA system, try binding the process to one node using numactl.

OMP_NUM_THREADS=3 numactl -m 0 -N 0 python3 train.py ...

Using the model

sampling.py implements an extensible and efficient sampling module. You can sample output from the model using sample.py:

python3 sample.py --checkpoint my_model.json

A simple chat application, chatter.py is also included. An efficient HTTP sampling server is also included. Edit the example config file and start the server:

python3 httpserver.py config.cfg

Then you can send text to the model and generate responses using a simple HTTP interface and specify different options for text generation:

curl -X POST -d '{"key":"test", "text":"User input"}' http://localhost:7880/put
curl -X POST -d '{"key":"test"}' http://localhost:7880/get
curl -X POST -d '{"key":"test", "bad_words":["stinky"], "temperature":0.7, "softlenght_max" : 50}' http://localhost:7880/get

The server can handle multiple parallel requests by packing them into one batch, which allows efficient generation of dozens of text streams at the same time.

Differences from torch-rnn

  • Only GridGRU layers are implemented at this time, based on guillitte's implementation. More layer types will be implemented later
  • String decoder works on byte level and is fully encoding-agnostic. Any tokenization scheme (bytes, unicode, words...) should work, as long as it can be decoded by a greedy algorithm.
  • Training now gives expected results. For some reason PyTorch 1.0 was causing gradient issues, but updating to 1.1 fixed it.

pytorch-rnn's People

Contributors

antihutka avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.