Coder Social home page Coder Social logo

dice_loss_for_nlp's People

Contributors

littlesulley avatar xiaoya-li avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

dice_loss_for_nlp's Issues

ohem_ratio

In the code comment, ohem_ratio refers to the max ratio of positive/negative, defautls to 0.0, which means no ohem.

But later in the code, it is computing keep_num with a formula keep_num = min(int(pos_num * self.ohem_ratio / logits_size), neg_num).

Should the comment be changed to negative/positive?

Dice Loss多分类

您好,请问您用您的dice loss跑过多分类的吗?我照常用您的代码跑多分类,但是得到loss NAN(同样的代码使用交叉熵是正常的)。

dice_loss训练中显示为NAN

您好,在做二分类任务时,我参考adaptive_dice_loss.py中代码:

intersection = torch.sum((1-flat_input)**self.alpha * flat_input * flat_target, -1) + self.smooth
denominator = torch.sum((1-flat_input)**self.alpha * flat_input) + flat_target.sum() + self.smooth       
return 1 - 2 * intersection / denominator

写了对应的tensorflow版的损失函数:

def dice_loss(alpha=0.1, smooth=1e-8):
    def dice_loss_fixed(y_pred, y_true):
        intersection = K.sum((1-y_pred)**alpha * y_pred * y_true, -1) + smooth
        denominator =  K.sum((1-y_pred)**alpha * y_pred,-1) + K.sum(y_true) + smooth
        return 1 - 2 *intersection / denominator
    return dice_loss_fixed

可在训练中,损失值一直显示为NAN,不知为何,还请麻烦解答指正,谢谢~

model.compile(optimizer=keras.optimizers.RMSprop(),
             loss=[dice_loss(alpha=0.1,smooth=1e-8)],
             metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=64, epochs=5,
         validation_data=(x_test, y_test))

image

RuntimeError: "bitwise_and_cpu" not implemented for 'Float' in DiceLoss

I am using:

  • Python 3.8.10
  • torch 1.12.0+cu113

When setting alpha > 0 in DiceLoss it results in following error:

RuntimeError: "bitwise_and_cpu" not implemented for 'Float' in DiceLoss

at line:

https://github.com/ShannonAI/dice_loss_for_NLP/blob/master/loss/dice_loss.py#L120

This is due to wrong operator evaluation order. First & is evaluated, which is wrong. You can avoid it by adding brackets around boolean operations:

cond = ((torch.argmax(flat_input, dim=1) == label_idx) & (flat_input[:, label_idx] >= threshold)) | pos_example.view(-1)

Dice loss does is not optimizing

Thanks for your paper and implementation,
I am using DIce loss for multiclass text classification, however the value of the Dice loss is not optimized at all
I can't see where is the problem, is anyone having the same problem ?
Thanks in advance

CPU报错

为什么在CPU上跑tasks/squad/train.py的时候会有这个报错
“”“
Traceback (most recent call last):
File "/Users/hodge/Documents/GitHub/diceLoss/tasks/squad/train.py", line 369, in
main()
File "/Users/hodge/Documents/GitHub/diceLoss/tasks/squad/train.py", line 363, in main
trainer.fit(model)
File "/Users/hodge/opt/anaconda3/envs/diceloss/lib/python3.6/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
result = fn(self, *args, **kwargs)
File "/Users/hodge/opt/anaconda3/envs/diceloss/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1083, in fit
self.accelerator_backend.setup(model)
File "/Users/hodge/opt/anaconda3/envs/diceloss/lib/python3.6/site-packages/pytorch_lightning/accelerators/cpu_backend.py", line 26, in setup
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
pytorch_lightning.utilities.exceptions.MisconfigurationException: amp + cpu is not supported. Please use a GPU option
”“”
但是我已经把第一个参数--gpus="1"改成--gpus="0"了

Self-adjustment in the Dice Loss

