Coder Social home page Coder Social logo

gt-ripl / l2c Goto Github PK

View Code? Open in Web Editor NEW
313.0 19.0 49.0 441 KB

Learning to Cluster. A deep clustering strategy.

License: MIT License

Python 96.05% Shell 3.95%
clustering deep-learning deep-neural-networks artificial-neural-networks artificial-intelligence unsupervised-learning supervised-learning similarity-metric semi-supervised-learning transfer-learning

l2c's Introduction

L2C: Learning to Cluster

A clustering strategy with deep neural networks. This blog article provides a generic overview.

Introduction

This repository provides the PyTorch implementation of the transfer learning schemes (L2C) and two learning criteria useful for deep clustering:

*It is renamed from CCL

This repository covers following references:

@inproceedings{Hsu19_MCL,
	title =	    {Multi-class classification without multi-class labels},
	author =    {Yen-Chang Hsu, Zhaoyang Lv, Joel Schlosser, Phillip Odom, Zsolt Kira},
	booktitle = {International Conference on Learning Representations (ICLR)},
	year =      {2019},
	url =       {https://openreview.net/forum?id=SJzR2iRcK7}
}

@inproceedings{Hsu18_L2C,
	title =     {Learning to cluster in order to transfer across domains and tasks},
	author =    {Yen-Chang Hsu and Zhaoyang Lv and Zsolt Kira},
	booktitle = {International Conference on Learning Representations (ICLR)},
	year =      {2018},
	url =       {https://openreview.net/forum?id=ByRWCqvT-}
}

@inproceedings{Hsu16_KCL,
	title =	    {Neural network-based clustering using pairwise constraints},
	author =    {Yen-Chang Hsu and Zsolt Kira},
	booktitle = {ICLR workshop},
	year =      {2016},
	url =       {https://arxiv.org/abs/1511.06321}
}

Preparation

This repository supports PyTorch 1.0, python 2.7, 3.6, and 3.7.

pip install -r requirements.txt

Demo

Supervised Classification/Clustering with only pairwise similarity

# A quick trial:
python demo.py  # Default Dataset:MNIST, Network:LeNet, Loss:MCL
python demo.py --loss KCL

# Lookup available options:
python demo.py -h

# For more examples:
./scripts/exp_supervised_MCL_vs_KCL.sh

Unsupervised Clustering (Cross-task Transfer Learning)

# Learn the Similarity Prediction Network (SPN) with Omniglot_background and then transfer to the 20 alphabets in Omniglot_evaluation.
# Default loss is MCL with an unknown number of clusters (Set a large cluster number, i.e., k=100)
# It takes about half an hour to finish.
python demo_omniglot_transfer.py

# An example of using KCL and set k=gt_#cluster
python demo_omniglot_transfer.py --loss KCL --num_cluster -1

# Lookup available options:
python demo_omniglot_transfer.py -h

# Other examples:
./scripts/exp_unsupervised_transfer_Omniglot.sh

Notes

  • The clustering results are highly dependent on the performance of the Similarity Prediction Network (SPN). For making a fair comparison, the SPN must be kept the same. Our script trains an SPN with random initialization and random data sampling. Once the SPN model is trained, the script will reuse the saved SPN and avoid training a new one.
  • The table below presents the clustering performance with the reference SPN [download]. Put the model file into /outputs folder and run demo_omniglot_transfer.py directly to generate the "MCL(k=100)" column.
  • The performance metric is clustering accuracy (for details, please see L2C paper). Each value in the table is the average of 3 clustering runs. This repository reuses most of the utilities in PyTorch and is different from the Lua-based implementation used in the reference papers. The result (the row with "--Average--") shows the same trend as the papers, but the absolute values have a mild difference. The MCL results here are better than the paper.
Dataset gt #class KCL (k=100) MCL (k=100) KCL (k=gt) MCL (k=gt)
Angelic 20 73.2% 82.2% 89.0% 91.7%
Atemayar_Qelisayer 26 73.3% 89.2% 82.5% 86.0%
Atlantean 26 65.5% 83.3% 89.4% 93.5%
Aurek_Besh 26 88.4% 92.8% 91.5% 92.4%
Avesta 26 79.0% 85.8% 85.4% 86.1%
Ge_ez 26 77.1% 84.0% 85.4% 86.6%
Glagolitic 45 83.9% 85.3% 84.9% 87.4%
Gurmukhi 45 78.8% 78.7% 77.0% 78.0%
Kannada 41 64.6% 81.1% 73.3% 81.2%
Keble 26 91.4% 95.1% 94.7% 94.3%
Malayalam 47 73.5% 75.0% 72.7% 73.0%
Manipuri 40 82.8% 81.2% 85.8% 81.5%
Mongolian 30 84.7% 89.0% 88.3% 90.2%
Old_Church_Slavonic_Cyrillic 45 89.9% 90.7% 88.7% 89.8%
Oriya 46 56.5% 73.4% 63.2% 75.3%
Sylheti 28 61.8% 68.2% 69.8% 80.6%
Syriac_Serto 23 72.1% 82.0% 85.8% 89.8%
Tengwar 25 67.7% 76.4% 82.5% 85.5%
Tibetan 42 81.8% 80.2% 84.3% 81.9%
ULOG 26 53.3% 77.1% 73.0% 89.1%
--Average-- 75.0% 82.5% 82.4% 85.7%

Compare MCL and KCL

The loss surface of MCL is more similar to the cross-entropy (CE) than KCL. Empirically, MCL converged faster than KCL. For details, please refer to the ICLR paper.

Related Applications

Lane detection for autonomous driving / Instance segmentation

@article{Hsu18_InsSeg,
	title =     {Learning to Cluster for Proposal-Free Instance Segmentation},
	author =    {Yen-Chang Hsu, Zheng Xu, Zsolt Kira, Jiawei Huang},
	booktitle = {accepted to the International Joint Conference on Neural Networks (IJCNN)},
	year =      {2018},
	url =       {https://arxiv.org/abs/1803.06459}
}

Acknowledgments

This work was supported by the National Science Foundation and National Robotics Initiative (grant # IIS-1426998) and DARPA’s Lifelong Learning Machines (L2M) program, under Cooperative Agreement HR0011-18-2-001.

l2c's People

Contributors

yenchanghsu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

l2c's Issues

training G funstion on imagenet

Hi, I have question regarding training function G. by using pair-enumeration layer there are always many more dis-similar pair per batch than similar pairs, I can see that it works on Omniglot but on more complicated datasets like imagenet, would not this imbalance be a problem? is different hyperparameters used to train G for imagenet?

BatchKLdivergence.lua

Hello there

Have you also implemented BatchKLdivCriterion.lua in the original code in this repo ?

ImageNet Model Requirement

Thanks for your excellent works, but could you please share your similarity model for ImageNet?
In my experiments, the two-linear-SimilarityNet converges quickly on Omniglot, but very slowly on large scale datasets.

For example, the F1-score of "similar" for val and test categories saturate at around 50% after 1000 epochs:
vis_similarity_lr0 001_b100_h2_04021311_acc52 16_e999

The confusion matrixes for train set, val set, and test set are:

Train Matrix:
[ Dissimilarity       :  525795   95205	(621000 in all.)]
[ Similarity          :    1728   67272	( 69000 in all.)]

Dissimilarity       :	PR: 99.7%,	RR: 84.7%,	 F1: 91.6%.
Similarity          :	PR: 41.4%,	RR: 97.5%,	 F1: 58.1%.

Val Matrix:
[ Dissimilarity       :  207682   26318	(234000 in all.)]
[ Similarity          :    5590   20410	( 26000 in all.)]

Dissimilarity       :	PR: 97.4%,	RR: 88.8%,	 F1: 92.9%.
Similarity          :	PR: 43.7%,	RR: 78.5%,	 F1: 56.1%.

Test Matrix:
[ Dissimilarity       :   80454    9546	( 90000 in all.)]
[ Similarity          :    3535    6465	( 10000 in all.)]

Dissimilarity       :	PR: 95.8%,	RR: 89.4%,	 F1: 92.5%.
Similarity          :	PR: 40.4%,	RR: 64.6%,	 F1: 49.7%.

Therefore, it may need a more powerful model to fit the similarity on large scale datasets.
And it would be helpful if you share your SimilarityNet for ImageNet or some tricks to obtain a good SimilarityNet on large scale dataset.

Thank you.

Question about MCL Eq. 1

In the ICLR2019 paper, the paragraph below Eq. 1, you said that marginalizing Y is intractable and Y_i depends on each other, so the additional independence is introduced.

Why the computation is intractable and how does Y_i depend on each other?

Eq. 2 is a bit intuitive that I cannot figure out why it is an approximation.

two questions about the function G

Hi, i have two questions about the function G:

  1. What is the accuracy (the ability to predict the similarity between two images) of the function G ?
  2. The "sampling in the Omniglot dataloader" makes sure that each cluster samples equal number of images. However, assuming that there are 5 classes per batch and each classes has 10 images, there are 2500 pairs in total (500 similar pairs and 2000 dissimilar pairs). We can see that the imbalanced problem still exists.

Torch Error

Hi, running python demo.py gives me the following error. any idea on how to fix that?

==== Epoch:0 ====
/projects/anaconda3/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:82: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
LR: 0.001
Itr            |Batch time     |Data Time      |Loss
Traceback (most recent call last):
  File "demo.py", line 295, in <module>
    run(get_args(sys.argv[1:]))
  File "demo.py", line 236, in run
    train(epoch, train_loader, learner, args)
  File "demo.py", line 76, in train
    confusion.add(output, eval_target)
  File "/projects/unsupervised/05_test_l2c/L2C/utils/metric.py", line 69, in add
    output = output.squeeze_()
RuntimeError: set_storage_offset is not allowed on a Tensor created from .data or .detach().
If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
For example, change:
    x.data.set_(y)
to:
    with torch.no_grad():
        x.set_(y)
> pip freeze | grep torch
torch==1.2.0
torchvision==0.4.0a0+6b959ee

Theoretical Issue

Can you please tell me how you converted [sij=0]P(Yi|xi,theta)PYj|xj,theta) to (1-sij)log(1-f(xi,theta)Tf(fxj,theta))
It seems like P(Yi|xi,theta)PYj|xj,theta) ->1-f(xi,theta)Tf(fxj,theta) for no reason.
It conflicts with the theory when 2 classes are differents. Your loss function has a positive class =1 when 2 samples have the same predictions. It means the higher overlap, the loss should be lower. In your case, the lower of overlap, the loss is lower.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.