Comments (2)
Thanks for the question: I realized the code wasn't the latest commit.
filename = model_args.init_emb # '/u/scr/nlp/xlisali/predictability/diffusion_models_v3/diff_e2e-tgt_block_rand16_transformer_lr0.0001_2000_cosine_Lsimple_h128_s2_sd101'
path_save = '{}/random_emb.torch'.format(filename)
path_learned = '{}/ema_0.9999_200000.pt'.format(filename)
if model_args.experiment == 'e2e-tgt-pos' and model_args.learned_emb == 'no':
model.transformer.embeddings.word_embeddings.load_state_dict(torch.load(path_save))
model.transformer.embeddings.word_embeddings.weight.requires_grad = False
elif model_args.experiment == 'e2e-tgt-pos' and model_args.learned_emb == 'yes':
print('loading the learned embeddings')
learned_embeddings = torch.load(path_learned)['word_embedding.weight']
model.transformer.embeddings.word_embeddings.weight.data = learned_embeddings.clone()
model.transformer.embeddings.word_embeddings.weight.requires_grad = False
elif model_args.experiment == 'e2e-tgt-tree' and model_args.learned_emb == 'no':
model.transformer.embeddings.word_embeddings.load_state_dict(torch.load(path_save))
model.transformer.embeddings.word_embeddings.weight.requires_grad = False
elif model_args.experiment == 'e2e-tgt-tree' and model_args.learned_emb == 'yes':
print('loading the learned embeddings')
learned_embeddings = torch.load(path_learned)['word_embedding.weight']
model.transformer.embeddings.word_embeddings.weight.data = learned_embeddings.clone()
model.transformer.embeddings.word_embeddings.weight.requires_grad = False
elif model_args.experiment.startswith('e2e-back') and model_args.learned_emb == 'no':
model.transformer.wte.load_state_dict(torch.load(path_save))
model.transformer.wte.weight.requires_grad = False
elif model_args.experiment.startswith('e2e-back') and model_args.learned_emb == 'yes':
print('loading the learned embeddings')
learned_embeddings = torch.load(path_learned)['word_embedding.weight']
model.transformer.wte.weight.data = learned_embeddings.clone()
model.transformer.wte.weight.requires_grad = False
I will push a new commit.
from diffusion-lm.
Hi Lisa, thank you for your response!
from diffusion-lm.
Related Issues (20)
- I wander where to find the model in the predictability HOT 1
- Training on A100
- Separate weights for word embedding and lm-head?
- Questions about the result of success rate of PPLM? HOT 2
- Why not directly use Emb(W) as X_0? HOT 2
- Error when running training script on Google Colab HOT 2
- Fail to load GPT2 pretrained model for attribute controled generation
- Reproducing Table 5: Sentence Infilling - CIDEr / BLEU-4 metrics HOT 1
- Baseline reproduction
- error when runing:Exception in thread Thread-4:·······ValueError: signal number 32 out of range
- Which classifier to use in custom_trainer.py for controllable generation?
- About the tT_loss HOT 1
- The difference between this code and the paper "IDDPM" in the run_loop function in train_util.py.
- The relevant code that caused the error is in the Controllable Text Generation section, after the model trained for 6 epochs and started evaluating, it raised a KeyError: 'eval_loss' HOT 2
- Questions about the NLL loss
- E2E training procedure
- Issue while generating controllable text generation
- How to Execute the Semantic Content Subtask with infill.py
- Seq2Seq tasks with Diffusion LM
- Difficulty in running code
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from diffusion-lm.