Coder Social home page Coder Social logo

srviest / char-cnn-text-classification-pytorch Goto Github PK

View Code? Open in Web Editor NEW
176.0 9.0 47.0 24.43 MB

Character-level Convolutional Neural Networks for text classification in PyTorch

License: Apache License 2.0

Python 100.00%
sentiment-analysis cnn-text-classification pytorch-implmention natural-language-processing

char-cnn-text-classification-pytorch's Introduction

Introduction

This is the implementation of Zhang's Character-level Convolutional Networks for Text Classification paper in PyTorch modified from Shawn1993/cnn-text-classification-pytorch.

Zhang's original implementation in Torch: https://github.com/zhangxiangxiao/Crepe

Requirement

  • python 2, 3
  • pytorch >= 0.5
  • numpy
  • termcolor

Dataset Format

Each sample looks like:

"class idx","sentence or text to be classified"  

Samples are separated by newline.

Example:

"3","Fears for T N pension after talks, Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."
"4","The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com)","SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket."

Train

python train.py -h

You will get:

Character-level CNN text classifier

optional arguments:
  -h, --help            show this help message and exit
  --train_path DIR      path to training data csv
  --val_path DIR        path to validation data csv

Learning options:
  --lr LR               initial learning rate [default: 0.0001]
  --epochs EPOCHS       number of epochs for train [default: 200]
  --batch_size BATCH_SIZE
                        batch size for training [default: 64]
  --max_norm MAX_NORM   Norm cutoff to prevent explosion of gradients
  --optimizer OPTIMIZER
                        Type of optimizer. SGD|Adam|ASGD are supported
                        [default: Adam]
  --class_weight        Weights should be a 1D Tensor assigning weight to each
                        of the classes.
  --dynamic_lr          Use dynamic learning schedule.
  --milestones MILESTONES [MILESTONES ...]
                        List of epoch indices. Must be increasing.
                        Default:[5,10,15]
  --decay_factor DECAY_FACTOR
                        Decay factor for reducing learning rate [default: 0.5]

Model options:
  --alphabet_path ALPHABET_PATH
                        Contains all characters for prediction
  --l0 L0               maximum length of input sequence to CNNs [default:
                        1014]
  --shuffle             shuffle the data every epoch
  --dropout DROPOUT     the probability for dropout [default: 0.5]
  -kernel_num KERNEL_NUM
                        number of each kind of kernel
  -kernel_sizes KERNEL_SIZES
                        comma-separated kernel size to use for convolution

Device options:
  --num_workers NUM_WORKERS
                        Number of workers used in data-loading
  --cuda                enable the gpu

Experiment options:
  --verbose             Turn on progress tracking per iteration for debugging
  --continue_from CONTINUE_FROM
                        Continue from checkpoint model
  --checkpoint          Enables checkpoint saving of model
  --checkpoint_per_batch CHECKPOINT_PER_BATCH
                        Save checkpoint per batch. 0 means never save
                        [default: 10000]
  --save_folder SAVE_FOLDER
                        Location to save epoch models, training configurations
                        and results.
  --log_config          Store experiment configuration
  --log_result          Store experiment result
  --log_interval LOG_INTERVAL
                        how many steps to wait before logging training status
                        [default: 1]
  --val_interval VAL_INTERVAL
                        how many steps to wait before vaidation [default: 200]
  --save_interval SAVE_INTERVAL
                        how many epochs to wait before saving [default:1]
python train.py

You will get:

Epoch[8] Batch[200] - loss: 0.237892  lr: 0.00050  acc: 93.7500%(120/128))
Evaluation - loss: 0.363364  acc: 89.1155%(6730/7552)
Label:   0      Prec:  93.2% (1636/1755)  Recall:  86.6% (1636/1890)  F-Score:  89.8%
Label:   1      Prec:  94.6% (1802/1905)  Recall:  95.6% (1802/1884)  F-Score:  95.1%
Label:   2      Prec:  85.6% (1587/1854)  Recall:  84.1% (1587/1888)  F-Score:  84.8%
Label:   3      Prec:  83.7% (1705/2038)  Recall:  90.2% (1705/1890)  F-Score:  86.8%

Test

If you has construct you test set, you make testing like:

python test.py --test-path='data/ag_news_csv/test.csv' --model-path='models_CharCNN/CharCNN_best.pth.tar'

The model-path option means where your model load from.

Reference

char-cnn-text-classification-pytorch's People

Contributors

astorfi avatar eriche2016 avatar jadore801120 avatar phonism avatar shawn1993 avatar srviest avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

char-cnn-text-classification-pytorch's Issues

Not converging!

I have been able to run the model as is but it simply not converging!
For how many epochs I should run the model?

No learning

I am trying the code, but precision is not improving. Model keeps on selecting only one of the classes:

image

image

Thanks

Error in Pytorch 1.0.0

Hi! I am using Pytorch 1.0.0 with Python 2.7 and I get the error below when running python train.py command:

Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
x = self.log_softmax(x)
train.py:120: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.
torch.nn.utils.clip_grad_norm(model.parameters(), args.max_norm)
Traceback (most recent call last):
File "train.py", line 282, in
main()
File "train.py", line 279, in main
train(train_loader, dev_loader, model, args)
File "train.py", line 137, in train
loss.data[0],
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

How to fix this? Thanks!

IndexError: list index out of range

When I was training, I met a problem.
File "train.py", line 278, in <module> main() File "train.py", line 275, in main train(train_loader, dev_loader, model, args) File "train.py", line 97, in train model = torch.nn.DataParallel(model).cuda() File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 47, in __init__ output_device = device_ids[0] IndexError: list index out of range
please help me!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.