Coder Social home page Coder Social logo

satclip's Introduction

🛰️ SatCLIP - A Global, General-Purpose Geographic Location Encoder

CLIP

Overview of the pretraining and deployment pipeline for SatCLIP.

Approach

SatCLIP trains location and image encoders via contrastive learning, by matching images to their corresponding locations. This is analogous to the CLIP approach, which matches images to their corresponding text. Through this process, the location encoder learns characteristics of a location, as represented by satellite imagery. For more details, check out our paper.

Overview

Usage of SatCLIP is simple:

from model import *
from location_encoder import *

model = SatCLIP(
    embed_dim=512,
    image_resolution=224, in_channels=13, vision_layers=4, vision_width=768, vision_patch_size=32, # Image encoder
    le_type='sphericalharmonics', pe_type='siren', legendre_polys=10, frequency_num=16, max_radius=360, min_radius=1, harmonics_calculation='analytic'  # Location encoder
)

img_batch = torch.randn(32, 13, 224, 224) # Represents a batch of 32 images
loc_batch = torch.randn(32, 2) # Represents the corresponding 32 locations (lon/lat)

with torch.no_grad():
    logits_per_image, logits_per_coord = model(img_batch, loc_batch)
    probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()

Training

You first need to download the S2-100k dataset in /data/s2. First, download the index file:

cd data/s2
wget https://satclip.z13.web.core.windows.net/satclip/index.csv

Within /data/s2, navigate to /images, download all images and unpack them:

cd images
wget https://satclip.z13.web.core.windows.net/satclip/satclip.tar
tar -xf satclip.tar

Now, to train SatCLIP models, set the paths correctly, adapt training configs in satclip/configs/default.yaml and train SatCLIP by running:

cd satclip
python main.py

Use of the S2-100K dataset

The S2-100K dataset is a dataset of 100,000 multi-spectral satellite images sampled from Sentinel-2 via the Microsoft Planetary Computer. Copernicus Sentinel data is captured between Jan 1, 2021 and May 17, 2023. The dataset is sampled approximately uniformly over landmass and only includes images without cloud coverage. The dataset is available for research purposes only. If you use the dataset, please cite our paper. More information on the dataset can be found in our paper.

Pretrained Models

CLIP

Visualization of embeddings obtained by different location encoders for locations around the globe.

We provide six pretrained SatCLIP models, trained with different vision encoders and spatial resolution hyperparameters $L$ (these indicate the number of Legendre polynomials used for spherical harmonics location encoding. Please refer to our paper for more details). The pretrained models can be downloaded as follows:

  • SatCLIP-ResNet18-L10: wget https://satclip.z13.web.core.windows.net/satclip/satclip-resnet18-l10.ckpt
  • SatCLIP-ResNet18-L40: wget https://satclip.z13.web.core.windows.net/satclip/satclip-resnet18-l40.ckpt
  • SatCLIP-ResNet50-L10: wget https://satclip.z13.web.core.windows.net/satclip/satclip-resnet50-l10.ckpt
  • SatCLIP-ResNet50-L40: wget https://satclip.z13.web.core.windows.net/satclip/satclip-resnet50-l40.ckpt
  • SatCLIP-ViT16-L10: wget https://satclip.z13.web.core.windows.net/satclip/satclip-vit16-l10.ckpt
  • SatCLIP-ViT16-L40: wget https://satclip.z13.web.core.windows.net/satclip/satclip-vit16-l40.ckpt

Usage of pretrained models is simple:

from load import get_satclip

device = 'cuda'

c = torch.randn(32, 2) # Represents a batch of 32 locations (lon/lat)

model = get_satclip('path_to_satclip', device=device) #Only loads location encoder by default
model.eval()
with torch.no_grad():
  emb  = model(c.double().to(device)).detach().cpu()

You can also access SatCLIP model weights directly via Hugging Face.

Examples

Examples on how to obtain and use pretrained SatCLIP embeddings can be found in the notebooks folder. We provide notebooks (optimized for use with Google Colab) for the following use cases.

