Coder Social home page Coder Social logo

nasaharvest / presto Goto Github PK

View Code? Open in Web Editor NEW
152.0 8.0 26.0 16.77 MB

Lightweight, Pre-trained Transformers for Remote Sensing Timeseries

Home Page: https://arxiv.org/abs/2304.14065

License: MIT License

Jupyter Notebook 6.51% Python 93.34% Shell 0.16%

presto's Introduction

The Pretrained Remote Sensing Transformer (Presto)

This code accompanies our paper, Lightweight, Pre-trained Transformers for Remote Sensing Timeseries.

Environment Setup

python -m venv venv
source venv/bin/activate
pip install -e .

wandb can additionally be installed for full functionality of the train.py script.

Entrypoints

Three entrypoints to the code are available: train.py, eval.py and mosaiks.py.

In addition, a jupyter notebook is available demonstrating how Presto can be finetuned on different downstream tasks.

Finally, Presto can also be loaded directly from the python package. We also have included Presto contained in a single file (i.e. with no imports from elsewhere in the package) at single_file_presto.py, if you want to easily integrate it into a different application. We test that these models are equivalent:

# either import works. The single_file_presto has no load_pretrained function, since this
# requires knowing where the pretrained file is. The state dict can be loaded directly
# from data/default_models.pt
from single_file_presto import Presto
from presto import Presto

# to make a randomly initialized encoder-decoder model
encoder_decoder = Presto.construct()
# alternatively, the pre-trained model can also be loaded
encoder_decoder = Presto.load_pretrained()

# to isolate the encoder
encoder_only = encoder_decoder.encoder
# to add a linear transformation to the encoder's output for finetuning
finetuning_model = encoder_decoder.construct_finetuning_model(num_outputs=1, regression=True)

The default arguments to construct are the same as the default parameters described in default.json.

Presto expects the following values as input, and returns the following outputs:

reconstructed_x, reconstructed_dynamic_world = encoder_decoder(x, dynamic_world, latlons, mask, month)

globally_pooled_tokens = encoder(x, dynamic_world, latlons, mask, month, eval_task=True)