I read the ACL2020 paper and it suggests self-adjustment in the Dice Loss with Figure 1, which explains the derivative approaches zero right after p exceeds 0.5. This is the case when the alpha is 1.0. However, the script for OntoNotes5 data use alpha=0.01, which is very small adjustment and gives almost same performance with just squared form of Dice. When I use alpha=1.0 and learn the model with the script and CoNLL2003 data, the model does not learn well (the F1 was about 28.96). I wonder why the self-adjustment does not affect well. Could you explain which value of alpha is best in general?

Dice loss for Token classification

I've been trying to use dice loss for task of token classification with 9 classes.
after I have fixed few errors in _multiple_class for example in line 143 we have flat_input_idx.view(-1, 1) which throws an error because tensors are not contiguous.
I used this instead:
loss_idx = self._compute_dice_loss(flat_input_idx.reshape(-1, 1), flat_target_idx.reshape(-1, 1))

And now I've tried to train a model with this and it seems to me that loss isn't changing at all. I don't know what I am doing wrong
https://github.com/Zhylkaaa/simpletransformers/blob/dice_loss/simpletransformers/ner/ner_model.py#L489 - this is where I am trying to integrate dice_loss.

I can prepare minimal example if you want to take a look

stuck in validation(squad task)

stuck in validation(squad task)
GPU mem keeps but no usage.

I tried the same pl version(0.9.0) and the newest.

Could you please help figure out that?

Dice Loss Error

I have two part question,

  1. The example given in the code bugs out i.e.
    >>> loss = DiceLoss(with_logits=True, ohem_ratio=0.1)

    IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
  2. The other question is related to the implementation, say the classifier has perfectly predicted the labels, but there would be still some dice loss because of loss = 1 - ((2 * interection + self.smooth) /
    (torch.sum(torch.square(flat_input, ), -1) + torch.sum(torch.square(flat_target), -1) + self.smooth))
    smooth. Is this the expected behavior or am I missing something.
input = torch.FloatTensor([[1., .0, .0, .0],[0., 1, .0, .0]])
target = torch.LongTensor([0, 1])
loss = DiceLoss(with_logits=False,reduction=None,ohem_ratio=0.)
input.requires_grad=True
output = loss(input, target)

Output
tensor([1.9998, 1.9998], grad_fn=)

代码报错

一堆版本不兼容的问题。。。。。。。

Link for Preprocessed OntoNotes 5.0

Hi,

Thank you very much for your paper and model. I've been trying to replicate your best experimental results on OntoNotes 5.0, however I cannot find the dataset at the link you have provided? Could you please provide link? Thanks.

你们的模型得大gpu才能跑起来呀

我用两块RTX 2080Ti还是会OOM,还有用你们的代码使用三块gpu必报错,多gpu混合精度也会报错,估计是pytorch-lightning的原因?没用过这个库

zh_onto4数据集结果复现问题

你好,我们在复现命名实体识别数据集zh_onto4结果时,按照readme的指导,运行的是scripts/ner_zhonto4/bert_dice.sh. 脚本,脚本超参没有修改过,但测试集 spanF1的分数只有80.80,与文章中的84.47的结果差距较大,运行日志和测试结果见下文,麻烦看一下是哪个地方的问题,多谢!

h-4.3$sh scripts/ner_zhonto4/bert_dice.sh
DEBUG INFO -> loss sign is dice_1_0.3_0.01
DEBUG INFO -> save hyperparameters
DEBUG INFO -> pred_answerable train_infer
DEBUG INFO -> check bert_config
BertForQueryNERConfig {
"activate_func": "relu",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"construct_entity_span": "start_and_end",
"directionality": "bidi",
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"LABEL_0": 0
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"pred_answerable": true,
"truncated_normal": false,
"type_vocab_size": 2,
"vocab_size": 21128
}

