Coder Social home page Coder Social logo

adeelh / pytorch-multi-class-focal-loss Goto Github PK

View Code? Open in Web Editor NEW
220.0 4.0 25.0 28 KB

An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, generalized to the multi-class case.

License: MIT License

Python 100.00%
pytorch deep-learning loss-functions classification neural-network retinanet multiclass-classification imbalanced-classes pytorch-implementation implementation-of-research-paper machine-learning

pytorch-multi-class-focal-loss's Introduction

DOI

Multi-class Focal Loss

An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, https://arxiv.org/abs/1708.02002, generalized to the multi-class case.

It is essentially an enhancement to cross-entropy loss and is useful for classification tasks when there is a large class imbalance. It has the effect of underweighting easy examples.

Usage

  • FocalLoss is an nn.Module and behaves very much like nn.CrossEntropyLoss() i.e.

    • supports the reduction and ignore_index params, and
    • is able to work with 2D inputs of shape (N, C) as well as K-dimensional inputs of shape (N, C, d1, d2, ..., dK).
  • Example usage

    focal_loss = FocalLoss(alpha, gamma)
    ...
    inp, targets = batch
    out = model(inp)
    loss = focal_loss(out, targets)

Loading through torch.hub

This repo supports importing modules through torch.hub. FocalLoss can be easily imported into your code via, for example:

focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='FocalLoss',
	alpha=torch.tensor([.75, .25]),
	gamma=2,
	reduction='mean',
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)

Or:

focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='focal_loss',
	alpha=[.75, .25],
	gamma=2,
	reduction='mean',
	device='cpu',
	dtype=torch.float32,
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)

pytorch-multi-class-focal-loss's People

Contributors

adeelh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch-multi-class-focal-loss's Issues

Different returning type values

Hi there!

I think a little line of the code can cause ValueError or something especially when used in automatic optimization like Lightning.
In particular, line 69 should return a torch.Tensor type and not a float, as your (correctly) type annotation of forward function.

Usage on the focal loss

I tried to replace the F.Cross_Entropy(out, target) with focal_loss(out, target), however, there is problem in the weight tensor.
In the implementation, it will call the ce = self.nll_loss(log_p, y).
But there will be problem:
RuntimeError: weight tensor should be defined either for all _ classes or no classes but got weight tensor of shape:[2]

It seems that the code is not applicable on the latest version of Pytorch ?

focal loss for number of classes greater then 2

Hi,

I would like to know if this implementation is good even if the problem has a number of classes greater than two. Suppose I have four classes and the dataset is unbalanced with the following distribution [90%, 6%, 2%, 2%]. Should I then define an array of weights alpha for each class? For example alpha=torch.tensor([0.1, 0.94, 0.98, 0.98]) with the weights assigned to each class inversely proportional to their distribution? It's correct?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.