Coder Social home page Coder Social logo

Multi class about focal-tversky-unet HOT 20 CLOSED

nabsabraham avatar nabsabraham commented on August 16, 2024
Multi class

from focal-tversky-unet.

Comments (20)

phernst avatar phernst commented on August 16, 2024 9

For a C class problem, you need to have C output channels. The activation should be softmax (but sigmoid should still work, too). Try this modification of the loss function:

import keras.backend as K

def class_tversky(y_true, y_pred):
    smooth = 1

    y_true = K.permute_dimensions(y_true, (3,1,2,0))
    y_pred = K.permute_dimensions(y_pred, (3,1,2,0))

    y_true_pos = K.batch_flatten(y_true)
    y_pred_pos = K.batch_flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos, 1)
    false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
    false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
    alpha = 0.7
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def focal_tversky_loss(y_true,y_pred):
    pt_1 = class_tversky(y_true, y_pred)
    gamma = 0.75
    return K.sum(K.pow((1-pt_1), gamma))

The code assumes 2d images and channels last (i.e. the shape of y_true and y_pred should be [num_batches, height, width, num_channels]).

from focal-tversky-unet.

DecentMakeover avatar DecentMakeover commented on August 16, 2024 1

@nabsabraham Thanks for the comments ill check these suggestions

from focal-tversky-unet.

phernst avatar phernst commented on August 16, 2024 1

Basically, if you flatten into 1 dimension, your classes will not be weighted equally any more (which is already true for the Dice loss, btw).

Short example: imagine a 2-class problem. Let's say the object in class 1 is really large and the object in class 2 is only one pixel. Assume there is only one wrong pixel in your prediction compared to the ground truth. Let this pixel be the one pixel of class 2 which was falsely assigned to class 1.

Now, if you calculate the Dice on the flattened output, you will still be very close to 1.
Otherwise, if you calculate the Dice per class and take the mean, you will get some value around 0.5.

Note that I used the sum instead of the mean in the code above. For the optimizing task, this is equivalent but doesn't need the division by a constant (i.e. the number of classes).

from focal-tversky-unet.

nabsabraham avatar nabsabraham commented on August 16, 2024

thank you for your question @shihlun1208 and your super clean code @phernst :)

from focal-tversky-unet.

DecentMakeover avatar DecentMakeover commented on August 16, 2024

@phernst , @nabsabraham

I am working with 3D data, will this work ?

def class_tversky(y_true, y_pred):
    smooth = 1

    y_true = K.permute_dimensions(y_true, (1,2,3,4,0))
    y_pred = K.permute_dimensions(y_pred, (1,2,3,4,0))

    y_true_pos = K.batch_flatten(y_true)
    y_pred_pos = K.batch_flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos, 1)
    false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
    false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
    alpha = 0.7
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

Thanks

from focal-tversky-unet.

phernst avatar phernst commented on August 16, 2024

The permutation should be (4,1,2,3,0). The reason behind this is that you want to keep the classes (here: channels at dimension 4) as "batches". Basically, every permutation that puts dimension 4 to the first place will work, e.g. (4,1,3,2,0), (4,3,2,1,0), ...

from focal-tversky-unet.

DecentMakeover avatar DecentMakeover commented on August 16, 2024

Hmmm, okay let me check.

And just to confirm the lower the loss the better right?

Also just a side question,
I am using a custom unet model to perform instance segmentation, 26 classes, Currently the results are very poor, do you have any suggestions?

from focal-tversky-unet.

nabsabraham avatar nabsabraham commented on August 16, 2024

The tversky function outputs a value between 0-1 and the tl function outputs 1-tversky(*) so the closer the loss is to 0, the better. (fyi, your ground truths must be one hot encoded).

I haven't ever done instance seg but I believe cross entropy is a good starting point. The TL/FTL require some parameter tuning so I would start with CE and then maybe weighted-CE and then try DL/TL/FTL. Also with CE, add an extra class for the background class - this always seems to improve my results on semantic segmentation.

from focal-tversky-unet.

DecentMakeover avatar DecentMakeover commented on August 16, 2024

@nabsabraham Hi,Since you mention in the readme that this network is good for small lesions, any thoughts on how do you think this network will perform on kidney or verterbra segmentation ?

Can your model be extended as a 3D model, i mean is the extension not very complicated?

Thanks

from focal-tversky-unet.

ledakk avatar ledakk commented on August 16, 2024

For a C class problem, you need to have C output channels. The activation should be softmax (but sigmoid should still work, too). Try this modification of the loss function:

import keras.backend as K

def class_tversky(y_true, y_pred):
    smooth = 1

    y_true = K.permute_dimensions(y_true, (3,1,2,0))
    y_pred = K.permute_dimensions(y_pred, (3,1,2,0))

    y_true_pos = K.batch_flatten(y_true)
    y_pred_pos = K.batch_flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos, 1)
    false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
    false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
    alpha = 0.7
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def focal_tversky_loss(y_true,y_pred):
    pt_1 = class_tversky(y_true, y_pred)
    gamma = 0.75
    return K.sum(K.pow((1-pt_1), gamma))

The code assumes 2d images and channels last (i.e. the shape of y_true and y_pred should be [num_batches, height, width, num_channels]).

Why the original model cannot be applied to mult class ?
Convert label to one-hot encoding, It‘s value range also in [0-1].
(sorry , my english is not very good )

from focal-tversky-unet.

phernst avatar phernst commented on August 16, 2024

Could you explain a bit more in detail? Do you mean a pixel-wise one-hot vector like [0,0,1] (so, for a 3-class problem, this one would encode the last class), or a single number per pixel (so, for a 3-class problem, 0 encodes class 1, 0.5 encodes class 2 and 1 encodes class 3)?