Some weights of the model checkpoint at /home/ma-user/work/bert-base-chinese were not used when initializing BertForQueryNER: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']

  • This IS expected if you are initializing BertForQueryNER from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).

  • This IS NOT expected if you are initializing BertForQueryNER from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    Some weights of BertForQueryNER were not initialized from the model checkpoint at /home/ma-user/work/bert-base-chinese and are newly initialized: ['start_outputs.dense_layer.weight', 'start_outputs.dense_layer.bias', 'start_outputs.dense_to_labels_layer.weight', 'start_outputs.dense_to_labels_layer.bias', 'end_outputs.dense_layer.weight', 'end_outputs.dense_layer.bias', 'end_outputs.dense_to_labels_layer.weight', 'end_outputs.dense_to_labels_layer.bias', 'answerable_cls_output.dense_layer.weight', 'answerable_cls_output.dense_layer.bias', 'answerable_cls_output.dense_to_labels_layer.weight', 'answerable_cls_output.dense_to_labels_layer.bias']
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
    /opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: Checkpoint directory /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01 exists and is not empty with save_top_k != 0.All files in this directory will be deleted when a checkpoint is saved!
    warnings.warn(*args, **kwargs)
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    CUDA_VISIBLE_DEVICES: [0]
    Using native 16bit precision.
    /opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: Could not log computational graph since the model.example_input_array attribute is not set or input_array was not given
    warnings.warn(*args, **kwargs)

    | Name | Type | Params


