Coder Social home page Coder Social logo

12wang3 / rrl Goto Github PK

View Code? Open in Web Editor NEW
89.0 4.0 23.0 574 KB

The code of NeurIPS 2021 paper "Scalable Rule-Based Representation Learning for Interpretable Classification" and TPAMI paper "Learning Interpretable Rules for Scalable Data Representation and Classification"

License: MIT License

Python 100.00%
interpretable-classification rule-based-model representation-learning scalability interpretable-ai interpretable-ml explainable-ai explainable-ml neurips nips xai interpretable-machine-learning iml interpretability explainability neuro-symbolic-ai

rrl's Introduction

Rule-based Representation Learner

Updates

The following updates have been summarized as a paper (Learning Interpretable Rules for Scalable Data Representation and Classification) and accepted by TPAMI. 🎉🎉🎉

Compared with the previous version, we make the following significant updates to enhance RRL:

Hierarchical Gradient Grafting

  • The gradient-based discrete model training method proposed by the conference version, i.e., Single Gradient Grafting, is more likely to fail when the RRL goes deeper.
  • To tackle this problem and further improve the performance of deep RRL, we propose Hierarchical Gradient Grafting that can avoid the side effects caused by the multiple layers during training.

Novel Logical Activation Functions (NLAF)

  • NLAFs not only can handle high-dimensional features that the original logical activation functions cannot handle but also are faster and require less GPU memory. Therefore, NLAFs are more scalable.
  • Unfortunately, NLAF brings three additional hyperparameters, i.e., alpha, beta, and gamma. We recommend trying (alpha, beta, gamma) in {(0.999, 8, 1), (0.999, 8, 3), (0.9, 3, 3)}.
  • To use NLAFs, you should set the "--use_nlaf" option and set hyperparameters by "--alpha", "--beta", and "--gamma". For example:
# trained on the tic-tac-toe data set with NLAFs.
python3 experiment.py -d tic-tac-toe -bs 32 -s 1@64 -e401 -lrde 200 -lr 0.002 -ki 0 -i 0 -wd 0.001 --nlaf --alpha 0.9 --beta 3 --gamma 3 --temp 0.01 --print_rule &

Introduction

This is a PyTorch implementation of Rule-based Representation Learner (RRL) as described in NeurIPS 2021 paper Scalable Rule-Based Representation Learning for Interpretable Classification and TPAMI paper Learning Interpretable Rules for Scalable Data Representation and Classification.

drawing

RRL aims to obtain both good scalability and interpretability, and it automatically learns interpretable non-fuzzy rules for data representation and classification. Moreover, RRL can be easily adjusted to obtain a trade-off between classification accuracy and model complexity for different scenarios.

Requirements

  • torch>=1.8.0
  • torchvision>=0.9.0
  • tensorboard>=1.15.0
  • sklearn>=0.23.2
  • numpy>=1.19.2
  • pandas>=1.1.3
  • matplotlib>=3.3.2
  • CUDA>=11.1

Tuning Suggestions

  1. Initially test an RRL with a single logical layer. If the loss converges, then consider increasing the number of layers.
  2. Start with a logical layer width of 1024 to check for loss convergence, then reduce width based on interpretability needs.
  3. Temperature (--temp) significantly affects performance. We suggest trying each of the following values: {1, 0.1, 0.01}.
  4. For NLAF, we suggest testing each of the following combinations: (alpha, beta, gamma) in {(0.999, 8, 1), (0.999, 8, 3), (0.9, 3, 3)}.
  5. Begin with learning rates of 0.002 and 0.0002, and then fine-tune as necessary.
  6. Don't forget to try the --save_best option.

Run the demo

We need to put the data sets in the dataset folder. You can specify one data set in the dataset folder and train the model as follows:

# trained on the tic-tac-toe data set with one GPU.
python3 experiment.py -d tic-tac-toe -bs 32 -s 1@16 -e401 -lrde 200 -lr 0.002 -ki 0 -i 0 -wd 0.0001 --print_rule &

The demo reads the data set and data set information first, then trains the RRL on the training set. During the training, you can check the training loss and the evaluation result on the validation set by:

tensorboard --logdir=log_folder

drawing

The training log file (log.txt) can be found in a folder created in log_folder. In this example, the folder path is

