Coder Social home page Coder Social logo

lingkai-kong / sde-net Goto Github PK

View Code? Open in Web Editor NEW
105.0 105.0 18.0 1.04 MB

Code for paper: SDE-Net: Equipping Deep Neural network with Uncertainty Estimates

License: Apache License 2.0

Python 100.00%
deep-learning open-world-classification robustness uncertainty-quantification

sde-net's People

Contributors

lingkai-kong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

sde-net's Issues

Reproduction of the results

@Lingkai-Kong

I ran the python commands on the repo and could not find the results you quoted in the paper.

Of course it was just one run but the values seem to be too low (regarding the std deviation).

MNIST RESNET
_________________________________

Final Accuracy: 9945/10000 (99.45%)

generate log  from out-of-distribution data
calculate metrics for OOD
OOD  Performance of Baseline detector
TNR at TPR 95%:            88.783%
AUROC:                     95.939%
Detection acc:             92.169%
AUPR In:                   86.441%
AUPR Out:                  98.434%

calculate metrics for mis
mis  Performance of Baseline detector
TNR at TPR 95%:            89.791%
AUROC:                     97.510%
Detection acc:             93.041%
AUPR In:                   99.985%
AUPR Out:                  34.000%


MNIST SDENET
_________________________________

Final Accuracy: 9927/10000 (99.27%)

generate log  from out-of-distribution data
calculate metrics for OOD
OOD  Performance of Baseline detector
TNR at TPR 95%:            99.372%
AUROC:                     99.804%
Detection acc:             98.692%
AUPR In:                   99.483%
AUPR Out:                  99.887%
calculate metrics for mis
mis  Performance of Baseline detector
TNR at TPR 95%:            92.544%
AUROC:                     97.525%
Detection acc:             94.485%
AUPR In:                   99.979%
AUPR Out:                  41.739%


SVHN RESNET
_________________________________

Final Accuracy: 24609/25856 (95.18%)

generate log  from out-of-distribution data
calculate metrics for OOD
OOD  Performance of Baseline detector
TNR at TPR 95%:            66.552%
AUROC:                     94.421%
Detection acc:             90.136%
AUPR In:                   97.639%
AUPR Out:                  84.998%
calculate metrics for mis
mis  Performance of Baseline detector
TNR at TPR 95%:            64.376%
AUROC:                     90.458%
Detection acc:             85.371%
AUPR In:                   99.301%
AUPR Out:                  44.899%


SVHN SDENET
_________________________________

Final Accuracy: 24588/25856 (95.10%)

generate log  from out-of-distribution data
calculate metrics for OOD
OOD  Performance of Baseline detector
TNR at TPR 95%:            65.215%
AUROC:                     94.308%
Detection acc:             89.746%
AUPR In:                   97.694%
AUPR Out:                  84.017%
calculate metrics for mis
mis  Performance of Baseline detector
TNR at TPR 95%:            67.831%
AUROC:                     91.267%
Detection acc:             86.501%
AUPR In:                   99.270%
AUPR Out:                  48.871%

image

A serious formula problem

(1) Why does g(t,x_t) in the sde equation become a constant (x_0) in the algorithm and code, and the code is also g(x_0) during iteration, according to the use of euler-maruyama mentioned in your paper The format should not be g(x_0), and the variable g(x_k) should also be inside the loop. At the same time, the theoretical analysis formula (5) is meaningless, only the drift term.
(2) Do the two graphs at the back of Figure 3 have corresponding codes?

About the objective function for training

Hi Kong,

Thank for your sharing code!

I am interested in the objective function (equation (4) in the paper).
The last term is to maximize diffusion when data is sampled from OOD data, but in the code, both criterion2 used for loss_in, and loss_out are the same. loss_in = criterion2(predict_in, label), loss_out = criterion2(predict_out, label).

As I understand, both losses are used to minimize net output vs label, I can't see the max term.

Public release of the code

