Coder Social home page Coder Social logo

aqt's Issues

Performance of MNIST example

Hi everyone - thanks for your work on this, very exciting!

I've been playing around a bit with the Flax MNIST example (https://github.com/google/aqt/blob/main/aqt/jax/v2/examples/mnist.py). I've benchmarked the training (as well as eval) on TPU v4 and v5 and can't see a performance improvement compared to bfloat16/float32 training. Both training and eval are around 4% slower when using int8 quantized operations.

Am I doing something wrong or is this expected? I could imagine that the overhead of converting from float32 to int8 and back is non-negligible at this small scale.

Refactor config/code classes to follow Flax.

For every piece of logic (like Numerics, Calibration, Tensor, DotGeneralRaw, DotGeneral, we should have a single class with "dataclass" field that configure it and methods that execute the logic.

TypeError: dataclass() got an unexpected keyword argument 'frozen'

I met this problem when I tried to use it.

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/lwm/vision_chat.py", line 18, in <module>
    from lwm.vision_llama import VideoLLaMAConfig, FlaxVideoLLaMAForCausalLM
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/lwm/vision_llama.py", line 21, in <module>
    from lwm.llama import LLaMAConfig, LLAMA_STANDARD_CONFIGS, FlaxLLaMABlockCollection, RMSNorm
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/lwm/llama.py", line 34, in <module>
    import aqt.jax.v2.flax.aqt_flax as aqt
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/flax/aqt_flax.py", line 23, in <module>
    from aqt.jax.v2 import aqt_dot_general
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/aqt_dot_general.py", line 29, in <module>
    from aqt.jax.v2 import aqt_tensor
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/aqt_tensor.py", line 28, in <module>
    from aqt.jax.v2.numerics import no_numerics
  File "/mnt/g/Seto/GitHub/LargeWorldModel/LWM/venv/lib/python3.10/site-packages/aqt/jax/v2/numerics/no_numerics.py", line 23, in <module>
    class NoNumerics(numerics.AqtNumerics):
TypeError: dataclass() got an unexpected keyword argument 'frozen'

flax_e2e_model.py example fails

I'm getting this error when running python3 flax_e2e_model.py which I think is from the lhs quantmode being QuantMode.CONVERT, which pushes the lhs freezer to store the lhs scale during serving.

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 490, in <module>
    app.run(main)
  File "/var/tmp/aqt/aqt_env/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/var/tmp/aqt/aqt_env/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 485, in main
    loss = serve(state, weight_only=False)
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 440, in serve
    logits = serve_fn(
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 84, in __call__
    x = nn.Dense(features=256, dot_general_cls=aqt_dg)(x)
  File "/var/tmp/aqt/aqt_env/lib/python3.10/site-packages/flax/linen/linear.py", line 276, in __call__
    y = dot_general(
  File "/var/tmp/aqt/aqt/jax/v2/flax/aqt_flax.py", line 515, in __call__
    return ret_dg(
  File "/var/tmp/aqt/aqt/jax/v2/tiled_dot_general.py", line 527, in tiled_dot_general
    return tiled_dot_general_with_tiling_states(
  File "/var/tmp/aqt/aqt/jax/v2/tiled_dot_general.py", line 419, in tiled_dot_general_with_tiling_states
    out = dot_general(
  File "/var/tmp/aqt/aqt/jax/v2/flax/aqt_flax.py", line 459, in ret_dg
    lhs_freezer.set(out_lhs_qt)
  File "/var/tmp/aqt/aqt/jax/v2/flax/freezer.py", line 100, in set
    return self._get_or_set(inputs, is_set=True)
  File "/var/tmp/aqt/aqt/jax/v2/flax/freezer.py", line 63, in _get_or_set
    s.value = inputs
flax.errors.ModifyScopeVariableError: Cannot update variable "frozen" in "/Dense_0/AqtDotGeneral_0/qlhs" because collection "aqt" is immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ModifyScopeVariableError)

Publish updated version?

The current version of this package on pypi is 0.0.9 doesn't include the fix in #56 , leading to this ImportError:

  File "/opt/conda/lib/python3.7/site-packages/aqt/jax_legacy/jax/compute_cost_utils.py", line 27, in <module>
    from jax.interpreters import masking
ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/conda/lib/python3.7/site-packages/jax/interpreters/__init__.py)

Would you consider releasing a new version?

Port static quantization from AQTv1 to AQTv2

Right now in AQTv2, we have only dynamic quantization.
It is great for backprop quantization, but we can have much better inference quality (and performance) with static quantization.

Broken `aqtp-0.1.1` package: missing `aqt` package prefix

The most recent aqtp-0.1.1 package is missing the aqt prefix in the installed packages:

$ pip install aqtp==0.1.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting aqtp==0.1.1
  Downloading aqtp-0.1.1-py3-none-any.whl (405 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 405.5/405.5 kB 7.9 MB/s eta 0:00:00
Installing collected packages: aqtp


$ pip show -f aqtp
Successfully installed aqtp-0.1.1
Name: aqtp
Version: 0.1.1
Summary: AQT: Accurate Quantized Training
Home-page: https://github.com/google/aqt
Author: Cerebra Catalyst team
Author-email: [email protected]
License: 
Location: /usr/local/lib/python3.10/dist-packages
Requires: 
Required-by: 
Files:
  aqtp-0.1.1.dist-info/INSTALLER
  aqtp-0.1.1.dist-info/LICENSE
  aqtp-0.1.1.dist-info/METADATA
  aqtp-0.1.1.dist-info/RECORD
  aqtp-0.1.1.dist-info/REQUESTED
  aqtp-0.1.1.dist-info/WHEEL
  aqtp-0.1.1.dist-info/top_level.txt
  common/__init__.py
  common/__pycache__/__init__.cpython-310.pyc
  common/__pycache__/aqt_common.cpython-310.pyc
  common/__pycache__/aqt_config.cpython-310.pyc
  common/__pycache__/aqt_config_schedule_test.cpython-310.pyc
  common/__pycache__/aqt_config_utils.cpython-310.pyc
  common/__pycache__/emulated_floating_points.cpython-310.pyc
  common/__pycache__/emulation_utils.cpython-310.pyc
  common/aqt_common.py
  common/aqt_config.py
  common/aqt_config_schedule_test.py
  common/aqt_config_utils.py
  common/emulated_floating_points.py
  common/emulation_utils.py
  jax/__init__.py
  jax/__pycache__/__init__.cpython-310.pyc
...

Note that these both work as expected:

  • pip install aqtp==0.1.0
  • pip install aqtp@git+https://github.com/google/aqt.git

The problem should easily be fixed by

  1. yank the faulty 0.1.1 version
  2. increase the version to 0.1.2
  3. create a new wheel
  4. verify the wheel installs correctly (I checked at HEAD and it seems to work as expected; not sure what exactly went wrong when uploading the 0.1.1 package)
  5. upload the new version to PyPi

Note that the faulty 0.1.1 package breaks all downstream useres, e.g. google-research/vision_transformer#271

To avoid these problems in the future, it might be a good idea to install an automatic Python publish workflow (e.g. like this example)

Can AQT be used to calculate qk score?

I see that the sample codes all talk about Attention block or MLP block. Can aqt int8 only be used for parts involving parameter calculation? For example, qk score calculation, score * V calculation, can these be used aqt int8?

Binary quantization?

The README mentions the Binarized Neural Machine Translation paper but does not really elaborate on how one can use AQT to implement one-bit weights and activations using AQT. For example, will the library take care of using LayerNorms as replacement for scaling factor, as mentioned in the paper?

NormalFloat4 support

Could you anyone support normalfloat4 for kernel fusing? It seems effective in QLoRA.

Quantized Batch Normalization?

Hi:

Thanks for your great work and open-sourced quantization codes!
I read your ResNet-4bit and PokeBNN papers and am interested in some GPU acceleration research based on your models.
Here I am a bit confused about the data flow of the model.

If I understand correctly, your batch normalization operator is not quantized, which means it will operate in bf16 during inference. So in a block with Conv+BN, your Conv layer will output 8/4-bit data, and then it will go through a bf16 BN operator. The output will be bf16 in the end. Then the bf16 activation data will go to the next Conv layer where it will be first quantized into 8/4-bit data and do quantized Conv operation.

If we consider two blocks with [Conv + BN]. The data will be something like: bf16->int8/int4 ->(Conv) -> int8/int4 -> bf16 -> (BN) -> bf16->next Block -> bf16.

I am not sure If I understand correctly, Do you have any comments?

Thanks a lot and thanks for your excellent project!

How to use it with jnp.einsum?

I tried to use this package with 0.7.2, but I encounter an error with the following code.

from aqt.jax.v2 import config

dot_general = config.dot_general_make(8, 8)
x = jax.random.normal(jax.random.PRNGKey(0), (4, 4))
y = jax.random.normal(jax.random.PRNGKey(1), (4, 4))

print(jnp.einsum('ij,jk->ik', x, y))
print(jnp.einsum('ij,jk->ik', x, y, _dot_general=dot_general))
ValueError: Non-hashable static arguments are not supported. An error occurred during a call to '_einsum' while trying to hash an object of type <class 'aqt.jax.v2.aqt_dot_general.DotGeneral'>, DotGeneral(fwd=DotGeneralRaw(lhs=Tensor(use_fwd_q

How to use it??

ckpt of 8-bit ResNet-50 teacher model

Hello, thanks for the great project.

Can you share the ckpt file of the trained 8-bit resnet50 teacher model?

I think better reproduction will be guaranteed if the teacher model's ckpt is shared.

Thank you.

AqtEinsum 'not enough values to unpack'

Hi, I was working with AqtEinsum and in this particular case I got ValueError, altough in jnp.einsum the following operation works fine.

This works fine:

x = jax.random.normal(key, [1, 2, 4])
w = jax.random.normal(key, [2, 4, 4])

z = jnp.einsum('...ij,hjk->...ik', x, w)
z

This is not:

class SimpleDense(nn.Module):
    features: int
    config = aqt_config.fully_quantized()

    @nn.compact
    def __call__(self, x):
        d = x.shape[-1]

        kernel = self.param('kernel', nn.initializers.normal(), (2, d, self.features))
        einsum = aqt.AqtEinsum(self.config)

        return einsum('...ij,hjk->...ik', x, kernel)

model = SimpleDense(features = 4)
params = model.init(key, x)
ValueError                                Traceback (most recent call last)
[<ipython-input-41-bf39ae22f96a>](https://localhost:8080/#) in <cell line: 2>()
      1 model = SimpleDense(features = 4)
----> 2 params = model.init(key, x)

    [... skipping hidden 9 frame]

1 frames
[<ipython-input-40-29c53684ec5e>](https://localhost:8080/#) in __call__(self, x)
     10         einsum = aqt.AqtEinsum(self.config)
     11 
---> 12         return einsum('...ij,hjk->...ik', x, kernel)

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.10/dist-packages/aqt/jax/v2/flax/aqt_flax.py](https://localhost:8080/#) in __call__(self, eqn, lhs_g, rhs_g)
    315     einsum = functools.partial(aqt_dot_general.einsum, eqn=eqn)
    316     a = jax.make_jaxpr(einsum)(lhs=lhs_in, rhs=rhs_in)
--> 317     [lhs_g_id, rhs_g_id] = a.eqns[0].invars
    318     [lhs_l_id, rhs_l_id] = a.jaxpr.invars
    319     not_swap = lhs_g_id == lhs_l_id and rhs_g_id == rhs_l_id

ValueError: not enough values to unpack (expected 2, got 1)

Also if the einsum subscript and the kernel dimension is the following:

...
kernel = self.param('kernel', nn.initializers.normal(), (d, self.features))
einsum = aqt.AqtEinsum(self.config)

return einsum('...ij,jk->...ik', x, kernel)
...

The code works as it is expected without any errors.

For mention I'm using aqt version 0.5.0 and the random seed is 42.

Does JAXv2 allow for arbitrary quantization?

Hi everyone! I would like to use AQT to quantize deep learning models that then I would infer on my hardware (FPGAs). Does JAXv2 support arbitrary quantization (e.g., INT4)? I am asking because I only saw examples using INT8 data type.

generalized einsum or matmul api for pure jax

I want to use this package in the following ways.

x = jax.random.normal(jax.random.PRNGKey(0), (4, 4), dtype=jnp.bfloat16)
w = jax.random.normal(jax.random.PRNGKey(1), (4, 3), dtype=jnp.bfloat16)
w_q = quantize(w)  # it might return QTensor in this package.
y = einsum('ij,jk->ik', x, w_q, lhs=jnp.bfloat16, rhs=QTensor)  # it might return jnp.bfloat16 for type promotion rule.

Do I have to do this now? I tried, but couldn't reach the solution. I want to get the solution with the fused kernel so that the overhead is minimized.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.