Setup:

Example use cases:

Use baseline pretrained location encoders:

Citation

@article{klemmer2023satclip,
  title={SatCLIP: Global, General-Purpose Location Embeddings with Satellite Imagery},
  author={Klemmer, Konstantin and Rolf, Esther and Robinson, Caleb and Mackey, Lester and Ru{\ss}wurm, Marc},
  journal={arXiv preprint arXiv:2311.17179},
  year={2023}
}

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

satclip's People

Contributors

konstantinklemmer avatar microsoftopensource avatar calebrob6 avatar jeangud avatar plekhanovaelena avatar

Stargazers

 avatar Natchapon Jongwiriyanurak avatar  avatar  avatar  avatar Diego Arenas avatar Seyed Ali Ahmadi avatar  avatar Zak Zebrowski avatar Kit Macleod avatar Antonio Esteves avatar  avatar Santiago Mota avatar Gabriel Pastorello avatar wildflowers315 avatar Bishesh Adhikari avatar YonghuiTAN avatar  avatar 李开宇 avatar Abe Usher avatar Ciaran avatar Tanjim Bin Faruk avatar  avatar Jorges Nofulla avatar  avatar  avatar Natthaphon Rotechanathamcharoen avatar Jie Feng avatar EllenXiong avatar  avatar Nate Simon avatar Karthik Venkataramani avatar Jacques Tardie avatar  avatar dioda avatar hpycndl avatar Qian Cao avatar Nedim Can Ulusoy avatar Jeonghwan Kim avatar Senthil Kumar avatar  avatar Ran avatar  avatar  avatar  avatar Xiaobing Han avatar ze feng avatar Dhruvin Patel avatar  avatar Tommy Gaidus avatar Jordan Imahori avatar  avatar R. Sahajpal avatar Mikołaj avatar rohit sohlot avatar Sven Sackers avatar James avatar hadifalahi4142 avatar Alex Rigler avatar Jonah Ruiz avatar  avatar Sachin Chanchani avatar Maxime Lenormand avatar Marc Girona-Mata avatar Thijs van der Plas avatar Mohanad Albughdadi avatar Pengyu CHEN avatar Guofeng Cao avatar  avatar  avatar Andrés Camilo Zúñiga-González avatar Emilio Luz-Ricca avatar Jordan Graesser avatar Kunal Kasodekar avatar  avatar Alex Levering avatar muthukumaran R avatar Patrick Emami avatar Nish avatar  avatar Fábio Franco Uechi avatar Vicky Liau avatar Mustafa Kemal Emil avatar Iman Deznabi avatar Kürşat Kömürcü avatar Romain Thoreau avatar Jose Cohenca avatar Philipe Borba avatar Reslan Tinawi avatar @guy22.eth avatar Mehmet Akif Ortak avatar Thorsten Hoeser avatar Jacques Moati avatar Francesco Rossi avatar Rim Sleimi avatar Roger avatar jotape avatar ThickNavyRain avatar Mirali Purohit avatar Weihao XUAN avatar

Watchers

Lester Mackey avatar James Cloos avatar  avatar Manoj Kumar avatar Samir Shrestha avatar Peng Zhao avatar Lakshay Sharma avatar .NET Foundation Contribution License Agreements avatar Dan Morris avatar  avatar GlennML avatar Giseop Kim avatar  avatar  avatar

satclip's Issues

[Question] Coordinate Reference System

Hi,

I was taking a look at the code base and I had a question about the coordinate reference system (CRS) of the original satellite data. As I understand, the Sentinel-2 data has its original CRS and I imagine the lat-lon coordinates are on a curvilinear grid. Are you all doing some sort of interpolation to a rectilinear or regular grid in your preprocessing? Or do you keep everything in its original projection?

I see immense benefits to keeping it in its original projection because one could potentially lose a lot of information going back and forth between grids with different interpolators, e.g. Curvilinear -> Rectilinear -> Train NN -> Curvilinear. In addition, it definitely makes sense to avoid this because your architecture features a position encoder for the coordinates. However, I just wanted to check to make sure I am understanding all of the preprocessing involved beforehand.