log_folder/tic-tac-toe/tic-tac-toe_e401_bs32_lr0.002_lrdr0.75_lrde200_wd0.0001_ki0_rc0_useNOTFalse_saveBestFalse_useNLAFFalse_estimatedGradFalse_useSkipFalse_alpha0.999_beta8_gamma1_temp1.0_L1@16

After training, the evaluation result on the test set is shown in the file test_res.txt:

[INFO] - On Test Set:
        Accuracy of RRL  Model: 1.0
        F1 Score of RRL  Model: 1.0

Moreover, the trained RRL model is saved in model.pth, and the discrete RRL is printed in rrl.txt:

RID class_negative(b=-0.3224) class_positive(b=-0.1306) Support Rule
(-1, 3) -0.7756 0.9354 0.0885 3_x & 6_x & 9_x
(-1, 0) -0.7257 0.8921 0.1146 1_x & 2_x & 3_x
(-1, 5) -0.6162 0.4967 0.0677 2_x & 5_x & 8_x
...... ...... ...... ...... ......

Your own data sets

You can use the demo to train RRL on your own data set by putting the data and data information files in the dataset folder. Please read DataSetDesc for a more specific guideline.

Available arguments

List all the available arguments and their default values by:

$ python3 experiment.py --help
usage: experiment.py [-h] [-d DATA_SET] [-i DEVICE_IDS] [-nr NR] [-e EPOCH] [-bs BATCH_SIZE] [-lr LEARNING_RATE] [-lrdr LR_DECAY_RATE]
                     [-lrde LR_DECAY_EPOCH] [-wd WEIGHT_DECAY] [-ki ITH_KFOLD] [-rc ROUND_COUNT] [-ma MASTER_ADDRESS] [-mp MASTER_PORT]
                     [-li LOG_ITER] [--nlaf] [--alpha ALPHA] [--beta BETA] [--gamma GAMMA] [--temp TEMP] [--use_not] [--save_best] [--skip]
                     [--estimated_grad] [--weighted] [--print_rule] [-s STRUCTURE]

optional arguments:
  -h, --help            show this help message and exit
  -d DATA_SET, --data_set DATA_SET
                        Set the data set for training. All the data sets in the dataset folder are available. (default: tic-tac-toe)
  -i DEVICE_IDS, --device_ids DEVICE_IDS
                        Set the device (GPU ids). Split by @. E.g., 0@2@3. (default: None)
  -nr NR, --nr NR       ranking within the nodes (default: 0)
  -e EPOCH, --epoch EPOCH
                        Set the total epoch. (default: 41)
  -bs BATCH_SIZE, --batch_size BATCH_SIZE
                        Set the batch size. (default: 64)
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
                        Set the initial learning rate. (default: 0.01)
  -lrdr LR_DECAY_RATE, --lr_decay_rate LR_DECAY_RATE
                        Set the learning rate decay rate. (default: 0.75)
  -lrde LR_DECAY_EPOCH, --lr_decay_epoch LR_DECAY_EPOCH
                        Set the learning rate decay epoch. (default: 10)
  -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
                        Set the weight decay (L2 penalty). (default: 0.0)
  -ki ITH_KFOLD, --ith_kfold ITH_KFOLD
                        Do the i-th 5-fold validation, 0 <= ki < 5. (default: 0)
  -rc ROUND_COUNT, --round_count ROUND_COUNT
                        Count the round of experiments. (default: 0)
  -ma MASTER_ADDRESS, --master_address MASTER_ADDRESS
                        Set the master address. (default: 127.0.0.1)
  -mp MASTER_PORT, --master_port MASTER_PORT
                        Set the master port. (default: 0)
  -li LOG_ITER, --log_iter LOG_ITER
                        The number of iterations (batches) to log once. (default: 500)
  --nlaf                Use novel logical activation functions to take less time and GPU memory usage. We recommend trying (alpha, beta, gamma) in {(0.999, 8, 1), (0.999, 8, 3), (0.9, 3, 3)} (default: False)
  --alpha ALPHA         Set the alpha for NLAF. (default: 0.999)
  --beta BETA           Set the beta for NLAF. (default: 8)
  --gamma GAMMA         Set the gamma for NLAF. (default: 1)
  --temp TEMP           Set the temperature. (default: 1.0)
  --use_not             Use the NOT (~) operator in logical rules. It will enhance model capability but make the RRL more complex. (default: False)
  --save_best           Save the model with best performance on the validation set. (default: False)
  --skip                Use skip connections when the number of logical layers is greater than 2. (default: False)
  --estimated_grad      Use estimated gradient. (default: False)
  --weighted            Use weighted loss for imbalanced data. (default: False)
  --print_rule          Print the rules. (default: False)
  -s STRUCTURE, --structure STRUCTURE
                        Set the number of nodes in the binarization layer and logical layers. E.g., 10@64, 10@64@32@16. (default: 5@64)

