Coder Social home page Coder Social logo

club's Introduction

Hi there 👋

I am Pengyu Cheng, a researcher in NLP and ML. Here are some facts about me:

  • I am currently at Tencent AI Lab, primarily working on LLM training, AI agents, and dialogue systems.
  • I have been experienced in research and projects about controllable generation, interpretability, and fairness of NLP.
  • I am also interested in probabilistic and information-theoretic machine learning methods.
  • I received my Ph.D. degree from Duke University in 2021, advised by Dr. Lawrence Carin.
  • I graduated from the Department of Mathematical Sciences at Tsinghua University in 2017, advised by Dr. Jiwen Lu.

club's People

Contributors

jiachangliu avatar linear95 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

club's Issues

DA experiments : CIFAR and STL

Hi,

Thanks for the work and releasing the code. I am currently trying to reproduce DA experiments, and by reading the imageloader.py file, I noticed you are dropping samples from class 6 for both CIFAR10 and STL (unless I'm wrong). I was wondering whether there's a particular reason to do so ? Thanks in advance !

EDIT: hadn't noticed the ordering of classes was different / some class only existed in a single dataset.
Best,
Malik

About parameters updating of CLUB-net and other modules

Hi, thanks for releasing this wonderful code! I have several questions on how the parameters are updated for CLUB-net and other modules during training. Based on your main_DANN.py, I think the updating order is: (1) CLUB-net parameters are updated first for maximizing log-likelihood of q(y|x); (2) Then the content-encoder/classifier parameters are updated for minimizing the content classification loss and CLUB of mutual information; (3) Domain-encoder/classifier parameters are updated for minimizing the domain classification loss. My questions are:
(a) At step (1), will the content-encoder/classifier and domain-encoder/classifier parameters be updated?
(b) Will the CLUB-net parameters be updated at the step (2)?
(c) Will the Domain-encoder/classifier parameters be updated at the step (2)?

Equation issue in mi_minimization.ipynb

Thanks for your great work in MI estimation!
I was trying to understand and run the mi_minimization.ipynb, but had some doubts about the given equation right at the beginning.
By matrix multiplication, with S =
image,
shouldn't the
image produce
[ 1, 0,]
[ 0, 1]
, which is non-parametric?
How exactly is
image
being produced by this?
I might have misunderstood this equation.
Could you provide any clues?

confused about logvar

您好,我看您源码中计算loglikelood时,最大化似然函数应该是
-(mu - y_samples)**2 /logvar.exp()
而您使用的方式是
(-(mu - y_samples)**2 /logvar.exp()-logvar)
我想请问后面添加的logvar的原因是什么呢?由于它训练起来很可能是负数,导致整个loglikelood一直保持负数水平。

code requests

Hi, Pengyu~
I'm interested in mutual information minimization and its application. Thanks a lot for your excellent work in this area.
As for the application of mutual information minimization, I came across your work on ACL2020'Improving Disentangled Text Representation Learning with Information-Theoretic Guidance' and I found it really fascinating!
Could you please share the code with me so that I can better understand the implementation?

Thank you!

Hi, thanks for the good work. I have a general question: according to your code, the positive term in the pytorch version minors a term of logvar but in ther tensorflow version it doesn't. Does it remain any tips in this two versions? And I also encounter a problem in MI minimization that the MI in the earlier training epoches is always <0, is it resonable and any tips to slove it?

    Hi, thanks for the good work. I have a general question: according to your code, the positive term in the pytorch version minors a term of logvar but in ther tensorflow version it doesn't. Does it remain any tips in this two versions? And I also encounter a problem in MI minimization that the MI in the earlier training epoches is always <0, is it resonable and any tips to slove it?

Originally posted by @bonehan in #12 (comment)

can I use it for pytorch loss function?

Hi,

output1 = model1(input)
output2 = model2(input)

I want to calculate the MI loss for output1 and output2
can I use your toolbox to achieve this purpose?
if yes, then can you please provide an example code snippet?
Thanks!

The symmetric problem about the CLUB MI estimator?

Thanks for your sharing. Your work is very interesting. The standard MI is symmetric but CLUB MI estimator is not. Have you tried building two variational estimators between x,y, which estimate x->y and y->x mi info? Thanks again.

The loss training log-likelihood when (x, y) is absolutely independent

As I know, in section 3.2, we need to learn a variational distribution $q_\theta$(y|x), to approximate $p(y|x)$. If we could learn this distribution that is good enough, the MI estimator would good. How ever, when (x, y) is absolutely independent, we can not do it. It means the loss when training log-likelihood will explode. Did I understand it right ?

Questions about the bound

Dear author,
Thanks for your great work and implementation. I have a similar network architecture like the MI Minimization in Domain Adaptation in the paper. When I use the CLUB to minimize the mutual information between content features and domain features, the bound still reduces at the first few epochs. However, it can reduces to less than 0 at an epoch and becomes larger and larger at the following epochs. Have you encounterd such situation? Do you have any advice for this? Thank you very much!

Understanding question - what value to take of the estimator while evaluating?

First, thanks or this great work and implementation - I want to use it in my own work.
I have a basic question about the implementation:
assume I have fixed embedding (size 512) with many samples (about 2 million)

I saw in the examples that the values of the MI is changing through the optimization, and moreover the values have high variance but extremally good MSE.

As I understand I will use all the 2 million samples in order to train the CLUB estimator - when is the best time to take the evaluation of the MI? is it best to monitor the loss in order to see it not changing or other measure? what is your suggestion ? and then what portion of the 2 million examples will you use for the evaluation of the true MI? all of them? and then taking the MSE of all the examples?

second question, regarding the architecture of the hidden layer and the network, any suggestion about that for the case I have two variables with 512 dim each?

the last question regarding the robustness of the optimizer, lets assume I will change the two vectors in time, optimizing them for a different task, and I will want to measure the MI again after changing them, will you initialize the optimizer for measuring the MI for the modified vectors or use the last optimizer that was trained?
thanks!

About the computation of loglikeli

Hello,
I found the computation of loglikeli is different from the paper appendix D. Implementation Details. It will be the same if the -logvar in the code below be removed.

return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)