Thanks!

Missing requirements.txt file

Description:
I noticed that the project is missing a requirements.txt file, which is commonly used in Python projects to list dependencies and make it easier for users to install them. Having a requirements.txt file is beneficial for users who want to set up the project environment.

Importance:
Including a requirements.txt file improves the project's accessibility and makes it easier for users to get started with the codebase. It also ensures that everyone working on the project is using the same versions of dependencies, reducing compatibility issues and potential errors.

Labels:
enhancement
help wanted

Compatibility issue with Pydantic version

I've been encountering an error while using Pydantic, even after trying both Pydantic==1.10.11 and Pydantic==2.0.3 versions. It seems there is a compatibility issue with both of these versions.

Error Message:

ImportError                               Traceback (most recent call last)
/tmp/ipykernel_488/4215400009.py in <module>
----> 1 from satclip.load import get_satclip

/mnt/satclip/satclip/__init__.py in <module>
      1 __all__ = ["configs", "datamodules", "positional_encoding"]
      2 
----> 3 from . import *
      4 from .main import *
      5 from .model import *

/mnt/satclip/satclip/datamodules/__init__.py in <module>
      1 from .transforms import *
----> 2 from .s2geo_dataset import *

/mnt/satclip/satclip/datamodules/s2geo_dataset.py in <module>
     10 import torch
     11 
---> 12 import lightning.pytorch as pl
     13 from torch.utils.data import DataLoader
     14 

/opt/conda/lib/python3.8/site-packages/lightning/__init__.py in <module>
     30 from lightning.__about__ import *  # noqa: E402, F401, F403
     31 from lightning.__version__ import version as __version__  # noqa: E402, F401
---> 32 from lightning.app import storage  # noqa: E402
     33 from lightning.app.core.app import LightningApp  # noqa: E402
     34 from lightning.app.core.flow import LightningFlow  # noqa: E402

/opt/conda/lib/python3.8/site-packages/lightning/app/__init__.py in <module>
     23 
     24 from lightning.app import __about__  # noqa: E402
---> 25 from lightning.app import components  # noqa: E402, F401
     26 from lightning.app.__about__ import *  # noqa: E402, F401, F403
     27 

/opt/conda/lib/python3.8/site-packages/lightning/app/components/__init__.py in <module>
----> 1 from lightning.app.components.database.client import DatabaseClient
      2 from lightning.app.components.database.server import Database
      3 from lightning.app.components.multi_node import (
      4     LightningTrainerMultiNode,
      5     LiteMultiNode,

/opt/conda/lib/python3.8/site-packages/lightning/app/components/database/__init__.py in <module>
----> 1 from lightning.app.components.database.client import DatabaseClient
      2 from lightning.app.components.database.server import Database
      3 
      4 __all__ = ["Database", "DatabaseClient"]

/opt/conda/lib/python3.8/site-packages/lightning/app/components/database/client.py in <module>
     20 from urllib3.util.retry import Retry
     21 
---> 22 from lightning.app.components.database.utilities import _GeneralModel
     23 
     24 _CONNECTION_RETRY_TOTAL = 5

/opt/conda/lib/python3.8/site-packages/lightning/app/components/database/utilities.py in <module>
     18 from typing import Any, Dict, Generic, List, Type, TypeVar
     19 
---> 20 from fastapi import Response, status
     21 from fastapi.encoders import jsonable_encoder
     22 from pydantic import BaseModel, parse_obj_as

/opt/conda/lib/python3.8/site-packages/fastapi/__init__.py in <module>
      5 from starlette import status as status
      6 
----> 7 from .applications import FastAPI as FastAPI
      8 from .background import BackgroundTasks as BackgroundTasks
      9 from .datastructures import UploadFile as UploadFile

/opt/conda/lib/python3.8/site-packages/fastapi/applications.py in <module>
     13 )
     14 
---> 15 from fastapi import routing
     16 from fastapi.datastructures import Default, DefaultPlaceholder
     17 from fastapi.encoders import DictIntStrAny, SetIntStr