Citation

If our work is helpful to you, please kindly cite our paper as:

@article{wang2021scalable,
  title={Scalable Rule-Based Representation Learning for Interpretable Classification},
  author={Wang, Zhuo and Zhang, Wei and Liu, Ning and Wang, Jianyong},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  year={2021}
}
@article{wang2024learning,
  title={Learning Interpretable Rules for Scalable Data Representation and Classification},
  author={Wang, Zhuo and Zhang, Wei and Liu, Ning and Wang, Jianyong},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
  volume={46},
  number={02},
  pages={1121--1133},
  year={2024},
  publisher={IEEE Computer Society}
}

License

MIT license

rrl's People

Contributors

12wang3 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

rrl's Issues

请教一下一个初级的问题

rrl.text中的结果的每个字段是什么含义能够解释一下吗?RID我理解是规则ID,但他的取值是什么意思呢,这个标签特征的分类|class_negative(b=-2.1733) | class_positive(b=1.9689)括号中的b是代表平均值吗,每个标签的值是结果的概率吗?那个support又代表什么意思呢?

请教比较初级的实验结果的问题

希望请教一下rrl.text中的实验结果,我理解support越高应该是学习到的规则越可信,然后如何通过weight和bias看出该条规则是否越可信/越重要呢?另外激活的规则是什么意思? 问题比较小白请多包涵!

Cannot run it on windows

Hi,

I was trying to give try to this implementation after reading the paper. I installed all the dependencies in a Conda env on a Window PC. However, I am having the following error when I run the experiment:

$ python experiment.py -d tic-tac-toe -bs 32 -s 1@16 -e401 -lrde 200 -lr 0.002 -ki 0 -wd 0.0001 --print_rule -i 0
C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\distributed_c10d.py:608: UserWarning: Attempted 
to get default timeout for nccl backend, but NCCL support is not compiled
  warnings.warn("Attempted to get default timeout for nccl backend, but NCCL support is not compiled")
