Coder Social home page Coder Social logo

facecoresetnet's Introduction

FaceCoresetNet: Differentiable Coresets for Face Set Recognition

Official github repository for Differentiable Coresets for Face Set Recognition

Published at AAAI2024

Abstract: In set-based face recognition, we aim to compute the most discriminative descriptor from an unbounded set of images and videos showing a single person. A discriminative descriptor balances two policies when aggregating information from a given set. The first is a quality-based policy: emphasizing high-quality and down-weighting low-quality images. The second is a diversity-based policy: emphasizing unique images in the set and down-weighting multiple occurrences of similar images as found in video clips which can overwhelm the set representation. This work frames face-set representation as a differentiable coreset selection problem. Our model learns how to select a small coreset of the input set that balances quality and diversity policies using a learned metric parameterized by the face quality, optimized end-to-end. The selection process is a differentiable farthest-point sampling (FPS) realized by approximating the non-differentiable Argmax operation with differentiable sampling from the Gumbel-Softmax distribution of distances. The small coreset is later used as queries in a self and cross-attention architecture to enrich the descriptor with information from the whole set. Our model is order-invariant and linear in the input set size. We set a new SOTA to set face verification on the IJB-B and IJB-C datasets. Our code is publicly available \footnote{\url{https://github.com/ligaripash/FaceCoresetNet/}}.

Installation and Preparation

1. Environment

We use pytorch (1.10.0) in our experiments.

pip install -r requirements.txt

2. Pretrained Models

We release the FaceCoresetNet model pretrained on AdaFace backbone. The backbone is trained on WebFace4M dataset. And FaceCoresetNet is trained on a subset of WebFace4M dataset.

Place FaceCoresetNet.pth under pretrained_models/

pretrained_models/
├── FaceCoresetNet.ckpt                         

Evaluation

IJBB and IJBC

For evaluation with IJBB/IJBC you may download the related files from.

Place the downloaded files in <DATA_ROOT>, i.e

<DATA_ROOT>
└── IJB
    ├── aligned (only needed during training)                                                                                                                      │➜  ffhq mv FFHQ_png_512.zip /hddata/data/ffhq/
    └── insightface_helper
        ├── ijb                                                                                                                             │➜  ffhq mv FFHQ_png_512.zip /hddata/data/ffhq/
        └── meta        

For faster validation please download the IJB AdaFace backbone features:

Please place both these files in the directory: validation_IJBB_IJBC

validation_IJBB_IJBC/
├── IJBB-AdaFace-Backbone-Features.pickle
└── IJBC-AdaFace-Backbone-Features.pickle                  

Refer to the below code for evaluation.

1. Update eval.sh with your DATA_ROOT
2. bash ./eval.sh  

Training from scratch

WebFace4M Subset (as in paper)

The model was trained on a WebFace4M subset that can be downloaded here AdaFace4M_subset.

  • Get pretrained face recognition model backbone

For training script, refer to

cd FaceCoresetNet
bash ./train.sh  # DATA_ROOT has to be specified. 

facecoresetnet's People

Contributors

ligaripash avatar yosikeller avatar

Stargazers

xyz avatar Jiong Wang avatar  avatar Pospielov Serhii avatar

Watchers

 avatar

facecoresetnet's Issues

the cross attention error for Unaligned q and K, v

Epoch 0: 0%| | 1/34142 [02:11<1247:13:53, 131.51s/it, loss=39.7, v_num=bl_0]Traceback (most recent call last):
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 90, in launch
return function(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 621, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1058, in _run
results = self._run_stage()
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1137, in _run_stage
self._run_train()
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1160, in _run_train
self.fit_loop.run()
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 214, in advance
batch_output = self.batch_loop.run(kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
outputs = self.optimizer_loop.run(optimizers, kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 200, in advance
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 247, in _run_optimization
self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 357, in _optimizer_step
self.trainer._call_lightning_module_hook(
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1302, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/core/module.py", line 1661, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 169, in step
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 281, in optimizer_step
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 234, in optimizer_step
return self.precision_plugin.optimizer_step(
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 121, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
return wrapped(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/optim/optimizer.py", line 88, in wrapper
return func(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/optim/adamw.py", line 92, in step
loss = closure()
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 107, in _wrap_closure
closure_result = closure()
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 147, in call
self._result = self.closure(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 133, in closure
step_output = self._step_fn()
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 406, in _training_step
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1440, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 352, in training_step
return self.model(*args, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 98, in forward
output = self._forward_module.training_step(*inputs, **kwargs)
File "/workspace/FaceCoresetNet/train_val_template.py", line 152, in training_step
cos_thetas, norms, embeddings, labels = self.forward(images, labels)
File "/workspace/FaceCoresetNet/train_val_template.py", line 127, in forward
aggregate_embeddings, aggregate_norms, FPS_sample = self.aggregate_model(embeddings, norms, only_FPS)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/workspace/FaceCoresetNet/net_template.py", line 459, in forward
aggregated_feature, FPS_sample = self.aggregate_fps_with_norm_priority(template_features, template_norms, only_FPS)
File "/workspace/FaceCoresetNet/net_template.py", line 449, in aggregate_fps_with_norm_priority
delta = self.decoder_layer1(core_template, norm_encoding_core_template, template_features, norm_encoding_template)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/workspace/FaceCoresetNet/net_template.py", line 152, in forward
q = self.norm2(core_template + self._my_mha_block1(q, k, v))
File "/workspace/FaceCoresetNet/net_template.py", line 177, in _my_mha_block1
x = self.multihead_attn1(q, k, v,
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 1003, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/opt/conda/envs/facecoresetnet/lib/python3.8/site-packages/torch/nn/functional.py", line 5044, in multi_head_attention_forward
k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
RuntimeError: shape '[16, 40, 64]' is invalid for input of size 139264

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.