predictions = finetuning_model(x, dynamic_world, latlons, mask, month)
  • x: torch.Tensor of shape [batch_size, num_timesteps, bands] where bands is described by NORMED_BANDS.
  • dynamic_world: torch.Tensor of shape [batch_size, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.
  • latlons: torch.Tensor of shape [batch_size, 2] describing the latitude and longitude of each input instance.
  • mask: An optional torch.Tensor of shape [batch_size, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.
  • month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.

The number of timesteps passed is optional, and can be any value between 1 and 24 (2 years of data).

3 of the input tensors (x, dynamic_world, mask) can be generated using presto.construct_single_presto_input. An example of this is in the downstream task jupyter notebook. For example, if I have access to some RGB imagery, it can be turned into Presto-compatible inputs:

import presto
x, mask, dynamic_world = presto.construct_single_presto_input(
    s2=rgb_imagery,  # of shape [num_timesteps, 3]
    s2_bands=["B2", "B3", "B4"]
)

Here, x will contain only the (normalized) RGB values in the correct indices, and mask will communicate to Presto to ignore every other input. Similarly, dynamic_world will contain only DynamicWorld2020_2021.class_amount, so Presto will ignore it.

Training

The train.py script contains code for self-supervised training. This can be run locally on a small subset of the data with:

# Barebones local run
python train.py \
    --train_url "data/dw_144_mini_shard_44.tar" \
    --val_url "data/dw_144_mini_shard_44.tar" \
    --val_per_n_steps 1 \
    --cropharvest_per_n_validations 0 \
    --skip_finetuning

Evaluation

A trained model (or a randomly initialized model) can be run against the evaluation tasks using eval.py. If an --id and --epoch is passed to the script, a model will be loaded from models/{id}/{epoch}.pt - otherwise, a randomly initialized model will be evaluated.

Mosaiks

The MOSAIKS1D benchmark can be run against evaluation tasks using the mosaiks.py script.

Generating new data

Diagram: url

Prerequisites:

  • Account with Google Cloud access and Earth Engine access
    curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-387.0.0-linux-x86_64.tar.gz
    tar -xf google-cloud-cli-387.0.0-linux-x86_64.tar.gz
    exec bash
    ./google-cloud-sdk/install.sh
    gcloud init
    earthengine authenticate
  • Create buckets for processing
    gcloud storage mb -l us-central1 $(python -c "from dataops import EE_BUCKET; print(EE_BUCKET)")
    gcloud storage mb -l us-central1 $(python -c "from dataops import NPY_BUCKET; print(NPY_BUCKET)")
    gcloud storage mb -l us-central1 $(python -c "from dataops import TAR_BUCKET; print(TAR_BUCKET)")
  • Deploy tif-to-np Cloud Function
    sh scripts/deploy_tif_to_np.sh

Once prerequisites are satisfied, data can be generated by running:

python scripts/generate_data.py

⚠️ This script assumes you have a Google Cloud project named presto - you need to change this in the script if the name of the project is different. ⚠️

The script will generate:

  • data/tile_processing.txt A summary of tiles being processed
  • data/tile_stats.yaml Stats for all tiles available for training

Behind the scenes for each tile the script will:

  1. Begin Earth Engine exports to get data for tile from specific data pipeline. Examples:
    • gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_1>/*.tif
    • gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_2>/*.tif
    • gs://<EE_BUCKET>/<SHARD_1>/<PIPELINE_3>/*.tif
  2. Process each tif file to npy. Examples:
    • gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_1>/*.npy
    • gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_2>/*.npy
    • gs://<NPY_BUCKET>/<SHARD_1>/<PIPELINE_3>/*.npy
  3. Combine all npy files into a tar file accessible through webdataset. Example:
    • gs://<TAR_BUCKET>/<DATASET_NAME>/<SHARD_1>.tar

Accessing new data

In [0]:
import webdataset as wds
import pandas as pd
uri = "gs://lem-assets2/S1_S2_ERA5_SRTM_2020_2021_DynamicWorld2020_2021_tars/dw_144_shard_0.tar"
dataset = wds.WebDataset(f"pipe:gcloud storage cat {uri}").decode()
for sample in dataset:
    break

In [1]: list(sample.keys())
Out[1]: ['__key__', '__url__', 'dynamicworld2020_2021.npy', 's1_s2_era5_srtm_2020_2021.npy', 'worldcover2020.npy']

In [2]: sample["s1_s2_era5_srtm_2020_2021.npy"].shape
Out[2]: (625, 24, 18)

In [3]: sample["latlon.npy"].shape
Out[3]: (625, 2)

In [4]: sample["worldcover2020.npy"].shape
Out[4]: (625, 1)

In [5]: sample["dynamicworld2020_2021.npy"].shape
Out[5]: (625, 24)

In [6]: pd.Series(sample["dynamicworld2020_2021.npy"].flatten()).value_counts()
Out[6]:
0    14978
7       22
dtype: int64

Reference

If you find this code useful, please cite the following paper:

@misc{tseng2023lightweight,
      title={Lightweight, Pre-trained Transformers for Remote Sensing Timeseries},
      author={Gabriel Tseng and Ruben Cartuyvels and Ivan Zvonkov and Mirali Purohit and David Rolnick and Hannah Kerner},
      year={2023},
      eprint={2304.14065},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

presto's People

Contributors

gabrieltseng avatar kvantricht avatar rubencart avatar sabman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

presto's Issues

Issues loading pretrained model

When I'm trying to install Presto as a package not in editable mode (e.g. to pack it to run on our cluster), I'm experiencing a couple of issues:

  • the current pretrained model loading method does not work as the model weights are not considered as package data (probably separate issue).
  • An alternative would be to accept a non-default model weights location in the load_pretrained method here but at the moment this is not possible
  • As a last resort, I was trying to copy the data folder manually and using update_data_dir from here, but while the data_dir variable is successfully updated, default_model_path from here is not, so it has actually no effect for loading the pretrained model

Downloading data for more years

Hi,

Thanks for the awesome work on this project!

I see that your data captures 2 years from 2020-2021. How would I need to change the data download script to download data for other years?

Thanks!

Data Access Issues

Hello,

I seem to have some issues with the data access. Maybe you could clarify them for me.

I am able to complete this part:

curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-387.0.0-linux-x86_64.tar.gz
tar -xf google-cloud-cli-387.0.0-linux-x86_64.tar.gz
exec bash
./google-cloud-sdk/install.sh
gcloud init
earthengine authenticate

however, when I try any of these:

gcloud storage mb -l us-central1 $(python -c "from dataops import EE_BUCKET; print(EE_BUCKET)")
gcloud storage mb -l us-central1 $(python -c "from dataops import NPY_BUCKET; print(NPY_BUCKET)")
gcloud storage mb -l us-central1 $(python -c "from dataops import TAR_BUCKET; print(TAR_BUCKET)")

I get the following error:

~/presto$ gcloud storage mb -l us-central1 $(python -c "from dataops import EE_BUCKET; print(EE_BUCKET)")
Traceback (most recent call last):
  File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'dataops'
ERROR: (gcloud) Invalid choice: 'storage'.
Maybe you meant:
  gcloud alpha storage
  gcloud composer environments

To search the help text of gcloud commands, run:
  gcloud help -- SEARCH_TERMS

What exactly am I doing wrong?

Thanks!

Presto import triggers directory creation which does not always work

Hi guys,

during a Presto import, this line always gets triggered:

data_dir.mkdir(exist_ok=True, parents=True)

On some systems like our cluster, where a Python wheel containing Presto is shipped and deployed, this does not work. I fixed it for now locally by catching the error without failing and the rest of the code (computing Presto encodings) worked without issue on the cluster. So it's something to look at.

Lat/long looks flipped by pyproj.Transformer

Hi, this is really interesting! I'm trying out the downstream task notebook as I am looking into applying this approach to wetland classification.

This repo uses pyproj's Transformer.transform() (eg in the downstream task notebook) which can sometimes swap the order of coordinates depending on the coordinate system. Specifically, EPSG:4326 uses lat/long whereas the input UTM data uses long/lat. In the example I linked, this results in the two variables being switched.

Adding always_xy=True to the transformer would keep the ordering consistent throughout (see the warning on the docs).

It's not clear to me whether this would affect the pre-trained model, but I see the same line in at least the eurosat and treesat eval files. It doesn't seem to significantly affect the F1 score on the example notebook however.

Even if there aren't other impacts I think that updating the order so that the variables are named correctly would avoid confusion.

Thanks!

Issue with Files Needed for CropHarvestMultiClassValidation Class

Hi @gabrieltseng

I am currently working on implementing the CropHarvestMultiClassValidation class within presto/eval/cropharvest_eval.py. To facilitate this, I require access to the data accessible via the download_cropharvest_data() function.

However, I am encountering difficulties accessing the "features/dynamic_world_arrays" and "test_dynamic_world_features" files necessary for this task. Could you please provide me with direct links or alternative methods to download these folders?

Your assistance in resolving this access issue would be greatly appreciated.

Kind Regards,
Mahrokh

Data access

Hey, I am trying to generate new data and could use some help.
Firstly, I cannot find the script generate_data.py, so not sure where that is.
Secondly, i just wanted to know how you collect the S1 and S2 data, i do not see any ee.ImageCollection in the scripts for this, and was wondering how you did this.

Would be great if I could get some help on this. Thanks in advance.

Required units for precipitation unclear

Hi guys,

we're still working on setting up the first Presto encoder attempt from our side and we're facing issues with precipitation as an input. In general, I think there's a bit a of lack on documentation with regards to expected units for inputting data into Presto (like Sentinel-2 required to be in 1E4 scaling, Sentinel-1 required to be in decibels, etc.). I think we figured out most on our own, but for precipitation we're puzzled.
We see the ERA5_DIV_VALUES being 0.03 for precipitation. At the moment, we are feeding just the monthly sum in mm, but after normalization, the values in x explode due to this 0.03 factor. So what's the expected units?

https://github.com/nasaharvest/presto/blob/613265efac15c8c97a28d94fc2110a2d38e8a5e0/presto/dataops/pipelines/s1_s2_era5_srtm.py#L53:L55

Also note that the commented notebook where the values supposedly come from is inaccessible to us.

Data access

Data is currently stored in a private google cloud bucket. We will need to move it to a more generally accessible location (likely Zenodo).

  • Data in Zenodo
  • Download functions updated to pull from Zenodo instead of google cloud

Confusing use of dynamic and static inputs

As per the paper, topography and location are static variables, not having a temporal dimension

image

However, as per the encoder itself, topography is part of x, while latlons is a separate input: https://github.com/nasaharvest/presto/blob/main/presto/presto.py#L390:L398

In practice it means the user needs to duplicate altitude/slope along the temporal dimension, e.g. to be able to feed it to construct_single_presto_input which is rather confusing. What is the reasoning behind this?

What is ignored DynamicWorld value?

Sorry for the many questions, but again a bit confused here.

when using presto.construct_single_presto_input without DynamicWorld, the returned dynamic_world has fixed values of 9. As per the ReadMe, it should be 10:

dynamic_world: torch.Tensor of shape [batch_size, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount + 1 (i.e. 10), in which case it is ignored.

If I manually try to feed a fixed tensor of value 10 to Presto, I get IndexError: index out of range in self, while with a value of 9 it works. Could you clarify what we're supposed to do in absence of DynamicWorld?

TreeSatAI evaluation

My understanding is that the TreeSatAI dataset is multilabel and the labels can be loaded from the labels.json file provided with the dataset. However, on this line it appears that you are taking the label from the filename and treating it not as a multilabel task but just a multiclass task. Is this correct?

LICENSE

I cannot see one, could you add? Thanks

`RuntimeError` "Expected all tensors to be on the same device..." when cuda is available.

I get an error if I have cuda available on my computer.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

            values = torch.stack(xs, axis=0).float().to(device)
            dynamic_world = torch.stack(dynamic_worlds, axis=0).long().to(device)
            mask = torch.stack(masks, axis=0).bool().to(device)
            latlons = torch.stack(latlonss, axis=0).float().to(device)
            month = torch.stack(months, axis=0).long().to(device)

            print(f"{values.device=}")
            print(f"{dynamic_world.device=}")
            print(f"{mask.device=}")
            print(f"{latlons.device=}")
            print(f"{month.device=}")

            with torch.no_grad():
                features = (
                    pretrained_model.encoder(
                        values,
                        dynamic_world=dynamic_world,
                        mask=mask,
                        latlons=latlons,
                        month=month,
                    )
                    .cpu()
                    .numpy()
                )

All values that I pass to the encoder are on the same device as you can see from the code. Here's the output of the printed debug messages:

values.device=device(type='cuda', index=0)
dynamic_world.device=device(type='cuda', index=0)
mask.device=device(type='cuda', index=0)
latlons.device=device(type='cuda', index=0)
month.device=device(type='cuda', index=0)

Nothing changes if I replace device="cuda" to device="cpu" we still have the same error.

The full stack trace of the error:

 Traceback (most recent call last):
  File "/home/mikhail/source/presto_features/main.py", line 188, in <module>
    process_tile(
  File "/home/mikhail/source/presto_features/main.py", line 156, in process_tile
    pretrained_model.encoder(
  File "/home/mikhail/.cache/pypoetry/virtualenvs/presto-features-bmBP-FwO-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mikhail/.cache/pypoetry/virtualenvs/presto-features-bmBP-FwO-py3.10/lib/python3.10/site-packages/presto/presto.py", line 415, in forward
    tokens = self.eo_patch_embed[channel_group](x[:, :, channel_idxs])
  File "/home/mikhail/.cache/pypoetry/virtualenvs/presto-features-bmBP-FwO-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mikhail/.cache/pypoetry/virtualenvs/presto-features-bmBP-FwO-py3.10/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

Pre-training on a subset of the original input channels

When using construct_single_presto_input, the code conveniently handles the normalization of the inputs and construction of the mask. If certain inputs (bands) are missing, the respective mask values are automatically set to 1. However, there seems to be no way to deal with certain missing timesteps in the inputs. Imagine a monthly compositing of Sentinel-2 resulting in no valid observations for some month. Either we can deal with it by linearly interpolating the missing values, but it seems Presto should actually be able to natively deal with missing timesteps by setting the respective mask value to 1.

At the moment, the only way to do it is by keeping track of missing value positions in the original inputs and after the call to construct_single_presto_input setting the mask of these positions to 1. Would there be a more convenient way of doing this? Thinking of certain no-data values that can be treated by this method as missing and setting the mask in correspondance.

Specific side note on automatic computation of NDVI: we were testing with NaN inputs for S2 to see how the code behaves. Interestingly, this line actually makes up a valid NDVI value of 0 in x when the inputs are invalid.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.