[W socket.cpp:697] [c10d] The client socket has failed to connect to [A2207000547.china.huawei.com]:47339 (system error: 10049 - The requested address is not valid in its context.).
Traceback (most recent call last):
  File "C:\Users\m00827298\Codes\RRL\experiment.py", line 174, in <module>
    train_main(rrl_args)
  File "C:\Users\m00827298\Codes\RRL\experiment.py", line 167, in train_main
    mp.spawn(train_model, nprocs=args.gpus, args=(args,))
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 241, in spawn       
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 197, in start_processes
    while not context.join():
              ^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 158, in join        
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\multiprocessing\spawn.py", line 68, in _wrap        
    fn(i, *args)
  File "C:\Users\m00827298\Codes\RRL\experiment.py", line 57, in train_model
    dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\c10d_logger.py", line 86, in wrapper    
    func_return = func(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\distributed_c10d.py", line 1177, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\rendezvous.py", line 246, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\m00827298\AppData\Local\miniconda3\envs\rrl\Lib\site-packages\torch\distributed\rendezvous.py", line 174, in _create_c10d_store
    return TCPStore(
           ^^^^^^^^^
torch.distributed.DistNetworkError: Unknown error

请教一个关于运行效率的问题。

请教一下,我用您的测试数据可以跑得起来,但是执行效率看起来有点低,您是否知道可能是什么问题,如何排查。

  • 998条数据的tic-tac-toe数据集,一个epoch看起来需要5min

image

  • GPU看起来完全没有用起来,使用率一直是0,但是显存被使用了

image

  • Loss看起来也正常在下降,但是效率看起来太低了。
    image

想请教一下您,这样的执行效率是否正常,是否是可能因为环境问题或者配置错误导致的效率低下呢。

Can RRL be used for regression?

Hi authors,

Thanks for this great work! I think it is very helpful for data analysis.
I wonder if it can be used for regression and how to perform that?

Best,

Mx

How to calculate the weights of the rule and class?

I'm a little confused by the output. rrl.txt as follows:

RID final_status_0(b=0.3015) final_status_1(b=-0.0915) Support Rule
(-1, 0) 1.6603 -1.6603 0.6216 a_1 & c > 0.002
(-1, 3) -1.4654 1.4654 0.0541 a_3 & b <= 2.324 & d <= 0.013
(-1, 2) -0.6511 0.6511 0.3357 b > 2.324 & d <= 0.013
(-1, 4) 0.6015 -0.6015 0.0356 a_3 & b > 2.324
(-1, 1) 0.5441 -0.5441 0.9325 c > 0.002
(-1, 5) -0.2566 0.2566 0.7992 d <= 0.013
############################################################

Why are the absolute values of the final_status_* columns the same? How should I calculate the weight of each rule under different classes?

参数问题

您好,我是一名计算机专业大四学生,正在复现这个实验,其他的数据集没有跑出预期的效果,请问可以麻烦您提供下其他数据集的参数配置吗,麻烦了,谢谢您!

个人的一些小问题,希望能得到解答~

最近在学习大佬的这篇文章以及源码,现在有几个小问题希望能得到您的解答。
(1) RRL类构造函数里的left、right参数没有怎么看懂是什么作用
image

(2) 同样也是RRL构造函数中的use_not参数,是可以让解释集规则中包含一些“~“规则,但是好像论文原文中没有提到这个细节。

(3) 第四层开始每一层的输入都要包含前两层的输出,这个细节在文中图片里有体现但是好像没有文字描述这个过程,想知道这样组合输入的原因是什么呢。

(4) estimated_grad参数用于选择conjunction_layer和disjunction_layer输出时的激活函数,我发现EstimatedProduct和Product只有backward()函数不同,这两个的差异是否就是对应论文这句话的描述呢,EstimatedProduct的反向传播时在导数的基础外又套了一个自定义的激活函数image
image

(5) mllp为权重连续的version,rrl为mllp权重离散化后的离散version,在反向传播训练的时候第一步先将 rrl的loss关于rrl的y_pred的导数求出,之后按照常规的mllp的y_pred对参数求导,实现梯度嫁接的过程:
image
在代码实现中,image,backward内传入的参数就是rrl的loss关于rrl的y_pred的导数吗?但是该导数是如何推出的呢?

(6) mllp是该模型的连续权重值版本(权重都位于[0-1]之间,用于训练),rrl是mllp的权重离散化后的二值权重版本(权重都属于{0,1},以0.5为分割阈值点,用于训练、测试、提取解释规则),离散version在参数反向传播过程中只参与了一小部分,大部分还是根据连续版本的mllp来调整参数的,但是在实际训练过程中发现mllp的loss很难收敛而且loss要比rrl高很多,常理来说应该连续版本的mllp的性能会比离散化后的rrl高的吧,这一点不太明白。

不知道我对这篇文章理解的是否到位,可能描述的不太清楚,希望大佬能够抽空解答一下问题,非常感谢啦。

您好,请教关于实验复现的问题

您好,我对您这篇工作很感兴趣,最近在复现实验部分,有些问题想请教下:

1)您论文的Appendix CParameter Settings一节中写道:

The number of nodes in logical layers ranges from 16 to 4096 depending on the number of binary features of the data set and the model complexity we need.
请问您是如何根据binary features的数量来确定logical layers的节点数量呢?换句话说,我在调这个参数的时候比较纠结,因为感觉它范围有些大,能请您给点经验吗?

2) 您论文的4.3节展示了模型复杂度和表现的关系。其中模型复杂度是用log(#edges)来表示的,我在代码中似乎没有找到统计边数的对应实现?也可能是我看漏了,恳请您能指出,感谢!

3)另外,可否请问下您在 chessbank-marketing 两个数据集上的参数设置?我在训练chess数据集上的模型时尝试了多种参数组合却依然无法收敛; bank-marketing虽然结果与您论文展示的接近,但是学出的规则却与您在**Figure4(b)**呈现的大相径庭(我学习出的规则完全没有balance这一项)

感谢您的时间 :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.