/opt/conda/lib/python3.8/site-packages/fastapi/routing.py in <module>
     20 )
     21 
---> 22 from fastapi import params
     23 from fastapi.datastructures import Default, DefaultPlaceholder
     24 from fastapi.dependencies.models import Dependant

/opt/conda/lib/python3.8/site-packages/fastapi/params.py in <module>
      2 from typing import Any, Callable, Dict, Optional, Sequence
      3 
----> 4 from pydantic.fields import FieldInfo, Undefined
      5 
      6 

ImportError: cannot import name 'Undefined' from 'pydantic.fields' (/home/ubuntu/.local/lib/python3.8/site-packages/pydantic/fields.py)

Request:

Could someone please confirm which version of Pydantic is compatible? Additionally, any insights into resolving this compatibility issue would be greatly appreciated.

Unable to Load Locally Stored SATClip Model

Issue Description:
I have downloaded the SATClip model named satclip-vit16-l40.ckpt from Hugging Face and stored it at /pretrained_models/satclip/resnet16-l40/satclip-vit16-l40.ckpt. I attempted to load the locally stored model using the provided code in my working environment without internet access. However, I encountered the following error. Seeking guidance on successfully loading and utilising a locally downloaded and stored SATClip model without accessing the internet.

import sys
sys.path.append("./satclip")
from load import get_satclip


import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# c = torch.randn(32, 2)  # Represents a batch of 32 locations (lon/lat)
# Represents a batch of 4 locations (lon/lat)
a = [[44.963320,-93.244523],
     [33.872022,-84.464836],
    [30.237592,-95.177780],
    [34.738666,-86.646624],
 ]

torch.as_tensor(a).float()

model = get_satclip(
    '/pretrained_models/satclip/resnet16-l40/satclip-vit16-l40.ckpt',
    device=device,
)  # Only loads location encoder by default
model.eval()
with torch.no_grad():
    emb = model(torch.as_tensor(a).float().to(device)).detach().cpu()
using pretrained moco vit16
Downloading: "https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth" to /home/ubuntu/.cache/torch/hub/checkpoints/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth
---------------------------------------------------------------------------
HTTPError                                 Traceback (most recent call last)
Cell In[14], line 1
----> 1 model = get_satclip(
      2     '/pretrained_models/satclip/resnet16-l40/satclip-vit16-l40.ckpt',
      3     device=device,
      4 )  # Only loads location encoder by default
      5 model.eval()
      6 with torch.no_grad():

File /mnt/satclip/./satclip/load.py:8, in get_satclip(ckpt_path, device, return_all)
      6 ckpt['hyper_parameters'].pop('air_temp_data_path')
      7 ckpt['hyper_parameters'].pop('election_data_path')
----> 8 lightning_model = SatCLIPLightningModule(**ckpt['hyper_parameters']).to(device)
     10 lightning_model.load_state_dict(ckpt['state_dict'])
     11 lightning_model.eval()

File /mnt/satclip/./satclip/main.py:39, in SatCLIPLightningModule.__init__(self, embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, in_channels, le_type, pe_type, frequency_num, max_radius, min_radius, legendre_polys, harmonics_calculation, sh_embedding_dims, learning_rate, weight_decay, num_hidden_layers, capacity)
     16 def __init__(
     17     self,
     18     embed_dim=512,
   (...)
     35     capacity=256,        
     36 ) -> None:
     37     super().__init__()
---> 39     self.model = SatCLIP(
     40         embed_dim=embed_dim,
     41         image_resolution=image_resolution,
     42         vision_layers=vision_layers,
     43         vision_width=vision_width,
     44         vision_patch_size=vision_patch_size,
     45         in_channels=in_channels,
     46         le_type=le_type,
     47         pe_type=pe_type,
     48         frequency_num=frequency_num,
     49         max_radius=max_radius,
     50         min_radius=min_radius,
     51         legendre_polys=legendre_polys,
     52         harmonics_calculation=harmonics_calculation,
     53         sh_embedding_dims=sh_embedding_dims,
     54         num_hidden_layers=num_hidden_layers,
     55         capacity=capacity,
     56     )
     58     self.loss_fun = SatCLIPLoss()
     59     self.learning_rate = learning_rate

