Coder Social home page Coder Social logo

leeesangwon / pytorch-image-retrieval Goto Github PK

View Code? Open in Web Editor NEW
274.0 9.0 55.0 33 KB

A PyTorch framework for an image retrieval task including implementation of N-pair Loss (NIPS 2016) and Angular Loss (ICCV 2017).

License: MIT License

Python 100.00%
metric-learning n-pair-loss angular-loss image-retrieval pytorch deep-metric-learning

pytorch-image-retrieval's Introduction

PyTorch Image Retrieval

A PyTorch framework for an image retrieval task including implementation of N-pair Loss (NIPS 2016) and Angular Loss (ICCV 2017).

Loss functions

We implemented loss functions to train the network for image retrieval.
Batch sampler for the loss function borrowed from here.

  • N-pair Loss (NIPS 2016): Sohn, Kihyuk. "Improved Deep Metric Learning with Multi-class N-pair Loss Objective," Advances in Neural Information Processing Systems. 2016.
  • Angular Loss (ICCV 2017): Wang, Jian. "Deep Metric Learning with Angular Loss," ICCV, 2017

Self-attention module

We attached the self-attention module of the Self-Attention GAN to conventional classification networks (e.g. DenseNet, ResNet, or SENet).
Implementation of the module borrowed from here.

Data augmentation

We adopted data augmentation techniques used in Single Shot MultiBox Detector.

Post processing

We utilized the following post-processing techniques in the inference phase.

pytorch-image-retrieval's People

Contributors

gymlab avatar leeesangwon avatar mikekook avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-image-retrieval's Issues

Ensemble models issue

안녕하세요! 대회 1등 축하드립니다. Cheat Key 팀은 Ensemble 을 적용하셨는지 궁금합니다.
캐글 이미지 검색 대회의 경우 1등의 구조처럼 feature extraction 부분을 3개의 모델에서 나온 feature로 concatenate를 하던데 혹시 시도를 해보셨는지, 안하셨다면 이번 대회에서 단일 모델로만 좋은 성적이 나오신건지 궁금합니다. 또한 어느 부분이 제일 성능 향상에 많은 기여를 했는지 궁금합니다.

How to use this code?

Hi, I am just arriving here but after 30 mins going through this repo. I dont know how to run this repo.
Can you help me?

what is the metrics of BlendedLoss()?

Thank you that you can share your codes with others.
then i do not understand a question in codes.
class BlendedLoss(object):
def init(self, main_loss_type, cross_entropy_flag):
super(BlendedLoss, self).init()
self.main_loss_type = main_loss_type
assert main_loss_type in MAIN_LOSS_CHOICES, "invalid main loss: %s" % main_loss_type

    self.metrics = []

metrics is one set of some metric functions??
here you are not introduction it?
hope your reply! thank you

A question about how to test retrieval

Hi, thank you for building such a useful and concise framework for image retrival task!

I just read the file inference.py and got confused about the testing process.

query_loader = DataLoader(query_img_dataset, batch_size=infer_batch_size, shuffle=False, num_workers=4,
pin_memory=True)
reference_loader = DataLoader(reference_img_dataset, batch_size=infer_batch_size, shuffle=False, num_workers=4,
pin_memory=True)

My question is, why is the database for retrieval getting divided into batches? Shouldn't we keep it as a whole for every query, since we want to rank all the reference images in whole database according to the pair-wise similarity?

RuntimeError: CUDA out of memory. Tried to allocate 240.00 MiB (GPU 0; 14.76 GiB total capacity; 13.54 GiB already allocated; 73.75 MiB free; 13.63 GiB reserved in total by PyTorch)

I am new to using PyTorch. I am using Google Colaboratory, Accelerator is GPU.

Attempted Solutions (same error):

  • torch.cuda.empty_cache(), suggested here.
  • torch.cuda.memory_summary(device=None, abbreviated=False), suggested here.
  • batch_size=1 for dm instance.

Attempted Solutions (different error):

  • torch.cuda.clear_memory_allocated(), suggested here.
    AttributeError: module 'torch.cuda' has no attribute 'clear_memory_allocated'

