Coder Social home page Coder Social logo

tclresearcheurope / ptdeco Goto Github PK

View Code? Open in Web Editor NEW
8.0 2.0 1.0 367 KB

ptdeco is a library for model optimization by decomposition built on top of PyTorch

License: Apache License 2.0

Makefile 1.91% Shell 0.42% Python 97.67%
deep-learning model-compression model-optimisation model-optimization pytorch

ptdeco's Introduction

license check test Code style: black Type checker: mypy

ptdeco

ptdeco is a library for model optimization by decomposition built on top of PyTorch.

Table of contents

Introduction

Currently, ptdeco implements the following methods:

  • lockd - method based on local knowledge distillation, tested on vision models (lockd = LOCal Knowledge Distillation)

  • falor - method based on low-rank decomposition of features inspired by Compressing Transformers: Features Are Low-Rank, but Weights Are Not! by Yu Hao, Wu Jianxin (2023), tested on vision models (falor = Features Are LOw Rank)

  • dwain - iterative method based on low-rank decomposition of features, tested on Large Language Models (dwain = Decomposing Weights Algorithm - an Iterative techNique)

lockd method requires short (~ 10 ImageNet epochs) knowledge distillation pretraining before decomposition is made. It can decompose linear layers and convolutions.

falor method does not require pretraining. Model decomposition lasts < 1 GPU hour (depending on model size and parameters). It can decompose linear layers and 1x1 convolutions.

dwain method does not require pretraining. It can decompose linear layers and 1x1 convolutions.

Installation

pip install ptdeco

Saving and loading a decomposed model

Saving a decomposed model

As a result of decomposition you get decompose_config dictionary. You need to serialize this e.g. to JSON. This will let you recreate the structure of a decomposed model. Except this, you need to save state_dict to recover the weights of a decomposed model. The code below illustrates the procedure:

import json
import pathlib

# Your decomposition code

output_path = pathlib.Path("YOUR/CHEKCPOINT/DIRECTORY")
out_decompose_config_path = output_path / "decompose_config.json"
with open(out_decompose_config_path, "wt") as f:
    json.dump(decompose_config, f)
out_decompose_state_dict_path = output_path / "decompose_state_dict.pt"
torch.save(model.state_dict(), out_decompose_state_dict_path)

Loading a decomposed model

To load the model, you need to recreate the original model first. Next, you load and apply the decompose_config. Finally, you load the state_dict (note the state dict "fits" the decomposed model, so you need to do it as a last step). The code below illustrates the procedure:

import json
import pathlib

import ptdeco

model = ... # Build original model
device = ...     # Specify the device original model uses

output_path = pathlib.Path("YOUR/CHEKCPOINT/DIRECTORY")

with open(output_path / "decompose_config.json", "rt") as f:
        decompose_config = json.load(f)

ptdeco.utils.apply_decompose_config_in_place(model, decompose_config)

sd = torch.load(output_path / "decompose_state_dict.pt")

model.load_state_dict(sd, map_location=device)

# Now `model` is decomposed and contains appropriate weights

ptdeco's People

Contributors

michal-lopuszynski-tcl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

lopusz

ptdeco's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.