Comments (7)
Hi jie, I appreciate your interest in our work. I would like to ask that you limit your issues to one thread because you have re-opened issues that are not related to code reproducibility is in #11 and #10.
With regards to BUS code reproducibility, I have resampled the images to 128x128 and I used opencv to do this resizing. This could be a source of error. The other thing is that I used early stopping to pick the model with the best performance because of training instability when close to convergence. A last note is that I ran this script 5 times to get the DSC +/- some variance - in doing so, the train and test was split randomly so this is another area for randomness. I see that your result of 0.748 is still quite far from my result so I will have to look for the pre-processing code because that is the only thing I have not shown in this repo.
Can you try to reproduce the ISIC results? I have recently moved to pytorch and have reproduced ISIC without any struggles.
from focal-tversky-unet.
Thank you for your reply!First of all, you have processed the data to 128 * 128 in the bus code, but there is no relevant processing for the ISIC code, so I think it is easier to reproduce the bus code than isic's. Secondly, I think the cross validation should not be too different, but I will review the code well according to your ideas.
In addition, is it convenient for you to share your Pytorch code? Maybe I'm a little more familiar with Pytorch. It doesn't matter if it's inconvenient. Thank you.
from focal-tversky-unet.
Hi, I try to use my data with this code, but why I get the dsc>1. Like this:
16/112 [===>..........................] - ETA: 11s - loss: -0.1059 - dsc: 1.1925 - tp: 1.0000 - tn: 3.8405e-06
32/112 [=======>......................] - ETA: 9s - loss: -0.1233 - dsc: 1.2266 - tp: 1.0000 - tn: 3.8427e-06
48/112 [===========>..................] - ETA: 7s - loss: -0.1497 - dsc: 1.2816 - tp: 1.0000 - tn: 3.8468e-06
64/112 [================>.............] - ETA: 5s - loss: -0.1684 - dsc: 1.3214 - tp: 1.0000 - tn: 3.8500e-06
80/112 [====================>.........] - ETA: 3s - loss: -0.1762 - dsc: 1.3380 - tp: 1.0000 - tn: 3.8510e-06
96/112 [========================>.....] - ETA: 1s - loss: -0.1816 - dsc: 1.3492 - tp: 1.0000 - tn: 3.8519e-06
112/112 [==============================] - 14s 125ms/step - loss: -0.1881 - dsc: 1.3637 - tp: 1.0000 - tn: 3.8533e-06 - val_loss: -0.2142 - val_dsc: 1.4020 - val_tp: 1.0000 - val_tn: 9.6319e-06
Can you give me some advice to solve the problem? Think you a lot in advance.
from focal-tversky-unet.
Hi Kevin, can you check the range of your predictions and the range of your ground truth? Both masks should be in range [0-1]. You can just print the max and check what it is. Also, try adjusting the alpha and gamma parameters. Maybe your dataset is easier to segment and so the network is becoming super biased to prediction TPs and FPs.
from focal-tversky-unet.
Hey @nabsabraham, any chance that you can share the pytorch version? I really would like to try it
from focal-tversky-unet.
Hi @luistelmocosta, thanks for your interest! Is it the model you were looking for or just the loss function? If it is the latter, I just wrote this up quickly but I think it should work:
def ftl(pred, gt, alpha=0.7, gamma=0.75):
pflat = pred.contiguous().view(-1)
gtflat = gt.contiguous().view(-1)
intersection = (pflat * gtflat).sum()
TP = intersection
FP = (pflat * (1-gtflat)).sum()
FN = ((1-pflat) * gtflat).sum()
return 1 - (((TP + smooth)/(TP + alpha*FN + (1-alpha)*FP + smooth))**gamma).mean()
If it's the model, I believe you can get started with oktay's pytorch version of attention networks. I will have to rewrite it because I never hung on to it but its essentially a few extra layers added on to oktay's model.
from focal-tversky-unet.
Thank you for the quick reply.
I was following this one https://github.com/LeeJunHyun/Image_Segmentation and yea, I noticed that your Attn Unet have an extra layer (4 skip connections), should be easy to adapt.
I am currently using this version:
class TverskyLoss(nn.Module):
__name__ = 'tversky_loss'
def __init__(self, activation = None, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
paper: https://arxiv.org/pdf/1706.05721.pdf
"""
super(TverskyLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.alpha = tversky_alpha
self.beta = tversky_beta
self.activation = Activation(activation)
def forward(self, y_pr, y_gt, loss_mask=None):
shp_x = y_pr.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
y_pr = self.apply_nonlin(y_pr)
y_pr = self.activation(y_pr)
#tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
tp = torch.sum(y_gt * y_pr)
fp = torch.sum(y_pr) - tp
fn = torch.sum(y_gt) - tp
tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)
if not self.do_bg:
if self.batch_dice:
tversky = tversky[1:]
else:
tversky = tversky[:, 1:]
tversky = tversky.mean()
return -tversky
## Tverky Focal Loss
class FocalTversky_loss(nn.Module):
__name__ = 'tversky_focal_loss'
def __init__(self, tversky_kwargs, gamma=float(4/3)):
super(FocalTversky_loss, self).__init__()
self.gamma = gamma
self.tversky = TverskyLoss()
#print("t loss" , self.tversky)
def forward(self, net_output, target):
#print(self.tversky(net_output, target))
tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)
focal_tversky = torch.pow(tversky_loss, self.gamma)
return focal_tversky
Does this look good or should I had contiguous
to my } y_gt
and y_pr
?
Thank you
from focal-tversky-unet.
Related Issues (20)
- Version issue HOT 8
- gt_train issue HOT 1
- seek for help about the visualization of the CAM of attention unet HOT 2
- need guidance
- ValueError: continuous format is not supported HOT 1
- About the dataset HOT 1
- Multi class support HOT 2
- There is a problem to be solved HOT 3
- About ISIC dataset's folders.
- What does 'thresh' stand for? HOT 1
- Batch_Size HOT 1
- Learning Rate Decay or Typo?
- pred1 output nan after a few epochs
- Unet gating signal typo?
- val_dsc is bigger than 1? HOT 1
- loss: Nan. HOT 2
- Loss calculated on wrong dimension
- Gating Signal before Convolution
- multi-scale input in the attn_reg function
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from focal-tversky-unet.