Code:

trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=loggers)
#torch.cuda.empty_cache()
#torch.cuda.memory_summary(device=None, abbreviated=False)
#torch.cuda.clear_memory_allocated()
trainer.fit(model, dm) # ERROR
model_file = os.path.join(args.modeldir, 'last.ckpt')
trainer.save_checkpoint(model_file, weights_only=True)

Error:

RuntimeError: CUDA out of memory. Tried to allocate 240.00 MiB (GPU 0; 11.17 GiB total capacity; 10.58 GiB already allocated; 63.81 MiB free; 10.66 GiB reserved in total by PyTorch)

Output and Traceback:

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMulticlassSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForMulticlassSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMulticlassSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMulticlassSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifiers.1.bias', 'classifiers.0.weight', 'classifiers.1.weight', 'classifiers.0.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration. You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).
  warnings.warn(*args, **kwargs)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                                    | Params
----------------------------------------------------------------------------
0 | model           | BertForMulticlassSequenceClassification | 108 M 
1 | valid_acc       | Accuracy                                | 0     
2 | valid_f1        | F1                                      | 0     
3 | valid_acc_multi | ModuleList                              | 0     
----------------------------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
433.413   Total estimated model params size (MB)
###score: val_score### 0.0625
/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric Accuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric F1 was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric CompositionalMetric was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
Epoch 0: 0% 0/3715 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-13-89b1728bd5a6> in <module>()
      7       --gpus 1
      8     """.split()
----> 9 run_training(args)

37 frames
<ipython-input-5-7f8e9eed480d> in run_training(input)
     68 
     69     trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=loggers)
---> 70     trainer.fit(model, dm)
     71     model_file = os.path.join(args.modeldir, 'last.ckpt')
     72     trainer.save_checkpoint(model_file, weights_only=True)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    497 
    498         # dispath `start_training` or `start_testing` or `start_predicting`
--> 499         self.dispatch()
    500 
    501         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    544 
    545         else:
--> 546             self.accelerator.start_training(self)
    547 
    548     def train_or_test_or_predict(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     71 
     72     def start_training(self, trainer):
---> 73         self.training_type_plugin.start_training(trainer)
     74 
     75     def start_testing(self, trainer):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    112     def start_training(self, trainer: 'Trainer') -> None:
    113         # double dispatch to initiate the training loop
--> 114         self._results = trainer.run_train()
    115 
    116     def start_testing(self, trainer: 'Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    635                 with self.profiler.profile("run_training_epoch"):
    636                     # run train epoch
--> 637                     self.train_loop.run_training_epoch()
    638 
    639                 if self.max_steps and self.max_steps <= self.global_step:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    490             # ------------------------------------
    491             with self.trainer.profiler.profile("run_training_batch"):
--> 492                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    493 
    494             # when returning -1 from train_step, we end epoch early

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    652 
    653                         # optimizer step
--> 654                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    655 
    656                     else:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    431             on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE,
    432             using_native_amp=using_native_amp,
--> 433             using_lbfgs=is_lbfgs,
    434         )
    435 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1388             # wraps into LightingOptimizer only for running step
   1389             optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx)
-> 1390         optimizer.step(closure=optimizer_closure)
   1391 
   1392     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
    213 
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1
    216 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
    132 
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    135 
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    275         )
    276         if make_optimizer_step:
--> 277             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    278         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    279         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    280 
    281     def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
--> 282         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    283 
    284     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    161 
    162     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 163         optimizer.step(closure=lambda_closure, **kwargs)

/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     86                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87                 with torch.autograd.profiler.record_function(profile_name):
---> 88                     return func(*args, **kwargs)
     89             return wrapper
     90 

/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

/usr/local/lib/python3.7/dist-packages/torch/optim/adam.py in step(self, closure)
     64         if closure is not None:
     65             with torch.enable_grad():
---> 66                 loss = closure()
     67 
     68         for group in self.param_groups:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
    647                         def train_step_and_backward_closure():
    648                             result = self.training_step_and_backward(
--> 649                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    650                             )
    651                             return None if result is None else result.loss

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    740         with self.trainer.profiler.profile("training_step_and_backward"):
    741             # lightning module hook
--> 742             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    743             self._curr_step_result = result
    744 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    291             model_ref._results = Result()
    292             with self.trainer.profiler.profile("training_step"):
--> 293                 training_step_output = self.trainer.accelerator.training_step(args)
    294                 self.trainer.accelerator.post_training_step()
    295 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
    154 
    155         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 156             return self.training_type_plugin.training_step(*args)
    157 
    158     def post_training_step(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
    123 
    124     def training_step(self, *args, **kwargs):
--> 125         return self.lightning_module.training_step(*args, **kwargs)
    126 
    127     def post_training_step(self):

<ipython-input-4-a6cb4f83dcb2> in training_step(self, batch, batch_idx)
    104     def training_step(self, batch, batch_idx):
    105         x, y_true = batch
--> 106         loss, _ = self(x, labels=y_true)
    107         self.log('train_loss', loss)
    108         return loss

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-4-a6cb4f83dcb2> in forward(self, *input, **kwargs)
    100 
    101     def forward(self, *input, **kwargs):
--> 102         return self.model(*input, **kwargs)
    103 
    104     def training_step(self, batch, batch_idx):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-4-a6cb4f83dcb2> in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)
     50             output_attentions=output_attentions,
     51             output_hidden_states=output_hidden_states,
---> 52             return_dict=return_dict,
     53         )
     54 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    999             output_attentions=output_attentions,
   1000             output_hidden_states=output_hidden_states,
-> 1001             return_dict=return_dict,
   1002         )
   1003         sequence_output = encoder_outputs[0]

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    587                     encoder_attention_mask,
    588                     past_key_value,
--> 589                     output_attentions,
    590                 )
    591 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
    473             head_mask,
    474             output_attentions=output_attentions,
--> 475             past_key_value=self_attn_past_key_value,
    476         )
    477         attention_output = self_attention_outputs[0]

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
    406             encoder_attention_mask,
    407             past_key_value,
--> 408             output_attentions,
    409         )
    410         attention_output = self.output(self_outputs[0], hidden_states)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
    303 
    304         # Take the dot product between "query" and "key" to get the raw attention scores.
--> 305         attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    306 
    307         if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":

RuntimeError: CUDA out of memory. Tried to allocate 240.00 MiB (GPU 0; 11.17 GiB total capacity; 10.58 GiB already allocated; 63.81 MiB free; 10.66 GiB reserved in total by PyTorch)

label:question Hard negative mining?

Did you try hard negative mining or some other technique to improve recall and accuracy? For example I have the images similar to online products or clothing dataset with about 10 000 identities, divided into about 30 supercategories. Like for example jeans ( 300 identities ), shoes ( 400 identities ), etc. When I use your code, it is able to distinguish between jeans and shoes, but properly distinct one shoe from another is problematic. Do you have some tips how to be able to correctly identify among such big number of identities which are also clustered into number of supercategories?

real world scalability?

Hi,
in db_augmentation and average_query_expansion you are calculating similarity matrix. But when I want to search for example in cub200_2011 dataset, you have 20k+ reference points and do not have enough memory for this. In fact, my computer with 16GB ram and gtx1080ti runs out of memory even if I search among 50 images.

Do you have any sugestions how to scale this?
Thanks,
T

RuntimeError: CUDA out of memory. Tried to allocate

Thanks for your great work on metric learning.
I have a question during I run the code.
I run the train code on cub200 data set, the script for the training is:
CUDA_VISIBLE_DEVICES='0' python main.py --model inceptionv3 --mode train --dataset-path ../../train_data/CUB_train_test/ --scheduler StepLR --input-size 299 --loss-type angular --model-save-dir ./models --num-classes 32

And I met the OOM error.
lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward
self.padding, self.dilation, self.groups)
RuntimeError: CUDA out of memory. Tried to allocate 492.88 MiB (GPU 0; 22.38 GiB total capacity; 13.54 GiB already allocated; 289.06 MiB free; 8.84 MiB cached)

Have you met this problem?
Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.