Coder Social home page Coder Social logo

johny-c / pylmnn Goto Github PK

View Code? Open in Web Editor NEW
47.0 3.0 14.0 3.04 MB

Large Margin Nearest Neighbors implementation in python

License: BSD 3-Clause "New" or "Revised" License

Python 100.00%
machine-learning metric-learning nearest-neighbor-search lmnn large-margin-nearest-neighbors

pylmnn's Introduction

PyLMNN

PyLMNN is an implementation of the Large Margin Nearest Neighbor algorithm for metric learning in pure python.

This implementation follows closely the original MATLAB code by Kilian Weinberger found at https://bitbucket.org/mlcircus/lmnn. This version solves the unconstrained optimisation problem and finds a linear transformation using L-BFGS as the backend optimizer.

This package can also find optimal hyper-parameters for LMNN via Bayesian Optimization using the excellent GPyOpt package.

Installation

The code was developed in python 3.5 under Ubuntu 16.04 and was also tested under Ubuntu 18.04 and python 3.6. You can clone the repo with:

git clone https://github.com/johny-c/pylmnn.git

or install it via pip:

pip3 install pylmnn

Dependencies

  • numpy>=1.11.2
  • scipy>=0.18.1
  • scikit_learn>=0.18.1

Optional dependencies

In case you want to use the hyperparameter optimization module, you should also install:

  • GPy>=1.5.6
  • GPyOpt>=1.0.3

Usage

Here is a minimal use case:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

from pylmnn import LargeMarginNearestNeighbor as LMNN


# Load a data set
X, y = load_iris(return_X_y=True)

# Split in training and testing set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.7, stratify=y, random_state=42)

# Set up the hyperparameters
k_train, k_test, n_components, max_iter = 3, 3, X.shape[1], 180

# Instantiate the metric learner
lmnn = LMNN(n_neighbors=k_train, max_iter=max_iter, n_components=n_components)

# Train the metric learner
lmnn.fit(X_train, y_train)

# Fit the nearest neighbors classifier
knn = KNeighborsClassifier(n_neighbors=k_test)
knn.fit(lmnn.transform(X_train), y_train)

# Compute the k-nearest neighbor test accuracy after applying the learned transformation
lmnn_acc = knn.score(lmnn.transform(X_test), y_test)
print('LMNN accuracy on test set of {} points: {:.4f}'.format(X_test.shape[0], lmnn_acc))

You can check the examples directory for a demonstration of how to use the code with different datasets and how to estimate good hyperparameters with Bayesian Optimisation.

Documentation can also be found at http://pylmnn.readthedocs.io/en/latest/ .

References

If you use this code in your work, please cite the following publication.

@ARTICLE{weinberger09distance,
    title={Distance metric learning for large margin nearest neighbor classification},
    author={Weinberger, K.Q. and Saul, L.K.},
    journal={The Journal of Machine Learning Research},
    volume={10},
    pages={207--244},
    year={2009},
    publisher={MIT Press}
}

License and Contact

This work is released under the 3-Clause BSD License.

Contact John Chiotellis โœ‰๏ธ for questions, comments and reporting bugs.

pylmnn's People

Contributors

johny-c avatar luccaportes avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pylmnn's Issues

Locality Sensitive Hashing / Approximate Nearest Neighbours Techniques

Does anyone know of any similar implementations that uses something like Faiss to improve the performance of the nearest neighbour step of the calculation? If not is it something that would be reasonable to add to this library?

Something like

class FaissKNeighbors:
    def __init__(self, k=5):
        self.index = None
        self.y = None
        self.k = k

    def fit(self, X, y):
        self.index = faiss.IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        self.y = y

    def predict(self, X):
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        votes = self.y[indices]
        predictions = np.array([np.argmax(np.bincount(x)) for x in votes])
        return predictions

from https://gist.github.com/j-adamczyk/74ee808ffd53cd8545a49f185a908584#file-knn_with_faiss-py provides an sklearn like interface to faiss and should almost be able to act as a drop in replacement.

As it then leads to an approximate result it should only be added as an option. e.g. nn_backend='faiss' or nn_backend='sklearn'

exception caused by missing logger_ key

The "del state['logger_']" line in the getstate method is causing problems because the 'logger' key does not exist.

wrapping the "del state['logger_']" line in a try/except appears to alleviate the problem.

try:
del state['logger_']
except KeyError:
pass

Change in "six" import

Hi, I suggest a small change in line 27 of pylmnn.py. Which is:

from sklearn.externals.six import integer_types, string_types

Which raises the following warning when executed:

DeprecationWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).
  "(https://pypi.org/project/six/).", DeprecationWarning)

So in order to maintain compatibility with both older and future versions, I propose something like this:

try:
    from six import integer_types, string_types
except ImportError:
    try:
        from sklearn.externals.six import integer_types, string_types
    except ImportError:
        raise ImportError("The module six must be installed or the version of scikit-learn version must be < 0.23")

I'll open a pull request in case you think it is a good idea.

Option to save the fit model?

The training step takes a long time. Is there way to save the model after the .fit method. If not should I submit a PR to enable it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.