File /mnt/satclip/./satclip/model.py:309, in SatCLIP.__init__(self, embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, in_channels, le_type, pe_type, frequency_num, max_radius, min_radius, harmonics_calculation, legendre_polys, sh_embedding_dims, ffn, num_hidden_layers, capacity, *args, **kwargs)
    307 in_chans = weights.meta["in_chans"]
    308 self.visual = timm.create_model("vit_small_patch16_224", in_chans=in_chans, num_classes=embed_dim)
--> 309 self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False)
    310 self.visual.requires_grad_(False)
    311 self.visual.head.requires_grad_(True)

File /opt/conda/lib/python3.9/site-packages/torchvision/models/_api.py:90, in WeightsEnum.get_state_dict(self, *args, **kwargs)
     89 def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
---> 90     return load_state_dict_from_url(self.url, *args, **kwargs)

File /opt/conda/lib/python3.9/site-packages/torch/hub.py:760, in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name, weights_only)
    758         r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
    759         hash_prefix = r.group(1) if r else None
--> 760     download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    762 if _is_legacy_zip_format(cached_file):
    763     return _legacy_zip_load(cached_file, model_dir, map_location, weights_only)

File /opt/conda/lib/python3.9/site-packages/torch/hub.py:622, in download_url_to_file(url, dst, hash_prefix, progress)
    620 file_size = None
    621 req = Request(url, headers={"User-Agent": "torch.hub"})
--> 622 u = urlopen(req)
    623 meta = u.info()
    624 if hasattr(meta, 'getheaders'):

File /opt/conda/lib/python3.9/urllib/request.py:214, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    212 else:
    213     opener = _opener
--> 214 return opener.open(url, data, timeout)

File /opt/conda/lib/python3.9/urllib/request.py:523, in OpenerDirector.open(self, fullurl, data, timeout)
    521 for processor in self.process_response.get(protocol, []):
    522     meth = getattr(processor, meth_name)
--> 523     response = meth(req, response)
    525 return response

File /opt/conda/lib/python3.9/urllib/request.py:632, in HTTPErrorProcessor.http_response(self, request, response)
    629 # According to RFC 2616, "2xx" code indicates that the client's
    630 # request was successfully received, understood, and accepted.
    631 if not (200 <= code < 300):
--> 632     response = self.parent.error(
    633         'http', request, response, code, msg, hdrs)
    635 return response

File /opt/conda/lib/python3.9/urllib/request.py:561, in OpenerDirector.error(self, proto, *args)
    559 if http_err:
    560     args = (dict, 'default', 'http_error_default') + orig_args
--> 561     return self._call_chain(*args)

File /opt/conda/lib/python3.9/urllib/request.py:494, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
    492 for handler in handlers:
    493     func = getattr(handler, meth_name)
--> 494     result = func(*args)
    495     if result is not None:
    496         return result

File /opt/conda/lib/python3.9/urllib/request.py:641, in HTTPDefaultErrorHandler.http_error_default(self, req, fp, code, msg, hdrs)
    640 def http_error_default(self, req, fp, code, msg, hdrs):
--> 641     raise HTTPError(req.full_url, code, msg, hdrs, fp)

HTTPError: HTTP Error 503: Service Unavailable

Looking forward to your assistance in resolving this issue promptly. Thank you.

Models and dataset on Hugging Face Hub

Really cool work! To make it a bit easier to find and use the models and dataset I've uploaded copies to the Hub. The model and dataset cards are a work in progress but it might also be nicer to transfer the models and datasets to the Microsoft organisation on the Hub.

The dataset is here: https://huggingface.co/datasets/davanstrien/satclip
The models are split into different repos for each model, for example, here: https://huggingface.co/davanstrien/SatCLIP-ResNet18-L10

