Coder Social home page Coder Social logo

Comments (7)

VincentStimper avatar VincentStimper commented on June 29, 2024

Hi @Clemente90,

thanks for your kind words and this interesting question.

The normalizing flow you are using models a continuous random variable and, hence, the log_prob values are the log density values and not probabilities. The density can be very high, e.g. if you have a uniform distribution with support on $[0, a]^d$ the density on the support is $\frac{1}{a^d}$, which can get very large if $a$ is small but positive.

As this example shows, the density values will be large if the distribution has (locally) low variance and the number of dimensions is high. In your case, you have a Gaussian mixture with the mixture components having unit variance, given your code block to evaluate the model I assume that you normalize the data during training. Since your Gaussian modes are so far apart from each other (means sampled on $[0, 40]^d$), your target is probably narrowly peaked and when sampling from the target, you will only get samples from those modes, which should have high density values.

To confirm whether the density values the flow model gives you are in the right range, you can just compute the target density. In your case, it is not sufficient to use target.log_prob, though, because the normalization will change the density as well. The log determinant for this transform should be -torch.log(torch.prod(stdev, dim=1)), which you have to subtract from what you get with target.log_prob.

Let me know whether the values match. I think they should given that the marginals match visually.

Best regards,
Vincent

from normalizing-flows.

Clemente90 avatar Clemente90 commented on June 29, 2024

Thanks for the quick response.
Would running target.log_prob() on the samples prior to normalizing work the same way?

I apologize if I am applying the log determinant incorrectly:

model = model_spline

model.eval()
model.cpu()

sample = target.sample(100).float()
sample_scaled = sample - mean
sample_scaled /= stdev

probs_from_model = model.log_prob(sample_scaled)
probs_from_target_dist = target.log_prob(sample)
probs_from_target_dist_scaled = target.log_prob(sample) - -torch.log(torch.prod(stdev))

print(probs_from_model)
print(torch.exp(probs_from_model)[0].item())

print()

print(probs_from_target_dist)
print(torch.exp(probs_from_target_dist)[0].item())

print()

print(probs_from_target_dist_scaled)
print(torch.exp(probs_from_target_dist_scaled)[0].item())

Output:

tensor([ 6.1033,  0.7600,  1.8968,  2.4201, -0.7029,  3.3429,  2.3394,  4.3479,
         4.4622,  5.7613,  2.9099,  4.3977,  2.7825,  3.5368,  3.8866,  3.3209,
         2.1554,  3.7185,  2.3784,  2.1182,  4.6006,  4.6628, -0.0999,  5.9435,
        -0.3684,  3.8626,  5.4402,  3.8243,  5.5309,  5.5979,  7.0800,  6.0813,
         4.0405,  2.9375,  4.5547,  3.7730,  4.2642,  6.3171,  4.0666,  1.8609,
         4.6220,  5.1381,  2.6969,  4.9265,  3.0525,  4.4102,  3.9816,  2.0883,
         3.9413,  5.7845,  5.9226,  4.1662,  3.3151,  3.9377,  3.5246,  2.9500,
         0.5910,  6.1415,  3.2517,  3.9900,  5.0201,  4.9463,  2.3083,  2.9241,
         4.9734,  3.5453,  3.7991,  4.3043,  2.6089,  6.0158,  0.8470,  4.8524,
         2.0211,  2.4848,  4.5143,  3.8469,  1.3725,  1.8134,  4.0068,  1.8383,
         6.9317,  4.2226,  5.3134,  7.4885,  3.0080,  4.2976,  5.6979,  0.6751,
         2.5332,  2.7945,  2.9892,  5.6636,  3.4289,  2.8324,  5.0012,  4.1906,
         2.5434,  3.5022, -1.6965,  1.9983], grad_fn=<AddBackward0>)
447.34735107421875

tensor([-11.9158, -14.8146, -14.5229, -13.3597, -17.5631, -12.3066, -15.1455,
        -10.8366, -12.3581, -11.2851, -14.1784, -14.0277, -15.8781, -12.3894,
        -11.4431, -13.8743, -14.2149, -12.9958, -14.0434, -14.6024, -12.1078,
        -11.0742, -13.8623, -11.0345, -15.7262, -12.3140, -11.2938, -13.4112,
        -12.0714, -10.9705,  -9.5615, -11.5685, -13.0625, -12.3902, -10.8086,
        -12.8052, -13.0418, -10.6799, -12.8476, -15.5437, -12.7075, -11.0828,
        -13.0369, -13.0468, -12.6904, -11.1777, -11.6199, -14.8000, -12.3532,
         -9.9093, -10.1884, -12.4668, -12.8595, -13.0815, -11.5400, -13.7892,
        -15.0917,  -9.9102, -13.6996, -12.5801, -10.9730, -12.6392, -12.3448,
        -11.8199, -11.4093, -12.8751, -13.1467, -11.4624, -13.8282, -11.6199,
        -16.3826, -11.6240, -13.2532, -15.2196, -11.3228, -12.3071, -14.5983,
        -13.7950, -11.9822, -17.0757, -11.0482, -11.4691, -10.4039,  -9.7035,
        -13.0871, -14.1413, -11.6800, -13.9928, -12.7807, -12.9996, -12.0244,
        -10.8223, -14.7816, -12.1493, -13.7105, -12.1933, -13.8362, -11.5132,
        -19.0384, -14.8261], dtype=torch.float64, grad_fn=<LogsumexpBackward0>)