0 | model | BertForQueryNER | 104 M
1 | evaluation_metric | MRCNERSpanF1 | 0
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the num_workers argument(try 72 which is the number of cpus on this machine) in theDataLoaderinit to improve performance. warnings.warn(*args, **kwargs) Validation sanity check: 0it [00:00, ?it/s]Truncation was not explicitly activated butmax_lengthis provided a specific value, please usetruncation=Trueto explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy totruncation. /opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the num_workers argument (try 72 which is the number of cpus on this machine) in the DataLoader init to improve performance.
warnings.warn(*args, **kwargs)
Epoch 0: 25%|██████████████████████████████▍ | 19168/76678 [09:30<28:32, 33.58it/s, loss=0.318, v_num=4]
Epoch 00000: val_f1 reached 0.69758 (best 0.69758), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=0.ckpt as top 3
/opt/conda/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.
warnings.warn(SAVE_STATE_WARNING, UserWarning)
Epoch 0: 50%|████████████████████████████████████████████████████████████▉ | 38336/76678 [19:07<19:07, 33.41it/s, loss=0.297, v_num=4]
Epoch 00000: val_f1 reached 0.72641 (best 0.72641), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=0_v0.ckpt as top 3
Epoch 0: 50%|██████��██████████████████████████████████████████████████████ | 38346/76678 [19:11<19:11, 33.29it/s, loss=0.312, v_num=4Epoch 0: 75%|███████████████████████████████████████████████████████████████████████████████████████████▍ | 57504/76678 [28:56<09:38, 33.12it/s, loss=0.210, v_num=4]
Epoch 00000: val_f1 reached 0.76375 (best 0.76375), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=0_v1.ckpt as top 3
Epoch 0: 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 70608/76678 [36:58<03:10, 31.83it/s, loss=0.296, v_num=4]
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 76672/76678 [38:33<00:00, 33.15it/s, loss=0.296, v_num=4]Epoch 00000: val_f1 reached 0.73197 (best 0.76375), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=0_v2.ckpt as top 3
Epoch 1: 20%|██████████��█████████████ | 15152/76678 [08:12<33:20, 30.75it/s, loss=0.125, v_num=4Epoch 1: 25%|██████████████████████████████▍ | 19168/76678 [09:16<27:50, 34.43it/s, loss=0.125, v_num=4]
Epoch 00001: val_f1 reached 0.75373 (best 0.76375), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=1.ckpt as top 3
Epoch 1: 46%|███████████████████████████████████████████████████████▉ | 35128/76678 [17:49<21:04, 32.85it/s, loss=0.175, v_num=4]
Epoch 1: 50%|████████████████████████████████████████████████████████████▉ | 38336/76678 [18:40<18:40, 34.21it/s, loss=0.175, v_num=4]Epoch 00001: val_f1 reached 0.75231 (best 0.76375), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=1_v0.ckpt as top 3
Epoch 1: 75%|███████████████████████████████████████████████████████████████████████████████████████████▍ | 57504/76678 [28:08<09:22, 34.06it/s, loss=0.241, v_num=4]
Epoch 00001: val_f1 was not in top 3███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 17200/17204 [04:34<00:00, 66.96it/s]
Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 76672/76678 [37:23<00:00, 34.17it/s, loss=0.149, v_num=4]
Epoch 00001: val_f1 reached 0.76958 (best 0.76958), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=1_v1.ckpt as top 3
Epoch 2: 25%|██████████████████████████████▍ | 19168/76678 [09:19<28:00, 34.23it/s, loss=0.159, v_num=4]
Epoch 00002: val_f1 reached 0.76353 (best 0.76958), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=2.ckpt as top 3
Epoch 2: 50%|████████████████████████████████████████████████████████████▉ | 38336/76678 [18:47<18:47, 34.01it/s, loss=0.157, v_num=4]
Epoch 00002: val_f1 was not in top 3███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 17202/17204 [04:32<00:00, 66.10it/s]
Epoch 2: 75%|███████████████████████████████████████████████████████████████████████████████████████████▍ | 57504/76678 [28:04<09:21, 34.14it/s, loss=0.220, v_num=4]
Epoch 00002: val_f1 was not in top 3███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 17202/17204 [04:32<00:00, 66.10it/s]
Epoch 2: 75%|██████��█████████████████████████████████████████████████████████████████████████████████████ | 57856/76678 [28:55<09:24, 33.33it/s, loss=0.080, v_num=4Epoch 2: 83%|█████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 63672/76678 [34:02<06:57, 31.17it/s, loss=0.129, v_num=4]
Epoch 2: 96%|██████████████████████████████████████████████████████████���██████████████████████████████████████████████████████████▌ | 73904/76678 [36:42<01:22, 33.55it/s, loss=0.129, v_num=Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 76672/76678 [37:26<00:00, 34.12it/s, loss=0.129, v_num=4]
Epoch 00002: val_f1 reached 0.78048 (best 0.78048), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=2_v0.ckpt as top 3
Epoch 3: 25%|██████████████████████████████▏ | 18984/76678 [09:18<28:17, 33.99it/s, loss=0.119, v_num=4]
Epoch 3: 25%|██████████████████████████████▍ | 19168/76678 [09:21<28:04, 34.15it/s, loss=0.119, v_num=4]Epoch 00003: val_f1 reached 0.76468 (best 0.78048), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=3.ckpt as top 3
Epoch 3: 36%|█████████████████████���█████████████████████▊ | 27528/76678 [15:57<28:29, 28.75it/s, loss=0.084, v_num=Epoch 3: 50%|████████████████████████████████████████████████████████████▉ | 38336/76678 [18:51<18:51, 33.88it/s, loss=0.084, v_num=4]
Epoch 00003: val_f1 reached 0.76513 (best 0.78048), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=3_v0.ckpt as top 3
Epoch 3: 50%|██████████████���██████████████████████████████████████████████▏ | 38424/76678 [19:07<19:02, 33.49it/s, loss=0.067, v_num=Epoch 3: 59%|███████████████���███████████████████████████████████████████████████████▉ | 45224/76678 [25:09<17:29, 29.96it/s, loss=0.065, v_num=Epoch 3: 75%|███████████████████████████████████████████████████████████████████████████████████████████▍ | 57504/76678 [28:17<09:26, 33.87it/s, loss=0.065, v_num=4]
Epoch 00003: val_f1 reached 0.77975 (best 0.78048), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=3.ckpt as top 3
Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 76672/76678 [37:48<00:00, 33.79it/s, loss=0.115, v_num=4]
Epoch 00003: val_f1 reached 0.77402 (best 0.78048), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=3_v0.ckpt as top 3
Epoch 4: 25%|██████████████████████████████▍ | 19168/76678 [09:50<29:32, 32.45it/s, loss=0.141, v_num=4]
Epoch 00004: val_f1 was not in top 3███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 17198/17204 [04:47<00:00, 66.49it/s]
Epoch 4: 50%|████████████████████████████████████████████████████████████▉ | 38336/76678 [19:56<19:56, 32.05it/s, loss=0.162, v_num=4]
Epoch 00004: val_f1 reached 0.78181 (best 0.78181), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=4.ckpt as top 3
Epoch 4: 75%|███████████████████████████████████████████████████████████████████████████████████████████▍ | 57504/76678 [30:48<10:16, 31.11it/s, loss=0.092, v_num=4]
Epoch 00004: val_f1 reached 0.78524 (best 0.78524), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=4_v0.ckpt as top 3
Epoch 4: 77%|██████████���██████████████████████████████████████████████████████████████████████████████████▊ | 58938/76678 [35:09<10:35, 27.93it/s, loss=0.122, v_num=Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 76672/76678 [41:29<00:00, 30.80it/s, loss=0.141, v_num=4]
Epoch 00004: val_f1 reached 0.78198 (best 0.78524), saving model to /home/ma-user/work/dice_loss_for_NLP-master/output/dice_loss/mrc_ner/reproduce_zhonto_dice_base_8_300_2e-5_polydecay_0.1_2_5_1.0_0.002_0.1_1_1_0.3_dice_1_0.3_0.01/epoch=4_v1.ckpt as top 3
Epoch 4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 76678/76678 [41:32<00:00, 30.76it/s, loss=0.141, v_num=4]Saving latest checkpoint..
Epoch 4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 76678/76678 [41:32<00:00, 30.76it/s, loss=0.141, v_num=4]
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:37: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the num_workers argument(try 72 which is the number of cpus on this machine) in theDataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 17380/17384 [04:51<00:00, 59.82it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_span_f1': tensor(0.8080, device='cuda:0'),
'test_span_precision': tensor(0.8282, device='cuda:0'),
'test_span_recall': tensor(0.7888, device='cuda:0')}

Asking for the script for CoNLL2003 data

Thank you for sharing your code and I am very interested in dice loss especially for NER task.
Here you are sharing CoNLL2003 data of MRC format but the script for NER (with hyper-parameters) is for OntoNotes5 (English). Can you share the script (actually, the hyper-parameters) suitable for CoNLL2003 data?
When I used the script for OntoNotes5 with CoNLL2003 data, I could get about 92.08 F1 (with 10 epochs) but this is a bit lower performance than 93.33 F1, which is reported in the ACL2020 paper. On the contrary, I could get 92.35 F1 with BCE loss and 5 epochs.

And can you share OntoNotes5 data of MRC format or at least query sentences?

Masking

I tried replacing BCE loss with DICE and my model wouldn't converge. When I looked closer I noticed that whilst the input and target are flattened, the mask isn't. So if you pass a mask that is the same shape as the target, then the multiplication flat_input * mask unflattens flat_input

def _binary_class(self, input, target, mask=None):
    flat_input = input.view(-1)
    flat_target = target.view(-1).float()
    flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input

    if mask is not None:
        mask = mask.float()
        flat_input = flat_input * mask
        flat_target = flat_target * mask
    else:
        mask = torch.ones_like(target)

I made the following change and my model started converging immediately

    if mask is not None:
        mask = mask.float()
        flat_input = flat_input * mask.view(-1)
        flat_target = flat_target * mask.view(-1)
    else:
        mask = torch.ones_like(target)

Although I think a better fix is to actually apply the mask rather than mask out the masked inputs/targets ie.

        mask = mask.view(-1)
        flat_input = flat_input[mask]
        flat_target = flat_target[mask]

cls_answerable_loss and match_loss

I have questions about some details in training.
in your paper, you said your backbone model for NER experiment is proposed in "A Unified MRC Framework for Named Entity Recognition". But when I look in the code and run the experiment to reproduce the reported results, I found that you ignore the match_loss (match_loss is introduced to teach the model to match which predicted start token with which predicted end token).
So how we can inference this model in the case of sentence with multiple Ner entities?
Also, you introduce a new loss - cls_answerable_loss to teach model to classify whether input contains entity or loss. Why do you not mention this loss in your paper.
Thank you.

The mask related code in the Dice loss function is wrong

Hello,

First of all, cool work! :)

Now let me get to the point:

I found the following bug in your code:

    if mask is not None: # here is the problem!! flat_input and flat_target are already made one-hot, thus the multiplication will not work!
        mask = mask.float()
        flat_input = flat_input * mask
        flat_target = flat_target * mask
    else:
        mask = torch.ones_like(target)

An easy fix is the following:

    if mask is not None:
        mask = mask.float()
        flat_input = (flat_input.t() * mask).t()
        flat_target = (flat_target.t() * mask).t()
    else:
        mask = torch.ones_like(target)

Some question about flat_ input and flat_target.

Suppose I have the following probs and labels (Binary classification):

probs = torch.FloatTensor([[0.3],
                           [0.8],
                           [0.2],
                           [0.7]])

targets = torch.LongTensor([[0],
                            [1],
                            [0],
                            [1]])

Execute the following code:

loss = DiceLoss(alpha=1, smooth=1, with_logits=False, ohem_ratio=0.0, reduction='mean')
output = loss(inputs, targets)
print(output)

No doubt, the code will enter binary_ class().

def _binary_class(self, input, target, mask=None):
        flat_input = input.view(-1)
        flat_target = target.view(-1).float()

At this time, the shape of flat_input is:

tensor([0.3000, 0.8000, 0.2000, 0.7000])

The shape of flat_target is:

tensor([0., 1., 0., 1.])

So far, there is no problem, but when I switched to multi-classification for testing, I found that there was a problem with the shape of flat_input and flat_target .

Suppose I have the following probs and labels (Multi classification):

probs  = torch.FloatTensor([[0.1,0.8,0.7],
                            [0.5,0.1,0.6],
                            [0.7,0.5,0.8],
                            [0.4,0.6,0.9]])

targets = torch.LongTensor([[1, 0, 0],
                            [0, 1, 0],
                            [0, 0, 1],
                            [0, 1, 0]])

Execute the following code:

loss = DiceLoss(alpha=1, smooth=1, with_logits=False, ohem_ratio=0.0, index_label_position=False, reduction='mean')
output = loss(inputs, targets)
print(output)

No doubt, the code will enter multiple_class().

def _multiple_class(self, input, target, logits_size, mask=None):
        flat_input = input
        flat_target = F.one_hot(target, num_classes=logits_size).float() if self.index_label_position else target.float()

At this time, the shape of flat_input is:

tensor([0.1000, 0.5000, 0.7000, 0.4000])

The shape of flat_target is:

tensor([1., 0., 0., 0.])

But after the following code:

loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1))

The shape of flat_input is:

tensor([[0.1000],
        [0.5000],
        [0.7000],
        [0.4000]])

The shape of flat_target is:

tensor([[1.],
        [0.],
        [0.],
        [0.]])

I don't understand why when you calculate dice loss, flat_ input and flat_target has different shapes.

an error on dice_loss.py

您好,我现在想在ner的任务中使用dice_loss,我的设置如下:
a = torch.rand(13,3)
b = torch.tensor([0,1,1,1,1,1,1,1,1,1,1,1,2])
f = DiceLoss(with_logits=True,smooth=1, ohem_ratio=0.3,alpha=0.01)
f(a,b)
当我运行之后,报错如下:
发生异常: TypeError
unsupported operand type(s) for &: 'int' and 'Tensor'

报错的位置在 _multiple_class
cond = (torch.argmax(flat_input, dim=1) == label_idx & flat_input[:, label_idx] >= threshold) | pos_example.view(-1)

或许是先运行了label_idx & flat...
由于我没有仔细阅读论文中的算法描述,所以并不清楚这一部分的逻辑,也不知道如何修改,特来请教!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.