Let me know if you'd like me to move them to the Microsoft org?

Longer training time than expected

Hi there,

I'm trying to reproduce the pre-training of the SatClip based on S100 datset. In the default.yaml, I changed the following:

  • in_channels parameter to 13 and the vision_layer to moco_resnet50 - as @konstantinklemmer recommended
  • the batch_size to 1000. It is 8k in the paper, but when I set 8k, I run into RuntimeError: DataLoader worker (pid 2968098) is killed by signal: Killed.
  • num_workers to 11 (to mach the number of iterations in one epoch).

I'm also using single A100 GPU, 11 cores and up to 256GB RAM.

The problem I'm facing is that one epoch takes really long time (probably for loading all the images). My data is stored on a SSD with a decent connection to the A100 tower. The time is approximately 36min per epoch which is 6 times more than what is indicated in the paper (i.e. 2 days for 500 epochs on a single A100 GPU).
Do you know what might be the problem?
May I ask which parameters and machines you used for training with moco_resnet50?

Kind regards,
Elena

Question about 13th band

Hi there,

I've noticed that there is in the code of transforms.py in the get_pretrained_s2_train_transform function there is an imput of 0s-filled B10 band:

B10 = np.zeros((1, *image.shape[1:]), dtype=image.dtype)
image = np.concatenate([image[:10], B10, image[10:]], axis=0)

I'm just curious - why do you do this?

Kind regards,
Elena

Number of input channels

Hi there,

I'm trying to reproduce the pre-training of the SatClip based on S100 datset. I downloaded S100 and changed the paths in the config file default.yaml and in s2geo_dataset.py.
Now, this is the output error that I'm trying to solve:

