Comments (3)
I've succeeded to train 6 tags at the same time. In experiment, I found 50k per tag is enough (i.e., 20k for 6 tags).
HiSD supports various numbers of tags but you should increase the training iteration and the model capacity.
Using gradient accumulation and train all tags in one iteration is also important (so you need to change the code a little).
from hisd.
Thanks for your reply.
Is it right about "the gradient accumulation and all tags in one iteration"?
And '20k for 6 tags' is the typo? The official repo is 200k for 3 tags with 7 attributions.
Then is there a better performance when we train fewer tags?
from hisd.
Sorry for the typo, it should be 200k for 3 tags with 7 attributes.
You get the idea of the gradient accumulation in a right way, and you can clarify the update code like:
def update(self, x, y, i, j, j_trg, iterations):
this_model = self.models.module if self.multi_gpus else self.models
# gen
for p in this_model.dis.parameters():
p.requires_grad = False
for p in this_model.gen.parameters():
p.requires_grad = True
self.loss_gen_adv, self.loss_gen_sty, self.loss_gen_rec, \
x_trg, x_cyc, s, s_trg = self.models((x, y, i, j, j_trg), mode='gen')
self.loss_gen_adv = self.loss_gen_adv.mean()
self.loss_gen_sty = self.loss_gen_sty.mean()
self.loss_gen_rec = self.loss_gen_rec.mean()
# dis
for p in this_model.dis.parameters():
p.requires_grad = True
for p in this_model.gen.parameters():
p.requires_grad = False
self.loss_dis_adv = self.models((x, x_trg, x_cyc, s, s_trg, y, i, j, j_trg), mode='dis')
self.loss_dis_adv = self.loss_dis_adv.mean()
if (iterations + 1) % self.tag_num == 0:
nn.utils.clip_grad_norm_(this_model.gen.parameters(), 100)
nn.utils.clip_grad_norm_(this_model.dis.parameters(), 100)
self.gen_opt.step()
self.dis_opt.step()
self.gen_opt.zero_grad()
self.dis_opt.zero_grad()
update_average(this_model.gen_test, this_model.gen)
return self.loss_gen_adv.item(), \
self.loss_gen_sty.item(), \
self.loss_gen_rec.item(), \
self.loss_dis_adv.item()
And you need to decrease the learning rate before backward (maybe lr/tag_num) since the gradient by 'sum' rather than 'average'.
from hisd.
Related Issues (20)
- Virtual memory usage is too large HOT 4
- About Mutil-Gpus for training? HOT 7
- Another size of image HOT 5
- Data imbalance HOT 1
- 作者您好,想请问一下关于discriminator的问题 HOT 6
- 请求RaFD数据集 HOT 4
- About beard model HOT 14
- A few questions about training tricks HOT 6
- 请教一下您的生成器和鉴别器模型是基于哪篇文章的 HOT 1
- 您好,请教一下论文中实验部分的一些问题,望指点 HOT 5
- 如何处理检查点文件? HOT 2
- 如何设置custom datasets? HOT 6
- 可以用m1芯片运行吗? HOT 3
- 更换数据集 HOT 2
- 关于生成图片模糊的问题 HOT 2
- Multi-tag task HOT 5
- 使用AFHQ数据集训练模型 HOT 13
- 求一个RaFD或者emotionnet数据集,想在celebA上做一下联合训练,有老哥有吗?我RaFD一直没申请到,emotionnet在网上下载的好像也是有问题的。
- 请教关于生成结果图像质量下降等问题 HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hisd.