I'm trying to run your code to get a better understanding of how the ensemble attention works. I downloaded TIGGE-IFS (50 ensemble members) and ERA5 reanalysis data for 2017 and ran both through the pre-processing pipeline. Everything seems to work until i run fit_normalizer.ipynb
. This raises an error:
for ifs_data, era_data in tqdm(data_module.train_dataloader(), total=len(data_module.train_dataloader())):
print(ifs_data.shape)
print(era_data.shape)
ifs_mean = ifs_data.mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
ifs_squared_mean = ifs_data.pow(2).mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
era_mean = era_data.mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
era_squared_mean = era_data.pow(2).mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
try:
rolling_sum['ifs'] = rolling_sum['ifs']+ifs_mean
rolling_squared_sum['ifs'] = rolling_squared_sum['ifs']+ifs_squared_mean
rolling_sum['era'] = rolling_sum['era']+era_mean
rolling_squared_sum['era'] = rolling_squared_sum['era']+era_squared_mean
except:
rolling_sum['ifs'] = ifs_mean
rolling_squared_sum['ifs'] = ifs_squared_mean
rolling_sum['era'] = era_mean
rolling_squared_sum['era'] = era_squared_mean
rolling_elems += ifs_data.shape[0]
0%| | 0/29 [00:00<?, ?it/s]
torch.Size([16, 50, 3, 32, 64])
torch.Size([16, 32, 64])
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
/tmp/ipykernel_20355/3535360696.py in <module>
4 ifs_mean = ifs_data.mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
5 ifs_squared_mean = ifs_data.pow(2).mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
----> 6 era_mean = era_data.mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
7 era_squared_mean = era_data.pow(2).mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
8 try:
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)