Comments (20)
For a C
class problem, you need to have C
output channels. The activation should be softmax (but sigmoid should still work, too). Try this modification of the loss function:
import keras.backend as K
def class_tversky(y_true, y_pred):
smooth = 1
y_true = K.permute_dimensions(y_true, (3,1,2,0))
y_pred = K.permute_dimensions(y_pred, (3,1,2,0))
y_true_pos = K.batch_flatten(y_true)
y_pred_pos = K.batch_flatten(y_pred)
true_pos = K.sum(y_true_pos * y_pred_pos, 1)
false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
alpha = 0.7
return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)
def focal_tversky_loss(y_true,y_pred):
pt_1 = class_tversky(y_true, y_pred)
gamma = 0.75
return K.sum(K.pow((1-pt_1), gamma))
The code assumes 2d images and channels last (i.e. the shape of y_true
and y_pred
should be [num_batches, height, width, num_channels]
).
from focal-tversky-unet.
@nabsabraham Thanks for the comments ill check these suggestions
from focal-tversky-unet.
Basically, if you flatten into 1 dimension, your classes will not be weighted equally any more (which is already true for the Dice loss, btw).
Short example: imagine a 2-class problem. Let's say the object in class 1 is really large and the object in class 2 is only one pixel. Assume there is only one wrong pixel in your prediction compared to the ground truth. Let this pixel be the one pixel of class 2 which was falsely assigned to class 1.
Now, if you calculate the Dice on the flattened output, you will still be very close to 1.
Otherwise, if you calculate the Dice per class and take the mean, you will get some value around 0.5.
Note that I used the sum instead of the mean in the code above. For the optimizing task, this is equivalent but doesn't need the division by a constant (i.e. the number of classes).
from focal-tversky-unet.
thank you for your question @shihlun1208 and your super clean code @phernst :)
from focal-tversky-unet.
I am working with 3D data, will this work ?
def class_tversky(y_true, y_pred):
smooth = 1
y_true = K.permute_dimensions(y_true, (1,2,3,4,0))
y_pred = K.permute_dimensions(y_pred, (1,2,3,4,0))
y_true_pos = K.batch_flatten(y_true)
y_pred_pos = K.batch_flatten(y_pred)
true_pos = K.sum(y_true_pos * y_pred_pos, 1)
false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1)
false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1)
alpha = 0.7
return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)
Thanks
from focal-tversky-unet.
The permutation should be (4,1,2,3,0). The reason behind this is that you want to keep the classes (here: channels at dimension 4) as "batches". Basically, every permutation that puts dimension 4 to the first place will work, e.g. (4,1,3,2,0), (4,3,2,1,0), ...
from focal-tversky-unet.
Hmmm, okay let me check.
And just to confirm the lower the loss the better right?
Also just a side question,
I am using a custom unet model to perform instance segmentation, 26 classes, Currently the results are very poor, do you have any suggestions?
from focal-tversky-unet.
The tversky
function outputs a value between 0-1 and the tl
function outputs 1-tversky(*)
so the closer the loss is to 0, the better. (fyi, your ground truths must be one hot encoded).
I haven't ever done instance seg but I believe cross entropy is a good starting point. The TL/FTL require some parameter tuning so I would start with CE and then maybe weighted-CE and then try DL/TL/FTL. Also with CE, add an extra class for the background class - this always seems to improve my results on semantic segmentation.
from focal-tversky-unet.
@nabsabraham Hi,Since you mention in the readme that this network is good for small lesions, any thoughts on how do you think this network will perform on kidney or verterbra segmentation ?
Can your model be extended as a 3D model, i mean is the extension not very complicated?
Thanks
from focal-tversky-unet.
For a
C
class problem, you need to haveC
output channels. The activation should be softmax (but sigmoid should still work, too). Try this modification of the loss function:import keras.backend as K def class_tversky(y_true, y_pred): smooth = 1 y_true = K.permute_dimensions(y_true, (3,1,2,0)) y_pred = K.permute_dimensions(y_pred, (3,1,2,0)) y_true_pos = K.batch_flatten(y_true) y_pred_pos = K.batch_flatten(y_pred) true_pos = K.sum(y_true_pos * y_pred_pos, 1) false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1) false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1) alpha = 0.7 return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth) def focal_tversky_loss(y_true,y_pred): pt_1 = class_tversky(y_true, y_pred) gamma = 0.75 return K.sum(K.pow((1-pt_1), gamma))
The code assumes 2d images and channels last (i.e. the shape of
y_true
andy_pred
should be[num_batches, height, width, num_channels]
).
Why the original model cannot be applied to mult class ?
Convert label to one-hot encoding, It‘s value range also in [0-1].
(sorry , my english is not very good )
from focal-tversky-unet.
Could you explain a bit more in detail? Do you mean a pixel-wise one-hot vector like [0,0,1] (so, for a 3-class problem, this one would encode the last class), or a single number per pixel (so, for a 3-class problem, 0 encodes class 1, 0.5 encodes class 2 and 1 encodes class 3)?
from focal-tversky-unet.
Could you explain a bit more in detail? Do you mean a pixel-wise one-hot vector like [0,0,1] (so, for a 3-class problem, this one would encode the last class), or a single number per pixel (so, for a 3-class problem, 0 encodes class 1, 0.5 encodes class 2 and 1 encodes class 3)?
Thank you for your reply, ^_^
i mean the first one, pixel-wise one-hot vectory .
in my opinion , in multi class problem, y_pred and y_true can also be flatten into 1 dimension,
why can't multi class problem directly apply the tversky loss ? ( binary class problem can ).
What is the difference between them?
from focal-tversky-unet.
Basically, if you flatten into 1 dimension, your classes will not be weighted equally any more (which is already true for the Dice loss, btw).
Short example: imagine a 2-class problem. Let's say the object in class 1 is really large and the object in class 2 is only one pixel. Assume there is only one wrong pixel in your prediction compared to the ground truth. Let this pixel be the one pixel of class 2 which was falsely assigned to class 1.
Now, if you calculate the Dice on the flattened output, you will still be very close to 1.
Otherwise, if you calculate the Dice per class and take the mean, you will get some value around 0.5.Note that I used the sum instead of the mean in the code above. For the optimizing task, this is equivalent but doesn't need the division by a constant (i.e. the number of classes).
cool ! ! !
I have totally understand ,
thank you for your patience to explain to me , and i am sorry for my bad expression.
from focal-tversky-unet.
you're welcome :)
from focal-tversky-unet.
For a
C
class problem, you need to haveC
output channels. The activation should be softmax (but sigmoid should still work, too). Try this modification of the loss function:import keras.backend as K def class_tversky(y_true, y_pred): smooth = 1 y_true = K.permute_dimensions(y_true, (3,1,2,0)) y_pred = K.permute_dimensions(y_pred, (3,1,2,0)) y_true_pos = K.batch_flatten(y_true) y_pred_pos = K.batch_flatten(y_pred) true_pos = K.sum(y_true_pos * y_pred_pos, 1) false_neg = K.sum(y_true_pos * (1-y_pred_pos), 1) false_pos = K.sum((1-y_true_pos)*y_pred_pos, 1) alpha = 0.7 return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth) def focal_tversky_loss(y_true,y_pred): pt_1 = class_tversky(y_true, y_pred) gamma = 0.75 return K.sum(K.pow((1-pt_1), gamma))
The code assumes 2d images and channels last (i.e. the shape of
y_true
andy_pred
should be[num_batches, height, width, num_channels]
).
thanks for your great answer. Though, I am applying it on a 3 class semantic segmentation problem, and I am getting nan as loss at each time. I am thinking that the gradient explodes. But I no sure how to fix it.
from focal-tversky-unet.
Yes, gradient explosion could be one problem, however I never experienced that with Unets and these types of loss functions (but depends on the data, of course).
I don't know what your data looks like, but if your target is not one-hot encoded or has values smaller than 0 or greater than 1, the result of 1-pt_1
might be negative, which would result in nan values after applying the pow
function. That's at least one other problem that I could imagine. Can you check that, please?
from focal-tversky-unet.
@phernst thanks a lot for your quick response. My network is a multi task network with three decoders and I am using this loss for the first task which is supposed to do 3 class segmentation (0, 1, 2) and the vectors are one hot encoded (softmax is used in the last layer in the form of Conv2D(3, (1,1), activation = 'softmax', name='segmap')).
My dataset is a pancreas dataset where background is the majority, the second class is much smaller than the background but more than the third class and the third class (tumor) occupies very small portion of the dataset. Right now I dont have an access to a gpu, but I will check the value of 1-pt_1 once I get the access and write the result in here.
from focal-tversky-unet.
Hello, I am trying to determine if I will be able to use the loss function for a project that I am working on. I am incredibly confused about the K.permute_dimensions
function. I am unsure why you included the tuple after the y_true
vector. Is it the order of dimensions to take? Is it something else? The documentation is unclear. Thank you for your help
from focal-tversky-unet.
Yes, the tuple determines the order of dimensions. However, K.permute_dimensions
is not available any more in Tensorflow 2. Since Keras was merged into Tensorflow, most of the backend functions became redundant and were removed. However, it should be possible for you to replace K.permute_dimensions
with tf.transpose
and get the same functionality. See also here: https://www.tensorflow.org/api_docs/python/tf/transpose
from focal-tversky-unet.
I haven't tested the code yet, but the Tensorflow v2 implementation should look something like this:
import tensorflow as tf
def class_tversky(y_true: tf.Tensor, y_pred: tf.Tensor):
smooth = 1
true_pos = tf.reduce_sum(y_true * y_pred, axis=(0, 1, 2))
false_neg = tf.reduce_sum(y_true * (1-y_pred), axis=(0, 1, 2))
false_pos = tf.reduce_sum((1-y_true) * y_pred, axis=(0, 1, 2))
alpha = 0.7
return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)
def focal_tversky_loss(y_true: tf.Tensor, y_pred: tf.Tensor):
pt_1 = class_tversky(y_true, y_pred)
gamma = 0.75
return tf.reduce_sum((1-pt_1)**gamma)
from focal-tversky-unet.
Related Issues (20)
- results HOT 7
- Version issue HOT 8
- gt_train issue HOT 1
- seek for help about the visualization of the CAM of attention unet HOT 2
- need guidance
- ValueError: continuous format is not supported HOT 1
- About the dataset HOT 1
- Multi class support HOT 2
- There is a problem to be solved HOT 3
- About ISIC dataset's folders.
- What does 'thresh' stand for? HOT 1
- Batch_Size HOT 1
- Learning Rate Decay or Typo?
- pred1 output nan after a few epochs
- Unet gating signal typo?
- val_dsc is bigger than 1? HOT 1
- loss: Nan. HOT 2
- Loss calculated on wrong dimension
- Gating Signal before Convolution
- multi-scale input in the attn_reg function
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from focal-tversky-unet.