Coder Social home page Coder Social logo

tobifinn / ensemble_transformer Goto Github PK

View Code? Open in Web Editor NEW
12.0 1.0 1.0 853 KB

Official PyTorch implementation of "Self-Attentive Ensemble Transformer: Representing Ensemble Interactions in Neural Networks for Earth System Models".

License: MIT License

Python 5.26% Jupyter Notebook 94.74%
machine-learning ensemble earth-system-model

ensemble_transformer's People

Contributors

tobifinn avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

ensemblexinxin

ensemble_transformer's Issues

expected shape of era5 data?

hi @tobifinn - nice work, thanks for open sourcing it!

I'm trying to run your code to get a better understanding of how the ensemble attention works. I downloaded TIGGE-IFS (50 ensemble members) and ERA5 reanalysis data for 2017 and ran both through the pre-processing pipeline. Everything seems to work until i run fit_normalizer.ipynb. This raises an error:

for ifs_data, era_data in tqdm(data_module.train_dataloader(), total=len(data_module.train_dataloader())):
    print(ifs_data.shape)
    print(era_data.shape)
    ifs_mean = ifs_data.mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
    ifs_squared_mean = ifs_data.pow(2).mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
    era_mean = era_data.mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
    era_squared_mean = era_data.pow(2).mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
    try:
        rolling_sum['ifs'] = rolling_sum['ifs']+ifs_mean
        rolling_squared_sum['ifs'] = rolling_squared_sum['ifs']+ifs_squared_mean
        rolling_sum['era'] = rolling_sum['era']+era_mean
        rolling_squared_sum['era'] = rolling_squared_sum['era']+era_squared_mean
    except:
        rolling_sum['ifs'] = ifs_mean
        rolling_squared_sum['ifs'] = ifs_squared_mean
        rolling_sum['era'] = era_mean
        rolling_squared_sum['era'] = era_squared_mean
    rolling_elems += ifs_data.shape[0]

Output

 0%|          | 0/29 [00:00<?, ?it/s]
torch.Size([16, 50, 3, 32, 64])
torch.Size([16, 32, 64])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_20355/3535360696.py in <module>
      4     ifs_mean = ifs_data.mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
      5     ifs_squared_mean = ifs_data.pow(2).mean(dim=(1, 3, 4), keepdim=True).sum(dim=0, keepdim=True)
----> 6     era_mean = era_data.mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
      7     era_squared_mean = era_data.pow(2).mean(dim=(1, 2, 3), keepdim=True).sum(dim=0, keepdim=True)
      8     try:

IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

What's the expected shape of the ERA5 t2m dataset? Thanks!

attention mechanism in the ensemble transformer

Hi Tobias @tobifinn - I'm trying to understand your implementation of the ensemble attention mechanism. If I'm not mistaken, EnsConv2D flattens the leading tensor dimensions (bs, ens) then the SelfAttentionModule attends to the (partially) flattened tensor with shape (bs * ens, n_channels, h, w)? The paper seems to indicate that K, w_i etc. are all dependent on the ensemble dimension - which would seem to preclude the use of a different ensemble size during inference. The code suggests otherwise, though.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.