Coder Social home page Coder Social logo

mrlantern / gnn-for-text-classification Goto Github PK

View Code? Open in Web Editor NEW

This project forked from zshicode/gnn-for-text-classification

1.0 0.0 0.0 1.52 MB

Graph convolutional networks (GCN), graphSAGE and graph attention networks (GAT) for text classification

Python 100.00%

gnn-for-text-classification's Introduction

Graph neural networks for text classification

Graph neural networks has been widely used in natural language processing. Yao et al. (2019) proposed TextGCN that adopts graph convolutional networks (GCN) (Kipf and Welling, 2017) for text classification on heterogeneous graph. We implemented TextGCN based on PyTorch and DGL. Furthermore, we suggest that inductive learning and attention mechanism is crucial for text classification using graph neural networks. So we adopt GraphSAGE (Hamilton et al., 2017) and graph attention networks (GAT) (Velickovic et al., 2018) for this classification task.

We evaluated our model on MR dataset (Tang et al., 2015) . Results:

GCN (Yao et al., 2019) GCN GraphSAGE GAT
Accuracy 0.7552 0.7541 0.7642 0.7665
Macro-F1 0.7548 0.7536 0.7637 0.7664

Some of our codes are from repository TextGCN and TextGCN in PyTorch. More datasets can also be downloaded from these two repositories.

Requirements

The code has been tested running under:

  • Python 3.7
  • PyTorch 1.3
  • dgl-cu101 0.4
  • CUDA 10.1

Running training and evaluation

First we need to preprocess data:

  1. cd ./preprocess
  2. Run python remove_words.py <dataset>
  3. Run python build_graph.py <dataset>
  4. cd ../
  5. Replace <dataset> with 20ng, R8, R52, ohsumed or mr

then run

python main.py --model GCN --cuda True

parameters:

def get_citation_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda', type=bool, default=False,
                        help='Use CUDA training.')
    parser.add_argument('--epochs', type=int, default=100,
                        help='Number of epochs to train.')
    parser.add_argument('--lr', type=float, default=0.02,
                        help='Initial learning rate.')
    parser.add_argument('--model', type=str, default="GCN",
                        choices=["GCN", "SAGE", "GAT"],
                        help='model to use.')
    parser.add_argument('--early_stopping', type=int, default=10,
                        help='require early stopping.')
    parser.add_argument('--dataset', type=str, default='mr',
                        choices = ['20ng', 'R8', 'R52', 'ohsumed', 'mr'],
                        help='dataset to train')

    args, _ = parser.parse_known_args()
    #args.cuda = not args.no_cuda and th.cuda.is_available()
    return args

args = get_citation_args()

References

William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive representation learning on large graphs. In NeurIPS, 2017.

Thomas Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. In ICLR, 2017.

Jian Tang, Meng Qu, and Qiaozhu Mei. Pte: Predictive text embedding through large-scale heterogeneous text networks. In KDD, 2015.

Petar Velickovic, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. Graph attention networks. In ICLR, 2018.

Liang Yao, Chengsheng Mao, and Yuan Luo. Graph convolutional networks for text classification. Proceedings of the AAAI Conference on Artificial Intelligence, 33:7370โ€“7377, 2019.

gnn-for-text-classification's People

Contributors

zshicode avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.