Hi,
I really enjoyed reading the paper and I'm excited about the OOD detection capabilities SDE Networks offer. Do you intend to release the code here anytime soon?

About regression task

Hi Lingkai, I have some questions when I use this to do regression task.
1. There are three arguements I don't know how to set:
The "target_scale", "self.delta" and "self.sigma", how could I set it for my own dataset.
2. About SDE-Net training and test detection:
When "not training_diffusion" in SDENet, we can get the output of mean and sigma, but what's the meaning of the "mean" and "sigma", Is this the mean or sigma value of target? But if so, the target value is a single value, how could it have mean and sigma value?
3. How could I get aleatoric uncertainty:
I've noticed that in "test_detection_sde.py", it will write out a file with model uncertainty. So if I want to write out inputs uncertainty, I just have to write the total Mean value. Is this right?
Actually, when I'm training my dataset, some metrics are not good, I just use the default setting of network, I want to know whether there is some tuning I can do to improve my network's performance.

keras implementation

thanks for providing the pytorch implementation.
will start to work on the tensorflow/keras one

About the strategy

Hi Lingkai !

Thanks for your fantastic paper. I want to consult some questions if possible.

(1) Which strategy does this network operate:

  • Strategy 1
    -Input Image → DNN → Output + Uncertainty Total (Aleatoric + Epistemic)
  • Strategy 2
    -Input Image → DNN → Output + Uncertainty Aleatoric + Uncertainty Epistemic

(2) Can this work for semantic segmentation tasks?

Which version of pytorch did you use?

I get this error with pytorch 1.7

Epoch: 0
Traceback (most recent call last):
  File "sdenet_mnist.py", line 139, in <module>
    train(epoch)
  File "sdenet_mnist.py", line 99, in train
    label = torch.full((args.batch_size,1), real_label, device=device)
RuntimeError: Providing a bool or integral fill value without setting the optional `dtype` or `out` arguments is currently unsupported. In PyTorch 1.7, when `dtype` and `out` are not set a bool fill value will return a tensor of torch.bool dtype, and an integral fill value will return a tensor of torch.long dtype.

How to let the model say 'I don't know'?

Hi Lingkai

Thanks for sharing the great code for the fantastic paper.
I want to consult some questions if possible.

  • 1 How to let the model say 'I don't know'?

For OOD samples, it is wiser to let the model say ‘I don’t know’ instead of making an absurdly wrong predictions.

How can we do it after training SDE-Net for estimating uncertain examples, especially for unseen classes?

  • 2 What does 'detection accuracy' mean for OOD task in 'calculate_log.py'?

For 'mis' task, the detection error is easy to understand.
For 'OOD' task, a trained SDE-net outputs the softmax values (without ground-truth information) for both in-domain examples and OOD examples.
I'm confused about the detection accuracy of OOD task.
Any math expression or description for OOD detection accuracy?

#calculate the minimum detection error
if task == 'OOD':
    cifar = np.loadtxt('%s/confidence_Base_In.txt'%dir_name, delimiter=',')
    other = np.loadtxt('%s/confidence_Base_Out.txt'%dir_name, delimiter=',')

Y1 = other
X1 = cifar
end = np.max([np.max(X1), np.max(Y1)])
start = np.min([np.min(X1),np.min(Y1)])
gap = (end- start)/200000

errorBase = 1.0
for delta in np.arange(start, end, gap):
    tpr = np.sum(np.sum(X1 < delta)) / np.float(len(X1))
    error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
    errorBase = np.minimum(errorBase, (tpr+error2)/2.0)
  • 3 What do the arrows in Figure 3 mean?

In Figure 3, both ID data and OOD data runs through f-net and g-net, merges with each other and finally outputs the predictions. It makes sense in the training process. However, the g-net seems useless in the test process? If we want to let the model evaluate its confidence/uncertainty for a specific example, should we use the g-net and set a threshold?

Some other potential questions may come later.
Sincerely thanks for your kindly help : )

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.