Coder Social home page Coder Social logo

Comments (9)

epeterson12 avatar epeterson12 commented on July 26, 2024 1

Using checkpointing in the unetsmall net increases the speed of training. Tests were performed using the following parameters:

# Training Samples # Validation Samples Sample Size # Classes # Epochs
781 495 256 11 200
Learning Rate Weight Decay Step Size Gamma Class Weights Dropout
0.0001 0 4 0.9 False False

memory_usage_by_batch_size
processing_time_by_batch_size

Best results

. Original Checkpointed
Max batch size 32 50
Time to complete training over 200 epochs 350 min 316 min

Using checkpoints in the net design does seem to affect the results of the training. Tests were done on the original and on the checkpointed nets while setting the random seed to 7 and the models outputted gave similar but slightly different results. In the first test, the original algorithm gave results closer to the ground truth. In the second tests, the checkpointed version of the net yielded better results.

from geo-deep-learning.

epeterson12 avatar epeterson12 commented on July 26, 2024 1

Validation of Checkpointed results

Results when using checkpointing are slightly different from those of the original unetsmall model because CudNN has non-deterministic kernels. I ran tests using suggestions from pytorch discussions https://discuss.pytorch.org/t/non-reproducible-result-with-gpu/1831 and https://discuss.pytorch.org/t/deterministic-non-deterministic-results-with-pytorch/9087.

Using the same sample files, I ran train_model.py twice using the unetsmall model (batch_size = 32) and I ran it once using the checkpointed_unet model (batch_size = 50). Then, I classified some images with the resulting models.

The settings to try to get reproducible results were set as follows at the beginning of the code:

torch.backends.cudnn.deterministic = True
torch.manual_seed(999)
torch.cuda.manual_seed(999)
torch.cuda.manual_seed_all(999)
random.seed(0)

Also, in the DataLoaders, the parameters used for the instantiation had num_workers = 0 and shuffle = False
Running the original unetsmall configuration without checkpoints 2 times yielded two slightly different results. Here are some examples of the results obtained when running image_classification.py on one of the training images with each trained model. Sections if the images that weren’t classified were left white.

Ground Truth Current Code Current Code 2 Checkpointed
1_rgb_8000_8000_ground_truth 1_rgb_8000_8000_original 1_rgb_8000_8000_original2 1_rgb_8000_8000_checkpoint
1_rgb_0_0_ground_truth 1_rgb_0_0_original 1_rgb_0_0_original2 1_rgb_0_0_checkpoint
on_5297_1_ground_truth on_5297_1_original on_5297_1_original2 on_5297_1_checkpointed

Please note that the configurations and the number of samples weren't set to yield optimal results. Verifying reproducibility was the goal of these tests. The number of training samples was set to the number of samples produced during the samples creation.

global:
  samples_size: 256
  num_classes: 5
  data_path: /my/data/path
  number_of_bands: 3
  model_name: unetsmall     # One of unet, unetsmall, checkpointed_unet or ternausnet

sample:
  prep_csv_file: /my/prep/csv/file
  samples_dist: 200
  remove_background: True
  mask_input_image: False

training:
  output_path: /my/output/path
  num_trn_samples: 3356
  num_val_samples: 1370
  batch_size: 32
  num_epochs: 100
  learning_rate: 0.0001
  weight_decay: 0
  step_size: 4
  gamma: 0.9
  class_weights: False

models:
  unet:   &unet001
    dropout: False
    probability: 0.2    # Set with dropout
    pretrained: False   # optional
  unetsmall:
    <<: *unet001
  ternausnet:
    pretrained: ./models/TernausNet.pt    # Mandatory
  checkpointed_unet: 
    <<: *unet001

I think that the results of the checkpointed_unet are similar enough to the unetsmall’s results for us to consider that it is a good memory and time optimised version of our unetsmall net architecture. I have added it as a model choice for our program.

Throughout my tests, I observed that the models produced by training are more accurate when the randoms aren’t seeded. The checkpointed_unet, observationally, seems to be more affected by this then the unetsmall.

from geo-deep-learning.

ymoisan avatar ymoisan commented on July 26, 2024

Can some kind of co-routines/generators approach be useful here ? Are there unnecessary copies of data structures in memory ? Check how to generate your data in parallel with PyTorch.

from geo-deep-learning.

ymoisan avatar ymoisan commented on July 26, 2024

See pytorch/pytorch#5210 and https://pytorch.org/docs/stable/bottleneck.html

from geo-deep-learning.

epeterson12 avatar epeterson12 commented on July 26, 2024

Tests were performed using the unet network rather than unetsmall in order to force a memory error. The error occurred on line 102 of unet_pytorch.py under these conditions. Here is a summary of the observations thus far:

  • Deleting variables as they were no longer needed had no impact on memory usage (ex. del maxpool1 after line 91 in unet_pytorch.py)
  • Adding torch.cuda.empty_cache() to the code gives back approximately 10 MB of GPU memory, which is too little to be an applicable solution
  • Convolution operations, which are quite heavy operations, are performed using CuDNN by Nvidia (closed source) which is optimized for performance on Nvidia GPUs. It is unlikely that we will be able to reduce memory use without building new modules. There seems to be a way to set a memory limit for the convolution in the Caffe C++ source code used by Pytorch, but further investigation is needed to see how to use this in Pytorch.
  • Alternative: we could trade memory for computation and execution speed as described here using gradient checkpointing to limit the size of the stored backpropagation graph. This isn't a good option since we want to keep training time as low as we can.

from geo-deep-learning.

ymoisan avatar ymoisan commented on July 26, 2024

@epeterson12 - interesting concept that we might want to look at : tensor comprehensions

Also, this ticket might not be all that important after all. PyTorch 1.0 will have a compiler to make it faster, which probably also means better on memory.

from geo-deep-learning.

ymoisan avatar ymoisan commented on July 26, 2024

@epeterson12 : the figures for checkpointing don't seem to be that bad. In fact for one of the models checkpointing actually decreases gpu time. I suggest we create some benchmarks to test. import torch.utils.checkpoint works just fine in our environment so we could try the following:

  • confirm that a batch size > 32 with 256 X 256 samples does indeed trigger the error
  • try with checkpoint to see if the batch size goes through; increase batch size until memory crashes again

Memory consumption and processing times could be monitored for all tests. What do you think ?

from geo-deep-learning.

epeterson12 avatar epeterson12 commented on July 26, 2024

@ymoisan I think that it would be a good idea to test our net using checkpointing. I will confirm that a batch size > 32 with 256 X 256 samples causes the error. I have started modifying our code in order to test checkpointing using this methodology I mentioned in my previous comment.

from geo-deep-learning.

epeterson12 avatar epeterson12 commented on July 26, 2024

It turns out that the maximum batch size that we can currently handle with a sample size of 256x256 is 33 when using the unetsmall net and 15 when using the unet net with our current hardware and version of the code. Beyond these quantities, we get the out of memory error.

from geo-deep-learning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.