Coder Social home page Coder Social logo

ximilar-com / tf-metric-learning Goto Github PK

View Code? Open in Web Editor NEW
38.0 3.0 5.0 134 KB

Minimalistic TensorFlow2+ deep metric/similarity learning library with loss functions, miners, and utils as embedding projector.

License: MIT License

Python 100.00%
metric-learning computer-vision similarity-search image-retrieval machine-learning python tensorflow deep-learning deep-metric-learning visual-search

tf-metric-learning's Introduction

tf-metric-learning

TensorFlow 2.2 Python 3.6

Overview

Minimalistic open-source library for metric learning written in TensorFlow2, TF-Addons, Numpy, OpenCV(CV2) and Annoy. This repository contains a TensorFlow2+/tf.keras implementation some of the loss functions and miners. This repository was inspired by pytorch-metric-learning.

Installation

Prerequirements:

pip install tensorflow
pip install tensorflow-addons
pip install annoy
pip install opencv-contrib-python

This library:

pip install tf-metric-learning

Features

  • All the loss functions are implemented as tf.keras.layers.Layer
  • Callbacks for Computing Recall, Visualize Embeddings in TensorBoard Projector
  • Simple Mining mechanism with Annoy
  • Combine multiple loss functions/layers in one model

Open-source repos

This library contains code that has been adapted and modified from the following great open-source repos, without them this will be not possible (THANK YOU):

TODO

  • Discriminative layer optimizer (different learning rates) for Loss with weights (Proxy, SoftTriple, ...) TODO
  • Some Tests ๐Ÿ˜‡
  • Improve and add more minerss

Examples

import tensorflow as tf
import numpy as np

from tf_metric_learning.layers import SoftTripleLoss
from tf_metric_learning.utils.constants import EMBEDDINGS, LABELS

num_class, num_centers, embedding_size = 10, 2, 256

inputs = tf.keras.Input(shape=(embedding_size), name=EMBEDDINGS)
input_label = tf.keras.layers.Input(shape=(1,), name=LABELS)
output_tensor = SoftTripleLoss(num_class, num_centers, embedding_size)({EMBEDDINGS:inputs, LABELS:input_label})

model = tf.keras.Model(inputs=[inputs, input_label], outputs=output_tensor)
model.compile(optimizer="adam")

data = {EMBEDDINGS : np.asarray([np.zeros(256) for i in range(1000)]), LABELS: np.zeros(1000, dtype=np.float32)}
model.fit(data, None, epochs=10, batch_size=10)

More complex scenarios:

Features

Loss functions

Miners

  • MaximumLossMiner [TODO]
  • TripletAnnoyMiner โœ…

Evaluators

  • AnnoyEvaluator Callback: for evaluation Recall@K, you will need to install Spotify annoy library.
import tensorflow as tf
from tf_metric_learning.utils.recall import AnnoyEvaluatorCallback

evaluator = AnnoyEvaluatorCallback(
    base_network,
    {"images": test_images[:divide], "labels": test_labels[:divide]}, # images stored to index
    {"images": test_images[divide:], "labels": test_labels[divide:]}, # images to query
    normalize_fn=lambda images: images / 255.0,
    normalize_eb=True,
    eb_size=embedding_size,
    freq=1,
)

Visualizations

  • Tensorboard Projector Callback
import tensorflow as tf
from tf_metric_learning.utils.projector import TBProjectorCallback

def normalize_images(images):
    return images/255.0

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
...

projector = TBProjectorCallback(
    base_model,
    "tb/projector",
    test_images, # list of images
    np.squeeze(test_labels),
    normalize_eb=True,
    normalize_fn=normalize_images
)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.