And the loglikeli will always be a large positive value if I use the code with -logvar.
I wonder whether I should remove the -logvar. Thanks!

The reliability of this method.

I just modified the dimension of samples to 10, and the results were not good for this situation.
UV{X NMYBH)_XKA1G2SB3HT
So does CLUB need to tune hyperparameters carefully?
我需要估计一个1维的c与10维的z之间的互信息,我发现CLUB估计的互信息有很大偏差。请问是CLUB存在一些使用上的限制条件还是对每种情况都需要调整参数?或者估计I(ci;z)的建议。

实验中InfoNCE与CLUB的对比。
image

About the logvar prediction

Thank you for your excellent code! I have encountered some problem when I use the mutual information constraint in a speech processing task. In the process of the training, I found that the logvar prediction network, whose last layer is 'Tanh', always output the '-1', no matter what the input is. And the overall mutual information prediction network seems to lose effect, as the loglikelihood of the positive sample in the training batch is all very small value, something like -1,000,000. Does other user meet this problems before? Or do you have any advice? Thank you a lot!

Yours,
Daxin

Target Acc only achieves 0.79 for Domain Adaptation experiment on dataset Mnist to MnistM.

Thanks for your paper and code, which are really enlightening and helpful. We test the MI_DA code on dataset Mnist to MnistM, however, the mi_loss did not decrease at all and the best acc on target only achieved 0.79. Since the 7700 iter, the acc on target did not increase any more.
We used the following commond:
python main_DANN.py --data_path /path/to/data_folder/ --save_path /path/to/save_dir/ --source mnist --target mnistn
and mnistm instruction. Why, could you please give some more tips to help ?
some log as following:
mnist-mnistm_DANN_0.1 iter 5 mi_loss: -0.9424 p_loss:2.5805 p_acc:0.1406 d_loss: 0.1033 d_acc: 0.9922
mnist-mnistm_DANN_0.1 iter 10 mi_loss: -1.4276 p_loss:1.7437 p_acc:0.4688 d_loss: 0.0160 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 15 mi_loss: -0.4924 p_loss:1.1728 p_acc:0.6250 d_loss: 0.0263 d_acc: 0.9922
mnist-mnistm_DANN_0.1 iter 20 mi_loss: -0.4116 p_loss:0.8434 p_acc:0.7969 d_loss: 0.0077 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 25 mi_loss: -0.4137 p_loss:0.6652 p_acc:0.8281 d_loss: 0.0045 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 30 mi_loss: 0.1388 p_loss:0.4122 p_acc:0.9219 d_loss: 0.0025 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 35 mi_loss: -0.3266 p_loss:0.6541 p_acc:0.8438 d_loss: 0.0019 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 40 mi_loss: -0.3420 p_loss:0.3561 p_acc:0.9375 d_loss: 0.0206 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 45 mi_loss: 0.2164 p_loss:0.4421 p_acc:0.8750 d_loss: 0.0063 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 50 mi_loss: 0.1182 p_loss:0.3425 p_acc:0.9062 d_loss: 0.0145 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 55 mi_loss: 0.2110 p_loss:0.2694 p_acc:0.9531 d_loss: 0.0056 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 60 mi_loss: 0.0381 p_loss:0.1792 p_acc:0.9688 d_loss: 0.0045 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 65 mi_loss: 0.0803 p_loss:0.2354 p_acc:0.9219 d_loss: 0.0581 d_acc: 0.9922
mnist-mnistm_DANN_0.1 iter 70 mi_loss: 0.0865 p_loss:0.3442 p_acc:0.8594 d_loss: 0.0034 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 75 mi_loss: 0.0698 p_loss:0.2371 p_acc:0.9062 d_loss: 0.0033 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 80 mi_loss: 0.0388 p_loss:0.1283 p_acc:0.9844 d_loss: 0.0020 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 85 mi_loss: 0.1110 p_loss:0.2033 p_acc:0.9375 d_loss: 0.0011 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 90 mi_loss: 0.3213 p_loss:0.1764 p_acc:0.9688 d_loss: 0.0025 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 95 mi_loss: -0.1905 p_loss:0.2329 p_acc:0.9375 d_loss: 0.0010 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 100 mi_loss: -0.1830 p_loss:0.2695 p_acc:0.9219 d_loss: 0.0007 d_acc: 1.0000
src valid: 0.9604 tgt valid: 0.4618 tgt test: 0.4432 best: 0.4432
...
mnist-mnistm_DANN_0.1 iter 1005 mi_loss: -0.0052 p_loss:0.0333 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1010 mi_loss: 0.1306 p_loss:0.0836 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1015 mi_loss: 0.2094 p_loss:0.0845 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1020 mi_loss: 0.2839 p_loss:0.0486 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1025 mi_loss: 0.1201 p_loss:0.0662 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1030 mi_loss: 0.2046 p_loss:0.1356 p_acc:0.9531 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1035 mi_loss: -0.1627 p_loss:0.0418 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1040 mi_loss: 0.1611 p_loss:0.0498 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1045 mi_loss: 0.1236 p_loss:0.0359 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1050 mi_loss: 0.1721 p_loss:0.0430 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1055 mi_loss: 0.3918 p_loss:0.1005 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1060 mi_loss: -0.0353 p_loss:0.1065 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1065 mi_loss: 0.1441 p_loss:0.1181 p_acc:0.9531 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1070 mi_loss: -0.5737 p_loss:0.1507 p_acc:0.9531 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1075 mi_loss: 0.2092 p_loss:0.0518 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1080 mi_loss: 0.2084 p_loss:0.1272 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1085 mi_loss: -0.1327 p_loss:0.0545 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1090 mi_loss: -0.1265 p_loss:0.0931 p_acc:0.9531 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1095 mi_loss: 0.1478 p_loss:0.0519 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 1100 mi_loss: -0.1057 p_loss:0.0659 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
src valid: 0.9864 tgt valid: 0.6650 tgt test: 0.6611 best: 0.6725
...
mnist-mnistm_DANN_0.1 iter 2005 mi_loss: 0.0830 p_loss:0.0511 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2010 mi_loss: 0.1078 p_loss:0.0347 p_acc:0.9844 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2015 mi_loss: 0.0853 p_loss:0.0646 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2020 mi_loss: 0.0153 p_loss:0.0474 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2025 mi_loss: -0.0367 p_loss:0.1069 p_acc:0.9531 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2030 mi_loss: -0.0450 p_loss:0.0501 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2035 mi_loss: 0.1163 p_loss:0.0121 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2040 mi_loss: -0.0938 p_loss:0.0082 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2045 mi_loss: 0.1360 p_loss:0.0125 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2050 mi_loss: 0.2207 p_loss:0.0161 p_acc:1.0000 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2055 mi_loss: -0.0965 p_loss:0.0149 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2060 mi_loss: 0.0843 p_loss:0.0110 p_acc:1.0000 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2065 mi_loss: -0.0486 p_loss:0.0794 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2070 mi_loss: 0.0611 p_loss:0.0556 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2075 mi_loss: 0.0393 p_loss:0.2532 p_acc:0.9219 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2080 mi_loss: 0.1374 p_loss:0.0399 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2085 mi_loss: 0.1120 p_loss:0.0161 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2090 mi_loss: 0.1085 p_loss:0.0277 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2095 mi_loss: -0.0497 p_loss:0.0451 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2100 mi_loss: 0.1559 p_loss:0.0184 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
src valid: 0.9908 tgt valid: 0.7214 tgt test: 0.7189 best: 0.7284
...
mnist-mnistm_DANN_0.1 iter 2905 mi_loss: 0.1074 p_loss:0.0956 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2910 mi_loss: 0.1063 p_loss:0.0166 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2915 mi_loss: 0.0295 p_loss:0.0335 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2920 mi_loss: -0.0136 p_loss:0.0238 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2925 mi_loss: -0.0305 p_loss:0.0523 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2930 mi_loss: -0.1295 p_loss:0.0290 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2935 mi_loss: -0.0361 p_loss:0.0067 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2940 mi_loss: 0.0729 p_loss:0.0424 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2945 mi_loss: 0.0316 p_loss:0.0512 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2950 mi_loss: -0.0648 p_loss:0.0263 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2955 mi_loss: -0.1093 p_loss:0.0278 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2960 mi_loss: 0.0914 p_loss:0.0800 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2965 mi_loss: 0.0647 p_loss:0.0458 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2970 mi_loss: 0.0216 p_loss:0.0295 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2975 mi_loss: 0.0742 p_loss:0.0486 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2980 mi_loss: -0.0668 p_loss:0.0755 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2985 mi_loss: -0.0981 p_loss:0.0309 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2990 mi_loss: 0.1510 p_loss:0.1314 p_acc:0.9531 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 2995 mi_loss: 0.0648 p_loss:0.0318 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 3000 mi_loss: -0.0805 p_loss:0.0105 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
src valid: 0.9904 tgt valid: 0.7352 tgt test: 0.7315 best: 0.7392
...
mnist-mnistm_DANN_0.1 iter 5005 mi_loss: 0.0365 p_loss:0.0057 p_acc:1.0000 d_loss: 0.0020 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5010 mi_loss: 0.1144 p_loss:0.0103 p_acc:1.0000 d_loss: 0.0007 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5015 mi_loss: 0.0933 p_loss:0.0057 p_acc:1.0000 d_loss: 0.0024 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5020 mi_loss: 0.1498 p_loss:0.0209 p_acc:0.9844 d_loss: 0.0005 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5025 mi_loss: -0.0083 p_loss:0.0251 p_acc:0.9844 d_loss: 0.0022 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5030 mi_loss: 0.0578 p_loss:0.0156 p_acc:1.0000 d_loss: 0.0226 d_acc: 0.9922
mnist-mnistm_DANN_0.1 iter 5035 mi_loss: -0.0284 p_loss:0.0148 p_acc:1.0000 d_loss: 0.0006 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5040 mi_loss: 0.0945 p_loss:0.0548 p_acc:0.9844 d_loss: 0.0100 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5045 mi_loss: 0.1155 p_loss:0.0210 p_acc:0.9844 d_loss: 0.0007 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5050 mi_loss: -0.0337 p_loss:0.0139 p_acc:1.0000 d_loss: 0.0006 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5055 mi_loss: 0.0090 p_loss:0.0299 p_acc:1.0000 d_loss: 0.0008 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5060 mi_loss: 0.1848 p_loss:0.0505 p_acc:0.9844 d_loss: 0.0007 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5065 mi_loss: -0.0112 p_loss:0.0187 p_acc:0.9844 d_loss: 0.0618 d_acc: 0.9844
mnist-mnistm_DANN_0.1 iter 5070 mi_loss: 0.0610 p_loss:0.0212 p_acc:1.0000 d_loss: 0.0148 d_acc: 0.9922
mnist-mnistm_DANN_0.1 iter 5075 mi_loss: 0.0432 p_loss:0.0338 p_acc:0.9844 d_loss: 0.0022 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5080 mi_loss: -0.0142 p_loss:0.0224 p_acc:0.9844 d_loss: 0.0061 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5085 mi_loss: 0.0673 p_loss:0.0254 p_acc:0.9844 d_loss: 0.0012 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5090 mi_loss: -0.0667 p_loss:0.0071 p_acc:1.0000 d_loss: 0.0006 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5095 mi_loss: 0.0387 p_loss:0.0437 p_acc:0.9844 d_loss: 0.0006 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 5100 mi_loss: -0.0622 p_loss:0.0460 p_acc:0.9844 d_loss: 0.0056 d_acc: 1.0000
src valid: 0.9916 tgt valid: 0.7566 tgt test: 0.7519 best: 0.7639
...
mnist-mnistm_DANN_0.1 iter 7605 mi_loss: 0.1248 p_loss:0.1195 p_acc:0.9688 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7610 mi_loss: -0.0163 p_loss:0.0107 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7615 mi_loss: 0.0306 p_loss:0.0046 p_acc:1.0000 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7620 mi_loss: 0.0334 p_loss:0.0022 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7625 mi_loss: 0.0361 p_loss:0.0202 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7630 mi_loss: 0.1049 p_loss:0.0154 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7635 mi_loss: 0.0222 p_loss:0.0645 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7640 mi_loss: 0.0745 p_loss:0.0103 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7645 mi_loss: -0.0437 p_loss:0.0088 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7650 mi_loss: 0.0109 p_loss:0.0111 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7655 mi_loss: 0.0242 p_loss:0.0113 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7660 mi_loss: 0.0371 p_loss:0.0839 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7665 mi_loss: 0.0469 p_loss:0.0180 p_acc:1.0000 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7670 mi_loss: 0.0002 p_loss:0.0495 p_acc:0.9844 d_loss: 0.0003 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7675 mi_loss: 0.0007 p_loss:0.0142 p_acc:1.0000 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7680 mi_loss: -0.0140 p_loss:0.0402 p_acc:0.9688 d_loss: 0.0002 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7685 mi_loss: 0.0904 p_loss:0.0482 p_acc:0.9688 d_loss: 0.0016 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7690 mi_loss: 0.0314 p_loss:0.0375 p_acc:0.9844 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7695 mi_loss: 0.1335 p_loss:0.0131 p_acc:1.0000 d_loss: 0.0001 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 7700 mi_loss: -0.0124 p_loss:0.0778 p_acc:0.9844 d_loss: 0.0003 d_acc: 1.0000
src valid: 0.9940 tgt valid: 0.7930 tgt test: 0.7907 best: 0.7907
...
mnist-mnistm_DANN_0.1 iter 10105 mi_loss: 0.0118 p_loss:0.0226 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10110 mi_loss: 0.0434 p_loss:0.0452 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10115 mi_loss: 0.0104 p_loss:0.0040 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10120 mi_loss: 0.0316 p_loss:0.0234 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10125 mi_loss: 0.1281 p_loss:0.0341 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10130 mi_loss: -0.0660 p_loss:0.0087 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10135 mi_loss: -0.0071 p_loss:0.0055 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10140 mi_loss: -0.0431 p_loss:0.0118 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10145 mi_loss: -0.0692 p_loss:0.0028 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10150 mi_loss: 0.0551 p_loss:0.0390 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10155 mi_loss: 0.0370 p_loss:0.0023 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10160 mi_loss: 0.1024 p_loss:0.0493 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10165 mi_loss: -0.0352 p_loss:0.0212 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10170 mi_loss: -0.0365 p_loss:0.0044 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10175 mi_loss: 0.0778 p_loss:0.0028 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10180 mi_loss: 0.0429 p_loss:0.0024 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10185 mi_loss: -0.0403 p_loss:0.1771 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10190 mi_loss: -0.0852 p_loss:0.0144 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10195 mi_loss: -0.0129 p_loss:0.0045 p_acc:1.0000 d_loss: 0.0000 d_acc: 1.0000
mnist-mnistm_DANN_0.1 iter 10200 mi_loss: -0.0011 p_loss:0.0461 p_acc:0.9844 d_loss: 0.0000 d_acc: 1.0000
src valid: 0.9950 tgt valid: 0.7780 tgt test: 0.7763 best: 0.7907
...

