Coder Social home page Coder Social logo

Comments (5)

airaria avatar airaria commented on June 24, 2024 2

First let me explain the distillation configuration above if it is confusing.

For BERT-base, there are 12 layers we name them as 1,2,...12.
The example matches the student's layers to teacher's layers evenly:
teacher's 0th layer goes to student's 0th layer (which are the embedding layers)
teacher's 3rd layer goes to student's 1st layer
teacher's 6th layer goes to student's 2nd layer
teacher's 9th layer goes to student's 3rd layer
teacher's 12th layer(the last layer of teacher) goes to student's 4th layer (the last layer of student)

Since the dimensions of teacher and student are different, we use a linear mapping from 312 (student's dim) to 768 (teacher's dim) to project student's hidden states into a higher dimensional space.

For the above mappings, we take the 'hidden' features (which should be defined in the adaptor by users; it's users responsibility to tell textbrewer what 'hidden' is) from each layer and calculate the 'hidden_mse' loss (defined in the losses.py) between the features from the student and the teacher.

The following lines

{'layer_T' : [0,0], 'layer_S': [0,0], ....}
...

use a different loss 'nst', which requires two similarity matrices.
For example, {'layer_T' : [0,0], 'layer_S': [0,0], ....} means:

  1. calculate the similarity matrix of the 'hidden' feature from teacher's 0th layer with the 'hidden' feature from teacher's 0th layer (self-similarity)
  2. calculate the similarity matrix of the 'hidden' feature from student's 0th layer with the 'hidden' feature from student's 0th layer (self-similarity)
  3. compute the 'nst' loss on above two similarity matrices.

For a three-layer thiner BERT T3-small, you can map the layers 0-0, 4-1, 8-2, 12-3, and use 'proj':[384,768] to match the dimensions.
Lines that contain 'nst' loss can be removed if you want to keep the configuration simple.

from textbrewer.

kaliaanup avatar kaliaanup commented on June 24, 2024

Thank you so much for the detailed explanation. If you can add these to your docs it will be super useful.

I am following up on the conll2003 example. I changed the distill_config as the following. (I am using Transformers 4.17.0)

distill_config = DistillationConfig(
            temperature = 8,
              # intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1}]
              intermediate_matches = [
            {'layer_T':0,  'layer_S':0, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]},
            {'layer_T':4,  'layer_S':1, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]},
            {'layer_T':8,  'layer_S':2, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]},
            {'layer_T':12, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1,'proj':['linear',384,768]}]
            )

The run_conll2003_distill_T3.sh file looks as the following.

export OUTPUT_DIR="resource/taggers/T3-small-bert-finetuned"
export BATCH_SIZE=32
export NUM_EPOCHS=3
export SAVE_STEPS=750
export SEED=42
export MAX_LENGTH=128
export BERT_MODEL_TEACHER="resource/taggers/bert-finetuned"
python run_ner_distill.py \
--data_dir english_dataset \
--model_type bert \
--labels label_prod.txt \
--model_name_or_path $BERT_MODEL_TEACHER \
--output_dir $OUTPUT_DIR \
--max_seq_length  $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--per_gpu_train_batch_size $BATCH_SIZE \
--num_hidden_layers 3 \
--save_steps $SAVE_STEPS \
--learning_rate 1e-4 \
--warmup_steps 0.1 \
--seed $SEED \
--do_distill \
--do_train \
--do_eval \
--do_predict

I am getting an index out of range error. Can you please check?

Traceback (most recent call last):
File "/Users/akalia/Research_Projects/NER-EL-Evaluation/textbrewer_ner_distiller/run_ner_distill.py", line 531, in
main()
File "/Users/akalia/Research_Projects/NER-EL-Evaluation/textbrewer_ner_distiller/run_ner_distill.py", line 460, in main
train(args, train_dataset,model_T, model, tokenizer, labels, pad_token_label_id,predict_callback)
File "/Users/akalia/Research_Projects/NER-EL-Evaluation/textbrewer_ner_distiller/run_ner_distill.py", line 147, in train
distiller.train(optimizer,train_dataloader,args.num_train_epochs,
File "/Users/akalia/Research_Projects/NER-EL-Evaluation/textbrewer_ner_distiller/textbrewer/distiller_basic.py", line 283, in train
self.train_with_num_epochs(optimizer, scheduler, tqdm_disable, dataloader, max_grad_norm, num_epochs, callback, batch_postprocessor, **args)
File "/Users/akalia/Research_Projects/NER-EL-Evaluation/textbrewer_ner_distiller/textbrewer/distiller_basic.py", line 212, in train_with_num_epochs
total_loss, losses_dict = self.train_on_batch(batch,args)
File "/Users/akalia/Research_Projects/NER-EL-Evaluation/textbrewer_ner_distiller/textbrewer/distiller_general.py", line 79, in train_on_batch
total_loss, losses_dict = self.compute_loss(results_S, results_T)
File "/Users/akalia/Research_Projects/NER-EL-Evaluation/textbrewer_ner_distiller/textbrewer/distiller_general.py", line 143, in compute_loss
inter_S = inters_S[feature][layer_S]
IndexError: list index out of range

from textbrewer.

airaria avatar airaria commented on June 24, 2024

Did you set the model to return hidden states by model.config.output_hidden_states=True (if you distilled with hidden states)?

If it is still not working, would you please print the length of the inters_S[feature] and inters_T[feature]by inserting inters_S[feature] and inters_T[feature] to the line 140 of textbrewer/distiller_general.py?

from textbrewer.

stale avatar stale commented on June 24, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

from textbrewer.

stale avatar stale commented on June 24, 2024

Closing the issue, since no updates observed. Feel free to re-open if you need any further assistance.

from textbrewer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.