6.683640800042695e-06

tensor([ 5.6426e+00,  2.7439e+00,  3.0356e+00,  4.1987e+00, -4.5920e-03,
         5.2519e+00,  2.4130e+00,  6.7219e+00,  5.2004e+00,  6.2733e+00,
         3.3801e+00,  3.5308e+00,  1.6803e+00,  5.1690e+00,  6.1153e+00,
         3.6842e+00,  3.3435e+00,  4.5626e+00,  3.5151e+00,  2.9561e+00,
         5.4506e+00,  6.4843e+00,  3.6962e+00,  6.5239e+00,  1.8323e+00,
         5.2445e+00,  6.2646e+00,  4.1473e+00,  5.4870e+00,  6.5879e+00,
         7.9970e+00,  5.9900e+00,  4.4960e+00,  5.1683e+00,  6.7499e+00,
         4.7532e+00,  4.5167e+00,  6.8786e+00,  4.7109e+00,  2.0147e+00,
         4.8510e+00,  6.4756e+00,  4.5216e+00,  4.5117e+00,  4.8681e+00,
         6.3808e+00,  5.9386e+00,  2.7585e+00,  5.2052e+00,  7.6491e+00,
         7.3700e+00,  5.0917e+00,  4.6990e+00,  4.4770e+00,  6.0185e+00,
         3.7693e+00,  2.4668e+00,  7.6482e+00,  3.8588e+00,  4.9784e+00,
         6.5854e+00,  4.9192e+00,  5.2137e+00,  5.7386e+00,  6.1491e+00,
         4.6833e+00,  4.4117e+00,  6.0961e+00,  3.7303e+00,  5.9386e+00,
         1.1759e+00,  5.9345e+00,  4.3053e+00,  2.3388e+00,  6.2356e+00,
         5.2514e+00,  2.9602e+00,  3.7634e+00,  5.5763e+00,  4.8281e-01,
         6.5103e+00,  6.0893e+00,  7.1546e+00,  7.8549e+00,  4.4713e+00,
         3.4171e+00,  5.8784e+00,  3.5657e+00,  4.7777e+00,  4.5589e+00,
         5.5341e+00,  6.7362e+00,  2.7769e+00,  5.4092e+00,  3.8480e+00,
         5.3652e+00,  3.7222e+00,  6.0453e+00, -1.4799e+00,  2.7324e+00],
       dtype=torch.float64, grad_fn=<SubBackward0>)
282.20062910317984

I apologize for my ignorance or misunderstanding.
image
Are you saying that due to the shift and scale I am applying, the values of the target density function will be higher than they are in the original distribution, like is illustrated in the cartoon above? Wouldn't the volume under the distribution always need to sum to 1.0? I assume this, so having positive values come out of log_prob doesn't seem to make sense to me.

from normalizing-flows.

VincentStimper avatar VincentStimper commented on June 29, 2024

Hi @Clemente90,

this seems fine to me and the values are now in the same range. They are not identical, but this is probably because the model does not fit the target perfectly yet, which could be overcome with a larger model and longer training.

Indeed, the volume in the distribution has to sum up to one. However, the samples you draw from the target are all located within the support of the target, i.e. where the density is high.

To check whether the density is normalized, you would have to do numerical integration, e.g. defining a grid of points covering the support of the distribution, evaluating the model (or target) at these points, summing the results, and multiplying by the volume of a single bin. This approach might be sufficient for your problem, but scales poorly to high dimensions. Alternatively, you could do importance sampling, e.g. with a uniform distribution covering the support of your model.

Best regards,
Vincent

from normalizing-flows.

Clemente90 avatar Clemente90 commented on June 29, 2024

Thank you, I appreciate your responses.

The one thing I am not understanding is, exp(0.0) == 1.0. Any log-probability value that is greater than 0.0 would imply a probability density function value greater than 1.0. So the probability distribution function could not integrate to 1.0.

from normalizing-flows.

VincentStimper avatar VincentStimper commented on June 29, 2024

Hi @Clemente90,

this is not true. In the first method to compute the integral I mentioned above, you would multiply with the volume of the bin, which will be small, pushing the value of the term in the sum below 1. To do importance sampling, you would compute the mean, so some of the terms can be larger than 1 while the result could still be 1. Feel free to implement them to convince yourself that the distribution is indeed normalized.

Best regards,
Vincent

from normalizing-flows.

Clemente90 avatar Clemente90 commented on June 29, 2024

I am sorry, you are right. It makes complete sense for those log_prob values to be greater than 0.0. That is very embarrassing to make such a basic conceptual mistake. Thank you for your patient explanations.

Best,

from normalizing-flows.

VincentStimper avatar VincentStimper commented on June 29, 2024

Hi @Clemente90,

no worries, this is indeed counter-intuitive at first glance.

I'm glad I could help. I'm closing the issue now.

Best regards,
Vincent

from normalizing-flows.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.