Coder Social home page Coder Social logo

ben-hawks / pytorch-jet-classify Goto Github PK

View Code? Open in Web Editor NEW
8.0 8.0 8.0 232 KB

Pytorch implementation of the HLS4ML 3 Layer Jet Tagging model, including a standard Pytorch (float) and a Quantized (via Xilinx's Brevitas library) implementation. designed for use with the HLS4ML LHC Jet dataset (100 particles): https://zenodo.org/record/3602254

Python 97.72% Shell 2.12% Dockerfile 0.16%

pytorch-jet-classify's People

Contributors

ben-hawks avatar dependabot[bot] avatar jmduarte avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-jet-classify's Issues

Implement L1 Weight Regularization/Pruning

To match the behavior of the Keras model/explore pruning (and the relationship between pruning and quantization), implement L1 Weight Regularization during training. Currently, L2 Regularization is being used as it's built into the optimizer via the weight_decay parameter, set to 1e-5 at the moment.

Understand brevitas.qnn.QuantReLU's `max_val` required parameter

Unlike stock PyTorch, Brevitas' qnn.QuantReLU has a required parameter specifying it's maximum value. In documentation and examples, it's set to 6 (making it effectively equal to a ReLU6 activation function), and at first glance doesn't seem to affect performance much, but this should be better understood and if possible (and makes sense to) worked around to mimic the unquantized implementation.

Understand the impact of having to use local_batch.float() in training/evaulation

When trying to train both the float and the quantized model, unless current_model.double() is called before training, or within training/evalulation of the model the input tensors are set to floats via outputs = current_model(local_batch.float()) and criterion_loss = criterion(outputs,local_labels.float()) the following error occurs:

current_model = models.three_layer_model_bv()
Traceback (most recent call last):
File "C:/Users/Ben/PycharmProjects/pytorch-jet-classify/train.py", line 203, in
outputs = current_model(local_batch)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\torch\nn\modules\module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\models.py", line 56, in forward
test = self.fc1(x)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\torch\nn\modules\module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\brevitas\nn\quant_linear.py", line 201, in forward
output = linear(input, quant_weight, quant_bias)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\torch\nn\functional.py", line 1610, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat1' in call to _th_addmm`

current_model = models.three_layer_model()
Traceback (most recent call last):
File "C:/Users/Ben/PycharmProjects/pytorch-jet-classify/train.py", line 203, in
outputs = current_model(local_batch)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\torch\nn\modules\module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\models.py", line 21, in forward
x = self.act(self.fc1(x))
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\torch\nn\modules\module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\torch\nn\modules\linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Users\Ben\PycharmProjects\pytorch-jet-classify\venv\lib\site-packages\torch\nn\functional.py", line 1610, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat1' in call to _th_addmm

We need to either understand:

  • Why this is occurring and if we can fix it without settings current_model.double() or local_batch.float()
  • If setting current_model.double() doesn't effect the quantization of the network via brevitas
  • if setting 'local_batch.float()'/local_labels.float() doesn't effect quantization of the network via brevitas

The currently implemented method is using local_batch.float() / local_labels.float() within model training/evaluation, as it seems to have a noticeable impact on model size. (model.double() model size was ~45KB, local_batch.float() model size was ~18KB for the unquantized model)

Found information so far (to look into further at some point)
https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-double-but-got-scalar-type-float-for-argument-2-weight/38961/8
https://stackoverflow.com/questions/56741087/how-to-fix-runtimeerror-expected-object-of-scalar-type-float-but-got-scalar-typ/56741419
pytorch/pytorch#2138

Weight distribution plots seemingly showing no difference between unpruned & first prune iteration sometimes

The first weight distribution plot for the unpruned version of the model and the first iteration of pruning the model can sometimes seemingly show no difference, despite other indicators showing that the pruning actually did take place, and each subsequent plot showing the pruning taking place. This only seems to occur sometimes though, specifically tending to happen when training with the "full dataset" (vs the single file and/or smaller epoch limit used for verification/testing)

For example:
Unpruned vs First prune iteration weights:
image
image

Performance metrics for same iterations:
Unpruned:
image
image
image
Pruned:
image
image
image

Additionally, in console output, when pruned parameters are counted, the model has clearly been pruned between iterations:

Pre-Pruning:
fc1.weight           | nonzeros =    1024 /    1024 (100.00%) | total_pruned =       0 | shape = (64, 16)
fc1.bias             | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
fc2.weight           | nonzeros =    2048 /    2048 (100.00%) | total_pruned =       0 | shape = (32, 64)
fc2.bias             | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
fc3.weight           | nonzeros =    1024 /    1024 (100.00%) | total_pruned =       0 | shape = (32, 32)
fc3.bias             | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
fc4.weight           | nonzeros =     160 /     160 (100.00%) | total_pruned =       0 | shape = (5, 32)
fc4.bias             | nonzeros =       5 /       5 (100.00%) | total_pruned =       0 | shape = (5,)
alive: 4389, pruned : 0, total: 4389, Compression rate :       1.00x  (  0.00% pruned)
Post-Pruning:
fc1.bias             | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
fc1.weight           | nonzeros =     941 /    1024 ( 91.89%) | total_pruned =      83 | shape = (64, 16)
fc2.bias             | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
fc2.weight           | nonzeros =    1827 /    2048 ( 89.21%) | total_pruned =     221 | shape = (32, 64)
fc3.bias             | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
fc3.weight           | nonzeros =     890 /    1024 ( 86.91%) | total_pruned =     134 | shape = (32, 32)
fc4.weight           | nonzeros =     160 /     160 (100.00%) | total_pruned =       0 | shape = (5, 32)
fc4.bias             | nonzeros =       5 /       5 (100.00%) | total_pruned =       0 | shape = (5,)
alive: 3951, pruned : 438, total: 4389, Compression rate :       1.11x  (  9.98% pruned)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.