from focal-tversky-unet.

ledakk avatar ledakk commented on August 16, 2024

Could you explain a bit more in detail? Do you mean a pixel-wise one-hot vector like [0,0,1] (so, for a 3-class problem, this one would encode the last class), or a single number per pixel (so, for a 3-class problem, 0 encodes class 1, 0.5 encodes class 2 and 1 encodes class 3)?

Thank you for your reply, ^_^
i mean the first one, pixel-wise one-hot vectory .
in my opinion , in multi class problem, y_pred and y_true can also be flatten into 1 dimension,
why can't multi class problem directly apply the tversky loss ? ( binary class problem can ).
What is the difference between them?

from focal-tversky-unet.

ledakk avatar ledakk commented on August 16, 2024

Basically, if you flatten into 1 dimension, your classes will not be weighted equally any more (which is already true for the Dice loss, btw).

Short example: imagine a 2-class problem. Let's say the object in class 1 is really large and the object in class 2 is only one pixel. Assume there is only one wrong pixel in your prediction compared to the ground truth. Let this pixel be the one pixel of class 2 which was falsely assigned to class 1.

Now, if you calculate the Dice on the flattened output, you will still be very close to 1.
Otherwise, if you calculate the Dice per class and take the mean, you will get some value around 0.5.

Note that I used the sum instead of the mean in the code above. For the optimizing task, this is equivalent but doesn't need the division by a constant (i.e. the number of classes).

cool ! ! !
I have totally understand ,
thank you for your patience to explain to me , and i am sorry for my bad expression.

from focal-tversky-unet.

phernst avatar phernst commented on August 16, 2024

you're welcome :)

from focal-tversky-unet.

azizasaber avatar azizasaber commented on August 16, 2024

For a C class problem, you need to have C output channels. The activation should be softmax (but sigmoid should still work, too). Try this modification of the loss function:

import keras.backend as K

def class_tversky(y_true, y_pred):
    smooth = 1

    y_true = K.permute_dimensions(y_true, (3,1,2,0))
    y_pred = K.permute_dimensions(y_pred, (3,1,2,0))

    y_true_pos = K.batch_flatten(y_true)
    y_pred_pos = K.batch_flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos, 1)
    false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
    false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
    alpha = 0.7
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def focal_tversky_loss(y_true,y_pred):
    pt_1 = class_tversky(y_true, y_pred)
    gamma = 0.75
    return K.sum(K.pow((1-pt_1), gamma))

The code assumes 2d images and channels last (i.e. the shape of y_true and y_pred should be [num_batches, height, width, num_channels]).

thanks for your great answer. Though, I am applying it on a 3 class semantic segmentation problem, and I am getting nan as loss at each time. I am thinking that the gradient explodes. But I no sure how to fix it.

from focal-tversky-unet.

phernst avatar phernst commented on August 16, 2024

Yes, gradient explosion could be one problem, however I never experienced that with Unets and these types of loss functions (but depends on the data, of course).
I don't know what your data looks like, but if your target is not one-hot encoded or has values smaller than 0 or greater than 1, the result of 1-pt_1 might be negative, which would result in nan values after applying the pow function. That's at least one other problem that I could imagine. Can you check that, please?

from focal-tversky-unet.

azizasaber avatar azizasaber commented on August 16, 2024

@phernst thanks a lot for your quick response. My network is a multi task network with three decoders and I am using this loss for the first task which is supposed to do 3 class segmentation (0, 1, 2) and the vectors are one hot encoded (softmax is used in the last layer in the form of Conv2D(3, (1,1), activation = 'softmax', name='segmap')). My dataset is a pancreas dataset where background is the majority, the second class is much smaller than the background but more than the third class and the third class (tumor) occupies very small portion of the dataset. Right now I dont have an access to a gpu, but I will check the value of 1-pt_1 once I get the access and write the result in here.

from focal-tversky-unet.

djsherm avatar djsherm commented on August 16, 2024

Hello, I am trying to determine if I will be able to use the loss function for a project that I am working on. I am incredibly confused about the K.permute_dimensions function. I am unsure why you included the tuple after the y_true vector. Is it the order of dimensions to take? Is it something else? The documentation is unclear. Thank you for your help

from focal-tversky-unet.

phernst avatar phernst commented on August 16, 2024

Yes, the tuple determines the order of dimensions. However, K.permute_dimensions is not available any more in Tensorflow 2. Since Keras was merged into Tensorflow, most of the backend functions became redundant and were removed. However, it should be possible for you to replace K.permute_dimensions with tf.transpose and get the same functionality. See also here: https://www.tensorflow.org/api_docs/python/tf/transpose

from focal-tversky-unet.

phernst avatar phernst commented on August 16, 2024

I haven't tested the code yet, but the Tensorflow v2 implementation should look something like this:

import tensorflow as tf

def class_tversky(y_true: tf.Tensor, y_pred: tf.Tensor):
    smooth = 1

    true_pos = tf.reduce_sum(y_true * y_pred, axis=(0, 1, 2))
    false_neg = tf.reduce_sum(y_true * (1-y_pred), axis=(0, 1, 2))
    false_pos = tf.reduce_sum((1-y_true) * y_pred, axis=(0, 1, 2))
    alpha = 0.7
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def focal_tversky_loss(y_true: tf.Tensor, y_pred: tf.Tensor):
    pt_1 = class_tversky(y_true, y_pred)
    gamma = 0.75
    return tf.reduce_sum((1-pt_1)**gamma)

from focal-tversky-unet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.