Coder Social home page Coder Social logo

wellyzhang / copinet Goto Github PK

View Code? Open in Web Editor NEW
27.0 3.0 3.0 24 KB

Learning Perceptual Inference by Contrasting

Home Page: http://wellyzhang.github.io/project/copinet.html

License: GNU General Public License v3.0

Python 100.00%
ravens-progressive-matrices neurips-2019 visual-reasoning abstract-reasoning

copinet's Introduction

CoPINet

This repo contains code for our NeurIPS 2019 spotlight paper.

Learning Perceptual Inference by Contrasting
Chi Zhang*, Baoxiong Jia*, Feng Gao, Yixin Zhu, Hongjing Lu, Song-Chun Zhu
Proceedings of Advances in Neural Information Processing Systems (NeurIPS), 2019
Spotlight (2.43% acceptance rate)
(* indicates equal contribution.)

"Thinking in pictures," i.e., spatial-temporal reasoning, effortless and instantaneous for humans, is believed to be a significant ability to perform logical induction and a crucial factor in the intellectual history of technology development. Modern Artificial Intelligence (AI), fueled by massive datasets, deeper models, and mighty computation, has come to a stage where (super-)human-level performances are observed in certain specific tasks. However, current AI's ability in "thinking in pictures" is still far lacking behind. In this work, we study how to improve machines' reasoning ability on one challenging task of this kind: Raven's Progressive Matrices (RPM). Specifically, we borrow the very idea of "contrast effects" from the field of psychology, cognition, and education to design and train a permutation-invariant model. Inspired by cognitive studies, we equip our model with a simple inference module that is jointly trained with the perception backbone. Combining all the elements, we propose the Contrastive Perceptual Inference network (CoPINet) and empirically demonstrate that CoPINet sets the new state-of-the-art for permutation-invariant models on two major datasets. We conclude that spatial-temporal reasoning depends on envisaging the possibilities consistent with the relations between objects and can be solved from pixel-level inputs.

model

Performance

The following two tables show the performance of various methods on the RAVEN dataset and the PGM dataset. For details, please check our paper.

Performance on RAVEN:

Method Acc Center 2x2Grid 3x3Grid L-R U-D O-IC O-IG
LSTM 13.07% 13.19% 14.13% 13.69% 12.84% 12.35% 12.15% 12.99%
WReN-NoTag-Aux 17.62% 17.66% 29.02% 34.67% 7.69% 7.89% 12.30% 13.94%
CNN 36.97% 33.58% 30.30% 33.53% 39.43% 41.26% 43.20% 37.54%
ResNet 53.43% 52.82% 41.86% 44.29% 58.77% 60.16% 63.19% 53.12%
ResNet+DRT 59.56% 58.08% 46.53% 50.40% 65.82% 67.11% 69.09% 60.11%
CoPINet 91.42% 95.05% 77.45% 78.85% 99.10% 99.65% 98.50% 91.35%
WReN-NoTag-NoAux 15.07% 12.30% 28.62% 29.22% 7.20% 6.55% 8.33% 13.10%
WReN-Tag-NoAux 17.94% 15.38% 29.81% 32.94% 11.06% 10.96% 11.06% 14.54%
WReN-Tag-Aux 33.97% 58.38% 38.89% 37.70% 21.58% 19.74% 38.84% 22.57%
CoPINet-Backbone-XE 20.75% 24.00% 23.25% 23.05% 15.00% 13.90% 21.25% 24.80%
CoPINet-Contrast-XE 86.16% 87.25% 71.05% 74.45% 97.25% 97.05% 93.20% 82.90%
CoPINet-Contrast-CL 90.04% 94.30% 74.00% 76.85% 99.05% 99.35% 98.00% 88.70%
Human 84.41% 95.45% 81.82% 79.55% 86.36% 81.81% 86.36% 81.81%
Solver 100% 100% 100% 100% 100% 100% 100% 100%

Performance on PGM:

Method CNN LSTM ResNet Wild-ResNet WReN-NoTag-Aux CoPINet
Acc 33.00% 35.80% 42.00% 48.00% 49.10% 56.37%

For CoPINet, we note that after cleaning the code, we can potentially get numbers slightly better than reported in the paper. Here, we only show numbers we got when we submitted the paper.

Dependencies

Important

  • Python3 supported
  • PyTorch
  • CUDA and cuDNN expected

See requirements.txt for a full list of packages required.

Usage

To train CoPINet, run

python src/main.py train --dataset <path to dataset>

The default hyper-parameters should work. However, you can check main.py for a full list of arguments you can adjust.

Performance of existing baselines is obtained from this repo.

Citation

If you find the paper and/or the code helpful, please cite us.

@inproceedings{zhang2019learning,
    title={Learning Perceptual Inference by Contrasting},
    author={Zhang, Chi and Jia, Baoxiong and Gao, Feng and Zhu, Yixin and Lu, Hongjing and Zhu, Song-Chun},
    booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
    year={2019}
}

Acknowledgement

We'd like to express our gratitude towards all the colleagues and anonymous reviewers for helping us improve the paper. The project is impossible to finish without the following open-source implementation.

copinet's People

Contributors

wellyzhang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

copinet's Issues

ValueError: num_samples should be a positive integer value, but got num_samples=0

I am using CoPINET on my Colaboratory . while running
!python /content/CoPINet/src/main.py train --dataset "/content/drive/MyDrive/sample2/center_single"

the following error occurs
Traceback (most recent call last):
File "/content/CoPINet/src/main.py", line 262, in
main()
File "/content/CoPINet/src/main.py", line 253, in main
train(args, device)
File "/content/CoPINet/src/main.py", line 122, in train
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 268, in init
sampler = RandomSampler(dataset, generator=generator)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py", line 103, in init
"value, but got num_samples={}".format(self.num_samples))
ValueError: num_samples should be a positive integer value, but got num_samples=0

**I have mounted google drive on my Colab and this folder /content/drive/MyDrive/sample2/center_single contains 10 sample training data (RAVEN_0_train.npz , RAVEN_0_train.xml , ........... , RAVEN_9_train.npz , RAVEN_9_train.xml)

A problem with the code running on the PGM dataset

The result of using this code on the RAVEN dataset is roughly the same as the data shown in the paper, but the performance of this code on the PGM data set is much different from the paper (we only got about 30% accuracy)
We noticed that the image data(16x160x160) of the RAVEN dataset is different from the image data (160x160x16) of the PGM dataset. Is there any other details to be considered when running on the PGM dataset?

Code provided:

    data = np.load(data_path)
    image = data["image"].reshape(16, 160, 160)
    target = data["target"]
    
    if self.img_size != 160:
        resize_image = []
        for idx in range(16):
            resize_image.append(misc.imresize(image[idx, :, :], (self.img_size, self.img_size)))
        image = np.stack(resize_image)
    image = torch.tensor(image, dtype=torch.float)
    target = torch.tensor(target, dtype=torch.long)

After modification:

    data = np.load(data_path)
    image = data["image"].reshape(160, 160,16)
    target = data["target"]
    
    if self.img_size != 160:
        resize_image = []
        for idx in range(16):
            resize_image.append(misc.imresize(image[:, :, idx], (self.img_size, self.img_size)))
        image = np.stack(resize_image)
    image = torch.tensor(image, dtype=torch.float)
    target = torch.tensor(target, dtype=torch.long)

Column-wise rules?

In the RAVEN dataset paper, it says that row-wise rules were applied. But in the code of CoPINet, column-wise rules are also taken into consideration.

Is this for the PGM dataset? Or is the column-wise rules also applied in the RAVEN dataset?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.