Description
When computing the log probability with FMPE's log_prob method, the resulting probability values depend on the other input elements in the batch. The change I saw was in the order of the third or fourth decimal place.
In any case, thanks already a lot for your work on LAMPE βΊοΈ
Reproduce
Following the example, the two ways to compute log probabilities for a given configuration theta
and batch of corresponding simulated results x
produce different results:
from itertools import islice
import torch
import torch.nn as nn
import torch.optim as optim
import zuko
from lampe.data import JointLoader
from lampe.inference import FMPE, FMPELoss
from lampe.utils import GDStep
from tqdm import tqdm
LABELS = [r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"]
LOWER = -torch.ones(3)
UPPER = torch.ones(3)
prior = zuko.distributions.BoxUniform(LOWER, UPPER)
def simulator(theta: torch.Tensor) -> torch.Tensor:
x = torch.stack(
[
theta[..., 0] + theta[..., 1] * theta[..., 2],
theta[..., 0] * theta[..., 1] + theta[..., 2],
],
dim=-1,
)
return x + 0.05 * torch.randn_like(x)
theta = prior.sample()
x = simulator(theta)
loader = JointLoader(prior, simulator, batch_size=256, vectorized=True)
estimator = FMPE(3, 2, hidden_features=[64] * 5, activation=nn.ELU)
loss = FMPELoss(estimator)
optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 128)
step = GDStep(optimizer, clip=1.0) # gradient descent step with gradient clipping
estimator.train()
with tqdm(range(128), unit="epoch") as tq:
for epoch in tq:
losses = torch.stack(
[
step(loss(theta, x))
for theta, x in islice(loader, 256) # 256 batches per epoch
]
)
tq.set_postfix(loss=losses.mean().item())
scheduler.step()
theta_star = prior.sample()
X = torch.stack([simulator(theta_star) for _ in range(10)])
estimator.eval()
with torch.no_grad():
# e.g. [3.1956, 1.8184, 2.4533, 1.6461, 3.0488, 2.5868, 2.7055, 2.7679, 3.3405, 1.5554]
log_p_one_batch = estimator.flow(X).log_prob(theta_star.repeat(len(X), 1))
# e.g. [3.1978, 1.8175, 2.4526, 1.6468, 3.0495, 2.5894, 2.7065, 2.7712, 3.3385, 1.5558]
log_p_individual = [estimator.flow(x).log_prob(theta_star) for x in X]
Expected behavior
I would expect that the individual log probability values for one theta
and x
pair are not affected by the other entries in the X
batch.
This is corroborated by the official implementation not showing that behaviour when evaluating log_prob_batch
with different subsets for the batch.
In the above example, I would expect both to e.g. result in [3.1978, 1.8175, 2.4526, 1.6468, 3.0495, 2.5894, 2.7065, 2.7712, 3.3385, 1.5558]
.
Causes and solution
I have no clear intuition why that would be the case. I suspected a stochastic influence and that the FreeFormJacobianTransform
exact mode might help, but it seems to be a deterministic difference and settings exact=true
did not affect that accordingly.
I noticed that the LAMPE implementation utilizes a trigonometrical embedding of the time dimension for the vector field computation when the official implementation by the authors does not, but it's also not obvious to me that this would explain the difference.
Environment
- LAMPE version: 0.8.2
- PyTorch version: 2.3.0
- Python version: 3.10.13
- OS: Ubuntu 20.04.6 LTS