Coder Social home page Coder Social logo

Comments (5)

One-sixth avatar One-sixth commented on July 22, 2024

@zcong17huang 看起来似乎是内存泄漏的问题。内存增长时爆掉时,显存占用情况如何,是随着内存增长而增长还是基本不变?还请大佬提供一下 pytorch 版本,还有系统版本,CUDA版本,CUDNN版本的信息。可以的话,最好还提供下相关的训练代码。

from ms_ssim_pytorch.

zcong17huang avatar zcong17huang commented on July 22, 2024

显存是没有变化的,就是内存到后面就爆了。
windows系统跟服务器都试过,windows上,我用的版本是python3.5.6,pytorch1.2.0,,CUDA10.1, CUDNN7.6.0
训练代码是这样的:

start_full_time = time.time()
model.train()
last_loss = 1000000

for epoch in range(epoch_start, args.epochs+1):
    log.info('This is %d-th epoch, learning rate : %f '%(epoch, scheduler.get_lr()[0]))
    total_train_loss = 0
    length_loader = len(TrainImgLoader)
    start_time = time.time()

    ## training ##
    for batch_idx, (data_img, data_label, data_gt, data_input) in enumerate(TrainImgLoader):
        # --------------------------------------------------
        batch_time = time.time()
        optimizer.zero_grad()

        data_img = data_img.float().to(device)
        data_label = data_label.float().to(device)
        data_gt = data_gt.float().to(device)
        data_input = data_input.float().to(device)
        # print(data_img.shape, data_label.shape, data_gt.shape, data_input.shape)

        _, output1, output2, output3 = model(data_img, data_label, data_input)

        data_gt = data_gt*255.0         #放缩到0~255的深度范围
        output1 = output1*255.0
        output2 = output2*255.0
        output3 = output3*255.0

        loss1 = nn.SmoothL1Loss()
        loss2 = losses.SSIM(data_range=255., channel=1).to(device)   # 这个是大佬您的SSIM loss
        loss3 = losses.SemanticBoundaryLoss(device)

        loss_1 = loss1(data_gt, output1) \
               + 0.1 * (1-loss2(data_gt, output1)) \
               + 0.1 * loss3(data_label, output1)
        loss_2 = loss1(data_gt, output2) \
               + 0.1 * (1-loss2(data_gt, output2)) \
               + 0.1 * loss3(data_label, output2)
        loss_3 = loss1(data_gt, output3) \
                 + 0.1 * (1 - loss2(data_gt, output3)) \
                 + 0.1 * loss3(data_label, output3)

        loss = 0.6*loss_1 + 0.8*loss_2 + loss_3

        loss.backward()
        optimizer.step()
        loss = loss.item()

        EPE_error = torch.mean(torch.abs(data_gt-output3))  # end-point-error

        writer_name = args.logpath.split('/')[-1]
        writer.add_scalar(writer_name, loss, (batch_idx + (epoch * length_loader)))
        writer.close()
        # -----------------------------------------------
        total_train_loss += loss

    log.info('epoch %d total training loss = %.5f, time = %.4f Hours' %(epoch, total_train_loss/length_loader, (time.time() - start_time)/3600))
    scheduler.step()

我的loss都取的值,不应该内存爆掉啊。不知道问题在哪里,麻烦大佬看一看。

from ms_ssim_pytorch.

One-sixth avatar One-sixth commented on July 22, 2024

@zcong17huang 看了代码,怀疑是 pytorch 1.2 的 jit 模块释放时的内存可能会存在泄漏问题。有没有试过最新版本pytorch?如果只能使用 pytorch 1.2,你可以在训练循环外面仅初始一次这个SSIM 模块,再重复使用这个模块来避免这个问题。以下是修改代码,看看还有没有问题。

start_full_time = time.time()
model.train()
last_loss = 1000000

# 这个就放在外面初始化一次就行了,可以重复使用。
loss2 = losses.SSIM(data_range=255., channel=1).to(device)   # 这个是大佬您的SSIM loss

for epoch in range(epoch_start, args.epochs+1):
    log.info('This is %d-th epoch, learning rate : %f '%(epoch, scheduler.get_lr()[0]))
    total_train_loss = 0
    length_loader = len(TrainImgLoader)
    start_time = time.time()

    ## training ##
    for batch_idx, (data_img, data_label, data_gt, data_input) in enumerate(TrainImgLoader):
        # --------------------------------------------------
        batch_time = time.time()
        optimizer.zero_grad()

        data_img = data_img.float().to(device)
        data_label = data_label.float().to(device)
        data_gt = data_gt.float().to(device)
        data_input = data_input.float().to(device)
        # print(data_img.shape, data_label.shape, data_gt.shape, data_input.shape)

        _, output1, output2, output3 = model(data_img, data_label, data_input)

        data_gt = data_gt*255.0         #放缩到0~255的深度范围
        output1 = output1*255.0
        output2 = output2*255.0
        output3 = output3*255.0

        loss1 = nn.SmoothL1Loss()
        # 不要在训练循环里面重复生成新的模块。
        # loss2 = losses.SSIM(data_range=255., channel=1).to(device)   # 这个是大佬您的SSIM loss
        loss3 = losses.SemanticBoundaryLoss(device)

        loss_1 = loss1(data_gt, output1) \
               + 0.1 * (1-loss2(data_gt, output1)) \
               + 0.1 * loss3(data_label, output1)
        loss_2 = loss1(data_gt, output2) \
               + 0.1 * (1-loss2(data_gt, output2)) \
               + 0.1 * loss3(data_label, output2)
        loss_3 = loss1(data_gt, output3) \
                 + 0.1 * (1 - loss2(data_gt, output3)) \
                 + 0.1 * loss3(data_label, output3)

        loss = 0.6*loss_1 + 0.8*loss_2 + loss_3

        loss.backward()
        optimizer.step()
        loss = loss.item()

        EPE_error = torch.mean(torch.abs(data_gt-output3))  # end-point-error

        writer_name = args.logpath.split('/')[-1]
        writer.add_scalar(writer_name, loss, (batch_idx + (epoch * length_loader)))
        writer.close()
        # -----------------------------------------------
        total_train_loss += loss

    log.info('epoch %d total training loss = %.5f, time = %.4f Hours' %(epoch, total_train_loss/length_loader, (time.time() - start_time)/3600))
    scheduler.step()

from ms_ssim_pytorch.

zcong17huang avatar zcong17huang commented on July 22, 2024

因为当时电脑的环境是pytorch 1.2,没有再更换版本实验了。我不太懂jit这一块的运行机制,不过看您的代码跟其他SSIM的区别好像就是jit这一块,所以估计应该是jit模块与我所用的pytorch不兼容的问题,后面我要是再做实验了,会及时反馈的。感谢大佬及时解答疑惑!

from ms_ssim_pytorch.

One-sixth avatar One-sixth commented on July 22, 2024

@zcong17huang 那我先关闭这个issue了,等你有进展再打开把。

from ms_ssim_pytorch.

Related Issues (6)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.