关于两个tensor的互信息

您好!论文中表达CLUB是一个互信息的上界,那么我要最大化两个tensor的互信息是不是可以直接最大化CLUB互信息呢?
比如 maxCLUB(x1,x2)

或者说这个互信息上界能够求出两个tensor之间的互信息最小化?最小化这个上界即就是最小化两者之间的互信息,是否
能这样理解呢!

Negative Mutual Information Values in Feature Decoupling with MI Minimization

Hello,

I've been exploring feature decoupling using mutual information (MI) minimization and came across your implementation in mi_minimization.ipynb. Inspired by your approach, I've adapted the code for my scenario, where both content and style tensors have dimensions N_sample x 64. I've also adjusted the hidden layer dimension to 128. The modification I made is as follows:

for j in range(5):
mi_estimator.train()
mi_loss = mi_estimator.learning_loss(content, style)
mi_optimizer.zero_grad()
mi_loss.backward()
mi_optimizer.step()

However, I've encountered a situation where the calculated mutual information turns out to be negative. Could you help me understand the potential reasons for this negative mutual information value? Is there something I might be missing in the implementation or any specific aspect I should consider adjusting?

Thank you for your time and assistance.

Query about Example Training

Thanks for your nice work.
In your provided mi_minimization.ipynb, you are optimizing the mi estimator 5 times more than the sampler estimator. Did you try to completely break the training process. First, optimizing the mi_estimator and the minimize mi between the random variables?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.