using vision transformer
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
skipped 8142/100000 images because they were smaller than 10000 bytes... they probably contained nodata pixels
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params
-----------------------------------------
0 | model    | SatCLIP     | 88.9 M
1 | loss_fun | SatCLIPLoss | 0     
-----------------------------------------
88.9 M    Trainable params
0         Non-trainable params
88.9 M    Total params
355.445   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params
-----------------------------------------
0 | model    | SatCLIP     | 88.9 M
1 | loss_fun | SatCLIPLoss | 0     
-----------------------------------------
88.9 M    Trainable params
0         Non-trainable params
88.9 M    Total params
355.445   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.
Sanity Checking DataLoader 0:   0%|                                                                                                                   | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/data/eplekh/code/satclip/satclip/main.py", line 159, in <module>
    cli_main(config_fn)
  File "/data/eplekh/code/satclip/satclip/main.py", line 144, in cli_main
    cli.trainer.fit(
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/data/eplekh/code/satclip/satclip/main.py", line 75, in validation_step
    loss = self.common_step(batch, batch_idx)
  File "/data/eplekh/code/satclip/satclip/main.py", line 66, in common_step
    logits_per_image, logits_per_coord = self.model(images, t_points)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/eplekh/code/satclip/satclip/model.py", line 365, in forward
    image_features = self.encode_image(image)     
  File "/data/eplekh/code/satclip/satclip/model.py", line 358, in encode_image
    return self.visual(image.type(self.dtype))
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/eplekh/code/satclip/satclip/model.py", line 230, in forward
    x = self.conv1(x)  # shape = [*, width, grid, grid]
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/eplekh/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [768, 4, 32, 32], expected input[64, 13, 256, 256] to have 4 channels, but got 13 channels instead

It seems the images are found, but somehow the CNN expects 4 channels and get 13, and I'm not sure why.
I tried to change in the ./satclip/configs/default.yaml the row
in_channels: 4
to
in_channels: 13,
but this did not help. Also the file
"/data/eplekh/code/satclip/lightning_logs/version_14077673/./configs/default-latest.yaml" that is created while running the script contains
in_channels: 4
despite I changed the ./satclip/configs/default.yaml.
This might be the reason, but I don't know how to fix it.

Small additional question: is "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]" okay output or does it mean that it does not see the GPU?

Would very much appreciate any help,
Kind rehards,
Elena

ImportError when importing NonGeoDataset from torchgeo.datasets.geo

Description: I am encountering an ImportError while attempting to import NonGeoDataset from torchgeo.datasets.geo module in the satclip codebase. This issue arises when executing the load.py script, specifically when attempting to import NonGeoDataset in the s2geo_dataset.py file.

Steps to Reproduce:

  1. Clone the satclip codebase.
  2. Navigate to the satclip directory.
  3. Execute the load.py script.

Expected Behavior: The load.py script should execute without errors and successfully import NonGeoDataset from torchgeo.datasets.geo module.

Actual Behavior: The following ImportError occurs:

import sys
sys.path.append("./satclip")
from load import get_satclip
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
/tmp/ipykernel_290/977908425.py in <module>
      1 import sys
      2 sys.path.append("./satclip")
----> 3 from load import get_satclip

/mnt/satclip/satclip/load.py in <module>
----> 1 from main import *
      2 
      3 def get_satclip(ckpt_path, device, return_all=False):
      4     ckpt = torch.load(ckpt_path,map_location=device)
      5     ckpt['hyper_parameters'].pop('eval_downstream')

/mnt/satclip/satclip/main.py in <module>
      5 import lightning.pytorch
      6 import torch
----> 7 from datamodules.s2geo_dataset import S2GeoDataModule
      8 from lightning.pytorch.callbacks import ModelCheckpoint
      9 from lightning.pytorch.cli import LightningCLI

/mnt/satclip/satclip/datamodules/__init__.py in <module>
      1 from .transforms import *
----> 2 from .s2geo_dataset import *

/mnt/satclip/satclip/datamodules/s2geo_dataset.py in <module>
      5 import rasterio
      6 from torch import Tensor
----> 7 from torchgeo.datasets.geo import NonGeoDataset
      8 import matplotlib.pyplot as plt
      9 import numpy as np

ImportError: cannot import name 'NonGeoDataset' from 'torchgeo.datasets.geo' (/opt/conda/lib/python3.8/site-packages/torchgeo/datasets/geo.py)

Environment:

  • Operating System: Linux
  • Python Version: 3.8
  • Dependencies:
    • lightning==2.2.2
    • rasterio==1.3.10
    • torchgeo==0.4.1
    • torch==2.2.1

Additional Information:

  • I have confirmed that the torchgeo library is installed correctly.
  • I am using the latest version of the satclip codebase from the main branch.
  • I have checked the torchgeo documentation and source code, but I couldn't find any references to NonGeoDataset within the torchgeo.datasets.geo module.

Please let me know if there are any additional steps I can take to troubleshoot this issue or if there is any further information needed to assist with resolving this problem. Thank you.

Non-uniform distribution of S100 dataset

Hi there,

While exploring the pre-training data, I noticed an issue about S100 dataset that I think can be fixed easily.
I visualized it here
s100_points_distribution

So, basically the problems are:

  1. oversampling in Greenland (probably due to Santinel 2 path, which has much higher visit frequency near poles)
  2. undersampling in tropics and 50-70°N lat

The problem arises due to 1) Santinel 2 path and 2) filtering out the dates with high cloud cover, which impacts the tropics a lot.

I was thinking of a solution for a uniform sampling and realized that the first step of creating S100 is to pick a Santinel tile, and tiles are distributed approx. uniformly. So forcing an algorithm to pick approximately same number of pictures per Santinel tile should fix it.
My easy fix suggestion is to sample uniformly by tile name (the tiles have attribute 's2:mgrs_tile') like this:

df['weight'] = 1./df.groupby('s2:mgrs_tile')['s2:mgrs_tile'].transform('count')
sampledf = df.sample(100000, weights = df.weight)

I know that the SatClip trained on S100 is only a prototype and a proof of concept, but just in case you want to do the experiments with more uniformly distributed pre-training, this seems quite easy to fix :)

Kind regards,
Elena

Environment file

Hello!

I am having trouble getting the code to run out of the box due to the environment/packages, specifically websocket. Would it be possible to provide an environment file?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.