orrzohar / fomo Goto Github PK
View Code? Open in Web Editor NEWOfficial Pytorch code for Open World Object Detection in the Era of Foundation Models
License: Apache License 2.0
Official Pytorch code for Open World Object Detection in the Era of Foundation Models
License: Apache License 2.0
Thanks for your great work. I'd like to ask whether there is a training process of FOMO, or we only need to load from pretrained parameters from OWL-ViT? It seems all start scripts are evaluation of different cases? I'm really confused with training.
Thank you for your outstanding work! Could you kindly help me with the following three specific questions:
1)Why is an additional dimension concatenation required in this case?
code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L351
if args.use_attributes:
self.att_embeds = torch.cat([self.att_embeds, torch.matmul(self.att_embeds.squeeze().T, self.att_W).mean(1, keepdim=True).T.unsqueeze(0)], dim=1)
2)FOMO needs attributes selected for each category, but the current implementation doesn't guarantee an equal number of attributes selected for each category.
code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L441
_, top_indices = torch.topk(self.att_W.view(-1), num_classes * self.num_attributes_per_class)
3)Is the training and evaluation process consistent in computing attribute scores?
Training Stage:without learnable parameters logit_shift and logit_scale
code for attribute_refinement: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L394
code for attribute_selection: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L431
cos_sim = cosine_similarity(image_embeddings, self.att_embeds, dim=-1)
Eval Stage:with learnable parameters logit_shift and logit_scale
pred_logits = (pred_logits + logit_shift) * logit_scale
code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L643
(pred_logits, class_embeds) = self.model.class_predictor(image_feats, self.att_embeds.repeat(batch_size, 1, 1),
self.att_query_mask)
def class_predictor(
self,
image_feats: torch.FloatTensor,
query_embeds: Optional[torch.FloatTensor] = None,
query_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor]:
"""
Args:
image_feats:
Features extracted from the `image_text_embedder`.
query_embeds:
Text query embeddings.
query_mask:
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
"""
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
return (pred_logits, image_class_embeds)
class OwlViTClassPredictionHead(nn.Module):
def __init__(self, config: OwlViTConfig):
super().__init__()
out_dim = config.text_config.hidden_size
self.query_dim = config.vision_config.hidden_size
self.dense0 = nn.Linear(self.query_dim, out_dim)
self.logit_shift = nn.Linear(self.query_dim, 1)
self.logit_scale = nn.Linear(self.query_dim, 1)
self.elu = nn.ELU()
def forward(
self,
image_embeds: torch.FloatTensor,
query_embeds: Optional[torch.FloatTensor],
query_mask: Optional[torch.Tensor],
) -> Tuple[torch.FloatTensor]:
image_class_embeds = self.dense0(image_embeds)
if query_embeds is None:
device = image_class_embeds.device
batch_size, num_patches = image_class_embeds.shape[:2]
pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
return (pred_logits, image_class_embeds)
# Normalize image and text features
image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
query_embeds /= torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6
# Get class predictions
pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
# Apply a learnable shift and scale to logits
logit_shift = self.logit_shift(image_embeds)
logit_scale = self.logit_scale(image_embeds)
logit_scale = self.elu(logit_scale) + 1
pred_logits = (pred_logits + logit_shift) * logit_scale
if query_mask is not None:
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)
return (pred_logits, image_class_embeds)
Your excellent will be a great help to my research!
I have download five datasets and reorgnize them into data/RWD
following your instruction. But I found that the number of images for each dataset cannot match your train/test split in data/ImageSets/*/train.txt or test.txt
. Do you select the part of original images for each dataset to setup RWD benchmark?
Looking forward your reply! @orrzohar
Great work!!! Respect !
I cannot wait to the release of your benchmark!
What kind of computing resources did you use during the training process? Are 4 RTX NVIDIA 3090s enough to run the experiments in your paper? If possible, how long will the training take?
First of all, thank you for the very impressive work.
While working with gen_attributes.ipynb to generate attributes, I've noticed that only the 'medical' dataset has a corresponding domain variable, specifically 'xray images of the bones in the hands.' I was wondering if you could kindly provide guidance on how to specify the domain information for the other datasets (Aquatic, Aerial, Game, Surgery)? Your assistance would be greatly appreciated.
Thank you very much for your excellent study. I have some questions about the experiment:
Are there any additional experiments on the performance of FOMO in the OWD benchmark results?
How is the inference speed (FPS) evaluated?
Hello, I currently have only one machine but it has two graphics cards. How should I modify the code to train on such a machine?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.