Comments (5)
@zcong17huang 看起来似乎是内存泄漏的问题。内存增长时爆掉时,显存占用情况如何,是随着内存增长而增长还是基本不变?还请大佬提供一下 pytorch 版本,还有系统版本,CUDA版本,CUDNN版本的信息。可以的话,最好还提供下相关的训练代码。
from ms_ssim_pytorch.
显存是没有变化的,就是内存到后面就爆了。
windows系统跟服务器都试过,windows上,我用的版本是python3.5.6,pytorch1.2.0,,CUDA10.1, CUDNN7.6.0
训练代码是这样的:
start_full_time = time.time()
model.train()
last_loss = 1000000
for epoch in range(epoch_start, args.epochs+1):
log.info('This is %d-th epoch, learning rate : %f '%(epoch, scheduler.get_lr()[0]))
total_train_loss = 0
length_loader = len(TrainImgLoader)
start_time = time.time()
## training ##
for batch_idx, (data_img, data_label, data_gt, data_input) in enumerate(TrainImgLoader):
# --------------------------------------------------
batch_time = time.time()
optimizer.zero_grad()
data_img = data_img.float().to(device)
data_label = data_label.float().to(device)
data_gt = data_gt.float().to(device)
data_input = data_input.float().to(device)
# print(data_img.shape, data_label.shape, data_gt.shape, data_input.shape)
_, output1, output2, output3 = model(data_img, data_label, data_input)
data_gt = data_gt*255.0 #放缩到0~255的深度范围
output1 = output1*255.0
output2 = output2*255.0
output3 = output3*255.0
loss1 = nn.SmoothL1Loss()
loss2 = losses.SSIM(data_range=255., channel=1).to(device) # 这个是大佬您的SSIM loss
loss3 = losses.SemanticBoundaryLoss(device)
loss_1 = loss1(data_gt, output1) \
+ 0.1 * (1-loss2(data_gt, output1)) \
+ 0.1 * loss3(data_label, output1)
loss_2 = loss1(data_gt, output2) \
+ 0.1 * (1-loss2(data_gt, output2)) \
+ 0.1 * loss3(data_label, output2)
loss_3 = loss1(data_gt, output3) \
+ 0.1 * (1 - loss2(data_gt, output3)) \
+ 0.1 * loss3(data_label, output3)
loss = 0.6*loss_1 + 0.8*loss_2 + loss_3
loss.backward()
optimizer.step()
loss = loss.item()
EPE_error = torch.mean(torch.abs(data_gt-output3)) # end-point-error
writer_name = args.logpath.split('/')[-1]
writer.add_scalar(writer_name, loss, (batch_idx + (epoch * length_loader)))
writer.close()
# -----------------------------------------------
total_train_loss += loss
log.info('epoch %d total training loss = %.5f, time = %.4f Hours' %(epoch, total_train_loss/length_loader, (time.time() - start_time)/3600))
scheduler.step()
我的loss都取的值,不应该内存爆掉啊。不知道问题在哪里,麻烦大佬看一看。
from ms_ssim_pytorch.
@zcong17huang 看了代码,怀疑是 pytorch 1.2 的 jit 模块释放时的内存可能会存在泄漏问题。有没有试过最新版本pytorch?如果只能使用 pytorch 1.2,你可以在训练循环外面仅初始一次这个SSIM 模块,再重复使用这个模块来避免这个问题。以下是修改代码,看看还有没有问题。
start_full_time = time.time()
model.train()
last_loss = 1000000
# 这个就放在外面初始化一次就行了,可以重复使用。
loss2 = losses.SSIM(data_range=255., channel=1).to(device) # 这个是大佬您的SSIM loss
for epoch in range(epoch_start, args.epochs+1):
log.info('This is %d-th epoch, learning rate : %f '%(epoch, scheduler.get_lr()[0]))
total_train_loss = 0
length_loader = len(TrainImgLoader)
start_time = time.time()
## training ##
for batch_idx, (data_img, data_label, data_gt, data_input) in enumerate(TrainImgLoader):
# --------------------------------------------------
batch_time = time.time()
optimizer.zero_grad()
data_img = data_img.float().to(device)
data_label = data_label.float().to(device)
data_gt = data_gt.float().to(device)
data_input = data_input.float().to(device)
# print(data_img.shape, data_label.shape, data_gt.shape, data_input.shape)
_, output1, output2, output3 = model(data_img, data_label, data_input)
data_gt = data_gt*255.0 #放缩到0~255的深度范围
output1 = output1*255.0
output2 = output2*255.0
output3 = output3*255.0
loss1 = nn.SmoothL1Loss()
# 不要在训练循环里面重复生成新的模块。
# loss2 = losses.SSIM(data_range=255., channel=1).to(device) # 这个是大佬您的SSIM loss
loss3 = losses.SemanticBoundaryLoss(device)
loss_1 = loss1(data_gt, output1) \
+ 0.1 * (1-loss2(data_gt, output1)) \
+ 0.1 * loss3(data_label, output1)
loss_2 = loss1(data_gt, output2) \
+ 0.1 * (1-loss2(data_gt, output2)) \
+ 0.1 * loss3(data_label, output2)
loss_3 = loss1(data_gt, output3) \
+ 0.1 * (1 - loss2(data_gt, output3)) \
+ 0.1 * loss3(data_label, output3)
loss = 0.6*loss_1 + 0.8*loss_2 + loss_3
loss.backward()
optimizer.step()
loss = loss.item()
EPE_error = torch.mean(torch.abs(data_gt-output3)) # end-point-error
writer_name = args.logpath.split('/')[-1]
writer.add_scalar(writer_name, loss, (batch_idx + (epoch * length_loader)))
writer.close()
# -----------------------------------------------
total_train_loss += loss
log.info('epoch %d total training loss = %.5f, time = %.4f Hours' %(epoch, total_train_loss/length_loader, (time.time() - start_time)/3600))
scheduler.step()
from ms_ssim_pytorch.
因为当时电脑的环境是pytorch 1.2,没有再更换版本实验了。我不太懂jit这一块的运行机制,不过看您的代码跟其他SSIM的区别好像就是jit这一块,所以估计应该是jit模块与我所用的pytorch不兼容的问题,后面我要是再做实验了,会及时反馈的。感谢大佬及时解答疑惑!
from ms_ssim_pytorch.
@zcong17huang 那我先关闭这个issue了,等你有进展再打开把。
from ms_ssim_pytorch.
Related Issues (6)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ms_ssim_pytorch.