openclimatefix / metnet Goto Github PK
View Code? Open in Web Editor NEWPyTorch Implementation of Google Research's MetNet and MetNet-2
License: MIT License
PyTorch Implementation of Google Research's MetNet and MetNet-2
License: MIT License
HF model hub allows for model cards for models uploaded there, and they have a section for more standardized reporting of carbon costs associated with training that model. Adding that helps keep OCF transparent in making a net decrease in carbon emissions.
Hi,Metnet has small target patch. So how these small patches merge to a big one which cover the whole country. Is there any discontinuous among the target patches.
As discussed on tech planning meetings, having a national forecast model might be helpful, and MetNet fits the bill with the ability to deal with very large context images.
Improving the national forecast is good
The code for both MetNet-2 and MetNet is arleady setup for NWP and Satellite, and topographic inputs, as long as they are the same resolution. The trickier bit is including PV and GSP data. One way would be to just put them at the pixels they contain in the overall image as more channels? Would probably be the simplest, GSP level would be one channel with the values being the PV generation for that timestep for each pixel containing the GSP. PV would be similar, although for piuxels with multipl PV systems, could take the average of them? Or have multiple channels for PV so every PV system is included.
Other part needed for this is to implement the new data loader with data pipes, probably.
MetNet was trained on GOES and MRMS data, and MetNet-2 on that plus some NWP initialization parameters. All that data is publicly accessible with no restrictions from NOAA, so adding that to HF datasets to make it easy to ideally, reproduce the paper would be great.
This would add GOES-16/17 and MRMS precipitation data to HF Datasets, along with probably the HRRR NWP model for high-resolution data over the US. The data for GOES and HRRR is all available through Amazon open datasets, while the MRMS data has been archived and is fairly easily accessible.
Describe the bug
When I test the "test_model.py", I got the following error message:
metnet\layers\ConvLSTM.py", line 54, in ConvLSTMCell
def forward(self, x: torch.Tensor, prev_state: list) -> tuple[torch.Tensor, torch.Tensor]:
TypeError: 'type' object is not subscriptable
I need help with understanding what the best input is and how to convert the output value of the model
In the original MetNet article, the model output is 512 values, which is converted using softmax. In the current implementation the model output is one value, what is the best thing to do with it ? MetNet takes as input [-1;1] or [0; max_val] ?
Thanks for sharing this project, I'm very interested in this but I have no idea how to get the data. Although the goes and mrms data is free to download from the site, but there are many different types, for goes16 data, there are abi-l1, abi-l2, GLM, SUVI etc.
Is there any resouces available to guide me collecting the dataset? It would be really helpful. Thanks in advance and regards.
There is a huggingface repo taht can construct our MetNet and run example through it with gradio. It still neds to be updated to the latest version, and have a pre-trained model added as well to make the outputs make sense.
https://arxiv.org/pdf/2211.01001.pdf
This paper describes ML models for predicting thunderstorms, taking in very similar data to MetNet, and dealing with certain modalities dropping out while still working, and gives probabilistic forecasts. Its forecasting 60min ahead at 5 minute intervals.
They did find the satellite data and radar were the most important inputs to the model:
This seems like its trying to do a similar thing to MetNet/MetNet-2, with similar inputs, so might be helpful with expanding MetNet or adding it as a new option.
Code: https://github.com/MeteoSwiss/c4dl-multi
Data: https://zenodo.org/record/6802292
Pretrained models: https://zenodo.org/record/7157986
For example, if I wanted the target patch to be 1/2 of the input size instead of 1/4, would that be possible?
I am running on a slightly down sampled dataset with spatial dimensions 448x448. The shape of my input tensor is (None, t = 7, c = 75, w = 112, h = 112) and output tensor shape (None, t=60, c = 51, w = 28, h = 28). I have implemented a version that works with pytorch-lightning for parallelization and would be happy to share if anyone wants. I got the following parameters:
Downsampler (same as paper): 1.6 M parameters
Temporal encoder (hidden = 384): 6.6M parameters
Temporal Aggregation (4 layers, heads=8, num_dims=2): 4.7M parameters
But when I run a single training epoch this with batch_size = 1 on an NVIDIA A100 GPU, I get the error:
"RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 39.59 GiB total capacity; 36.22 GiB already allocated; 6.19 MiB free; 37.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF". When I reduce the lead times this error disappears (this is the bottleneck)."
The issue is that the effective batch size with this training loop is 60*batch_size. The paper only does one random lead time per sample, which now makes sense to me. This solves the memory issue by allowing effective minimum batch size to be lower than 60. I am however not certain how to implement this since the training step is automated by the pytorch-lightning module. A quick fix would be to generate 60 different copies of the input tensor, encode them with all 60 different lead times and pair them with an output tensor of shape t=1 instead of t=60. However I see some potential issues with this solution because it's very inefficient memory wise since we would have 60 nearly identical input tensors.
MetNet-2: https://ai.googleblog.com/2021/11/metnet-2-deep-learning-for-12-hour.html
Some notes about MetNet-2:
Hat tip to @jacobbieker for spotting Google AI's blog post about MetNet-2!
Hello everyone!
I would like to know what format the input data on the trained model still has. I found some information in this question #52 , but there are a number of clarifications. Is the following true for [batch, channels, timesteps, width, height]:
GOES-16/18
? (then what are these layers for the trained MetNet-2
? Wind, humidity, temperature? In what order are they arranged in the tensor ?)In general, it would be nice to make a more constructive example of running the model (what format and structure are the inputs and outputs in the trained example for MetNet-2
).
Thanks in advance for the answer!
This is apparently MetNet-2 official version. Would be good to compare and incorporate that into this version. Their version is JAX-based.
We plan to use pydocstyle
to automatically check that our docstrings conform to a standard format. It would be amazing to have help tweaking our existing docstrings to conform!
pydocstyle
's list of criteria is here: http://www.pydocstyle.org/en/stable/error_codes.html
Here's the Pull Request (and discussion) about enabling pydocstyle
for nowcasting_utils
: openclimatefix/nowcasting_utils#23
Found a bottleneck: the attention layer
I have found a potential bottleneck for why bug #22 occurred. It seems like the axial attention layer is some kind of bottleneck. I ran the network for 1000 epochs to try to overfit a small subset of 4 samples. See run at
WandB. The network is not able to drop the loss at all almost and does not overfit the data, it yields a very bad result, some kind of mean. See image below:
After removing the axial attention layer the model does as expected and overfits the training data, see below after 100 epochs:
The message from the author listed in #19 does mention that our implementation of axial attention seems to be very different from theirs, he says: "Our (Google's) heads were small MLPs as far as I remember (I'm not at google anymore so do not have access to the source code)." I am not experienced enough to look into the source code of our Axial Attention Library to see how this differs from theirs.
Training loss is not decreasing
I have implemented the network in the PR "lightning" branch with pytorch lightning and tried to find any bugs. The network compiles without issues and seems to generate gradients but still network fails to learn anything. I have tried to play around with the learning rate and plot the data at different stages but even with 4 training samples (it should be able to overfit these) it fails to decrease the loss even after 100 epochs...
Here is the training loss plotted:
It seems like it's doing something but not nearly quick enough to overfit the small dataset. Something is wrong...
Hyperparameters:
n_samples = 4
hidden_dim=8,
forecast_steps=1,
input_channels=15,
output_channels=6, #512
input_size=112, # 112
n_samples = 100,
num_workers = 8,
batch_size = 1,
learning_rate = 1e-2
Below is a weights&biases grad report. As you can see most gradients are non-zero, I'm not sure why image_encoder has very small gradients for their biases...
wandb: epoch 83 wandb: grad_2.0_norm/head.bias_epoch 0.0746 wandb: grad_2.0_norm/head.bias_step 0.049 wandb: grad_2.0_norm/head.weight_epoch 0.0862 wandb: grad_2.0_norm/head.weight_step 0.081 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_epoch 0.06653 wandb: grad_2.0_norm/image_encoder.module.module.0.weight_step 0.043 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_epoch 0.00017 wandb: grad_2.0_norm/image_encoder.module.module.2.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_epoch 0.003 wandb: grad_2.0_norm/image_encoder.module.module.2.weight_step 0.0019 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_epoch 0.16387 wandb: grad_2.0_norm/image_encoder.module.module.3.weight_step 0.1125 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_epoch 0.00013 wandb: grad_2.0_norm/image_encoder.module.module.4.bias_step 0.0001 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_epoch 0.00203 wandb: grad_2.0_norm/image_encoder.module.module.4.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_epoch 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.bias_step 0.0 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_epoch 0.15237 wandb: grad_2.0_norm/image_encoder.module.module.5.weight_step 0.1151 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_epoch 0.0032 wandb: grad_2.0_norm/image_encoder.module.module.6.bias_step 0.0018 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_epoch 0.00157 wandb: grad_2.0_norm/image_encoder.module.module.6.weight_step 0.0012 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_epoch 0.00497 wandb: grad_2.0_norm/image_encoder.module.module.7.bias_step 0.003 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_epoch 0.11753 wandb: grad_2.0_norm/image_encoder.module.module.7.weight_step 0.0915 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_epoch 0.03763 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_kv.weight_step 0.0277 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_epoch 0.05167 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_out.weight_step 0.0369 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_epoch 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.0.fn.to_q.weight_step 0.0008 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_epoch 0.04393 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_kv.weight_step 0.0216 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_epoch 0.0412 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.bias_step 0.0289 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_epoch 0.04287 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_out.weight_step 0.027 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_epoch 0.0014 wandb: grad_2.0_norm/temporal_agg.0.axial_attentions.1.fn.to_q.weight_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_epoch 0.00197 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.bias_step 0.001 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_epoch 0.03313 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h1.weight_step 0.0216 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_epoch 0.00103 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.bias_step 0.0004 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_epoch 0.00353 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_h2.weight_step 0.002 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_epoch 0.00133 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.bias_step 0.0009 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_epoch 0.02123 wandb: grad_2.0_norm/temporal_enc.rnn.cell_list.0.conv_zr.weight_step 0.0147 wandb: grad_2.0_norm_total_epoch 0.31513 wandb: grad_2.0_norm_total_step 0.2254 wandb: train/loss_epoch 1.72826 wandb: train/loss_step 1.73303 wandb: trainer/global_step 251 wandb: validation/loss_epoch 1.76064
I have plotted the inputs as they flow through the layers, and none of them seems to do anything unexpected:
I'm out of ideas and would appreciate any input.
To Reproduce
Steps to reproduce the behavior:
Hello team at Open Climate Fix,
I am really interested in the code you developed here in this repo! Is there an easy way to demo this code with real images from GOES and MRMS? I have the described data but am not sure how to interface with the code base.
Thanks again,
Sailesh
We want to implement MetNet-3 (See #54). For this, we need to do:
The dataset associated with the project is not yet accessible to the public.
@JackKelly @peterdudfield @jacobbieker
Also for subselecting types of satellite (HRV+NonHRV, just HRV, just nonHRV, Only some non-HRV channels) or channels in NWPs, etc. Also try with GFS NWP.
Want to determine how important satellite imagery is
Too small changes depending on lead time
The model is able to learn something but the output image seems to change too little depending on the lead time encoding (one-hot from input layer). Here are some examples of the output from 2 different models, one with 60 leadtimes and one with only 8. The left hand plots show the ground truth precipitation in the prediction zone at different lead times, the right hand side shows P(rain_rate>0.2 mm/h) which means I sum the softmax probabillites of all the 127 bins corresponding to rain>0.2mm/h.
60 lead times (5,10,15... 300 min)
Only 8 leadtimes (15,30,45... 120 minutes):
(I changed the cmap to make it clearer that the two plots are not plotting the same thing).
For full 60 lead time network check out: w&b, 60 leads
For 8 lead time network check out: w&b, 8 leads
Should be noted that the 8 lead time network has not yet started overfitting.
I have implemented a sampling quality pass that during makes sure each training sample only samples a lead time when there is at least 5 rain pixels.
I am suspecting the axial attention layer again as a bottleneck. Maybe I'm not using it right. We added a positional embedding so that it would know which pixel was where in the input layer, I was wondering if we should add an embedding for which channel it is looking at. Since the model seems to be forgetting which lead time it is handling the ConvGRU spits out 256x28x28 tensor.
Why is it performing so poorly?
https://arxiv.org/pdf/2306.06079v2.pdf
MetNet-3 was released as a paper
It would be better to have all versions of MetNet in this repo. This densified forecast could also be useful for the irradiance modelling, going from PV sites to a dense forecast.
Currently the CI is failing, probably should fix these
For reference:
From the paper:
Tasks:
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.