Coder Social home page Coder Social logo

jax-triton's Introduction

jax-triton

PyPI version

The jax-triton repository contains integrations between JAX and Triton.

Documentation can be found here.

This is not an officially supported Google product.

Quickstart

The main function of interest is jax_triton.triton_call for applying Triton functions to JAX arrays, including inside jax.jit-compiled functions. For example, we can define a kernel from the Triton tutorial:

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    length,
    output_ptr,
    block_size: tl.constexpr,
):
  """Adds two vectors."""
  pid = tl.program_id(axis=0)
  block_start = pid * block_size
  offsets = block_start + tl.arange(0, block_size)
  mask = offsets < length
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

Then we can apply it to JAX arrays using jax_triton.triton_call:

import jax
import jax.numpy as jnp
import jax_triton as jt

def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
  block_size = 8
  return jt.triton_call(
      x,
      y,
      x.size,
      kernel=add_kernel,
      out_shape=out_shape,
      grid=(x.size // block_size,),
      block_size=block_size)

x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))

See the examples directory, especially fused_attention.py and the fused attention ipynb.

Installation

$ pip install jax-triton

Make sure you have a CUDA-compatible jaxlib installed. For example you could run:

$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Installation at HEAD

JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:

$ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'

This should install compatible versions of JAX and Triton.

JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:

$ pip install jaxlib[cuda11_pip]
$ # or
$ pip install jaxlib[cuda12_pip]

If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly. To install a new jaxlib, you can find a link to a CUDA 11 nightly or CUDA 12 nightly. Then install it via:

$ pip install 'jaxlib @ <link to nightly>'

or to install CUDA via pip automatically, you can do:

$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
$ # or
$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'

Development

To develop jax-triton, you can clone the repo with:

$ git clone https://github.com/jax-ml/jax-triton.git

and do an editable install with:

$ cd jax-triton
$ pip install -e .

To run the jax-triton tests, you'll need pytest:

$ pip install pytest
$ pytest tests/

jax-triton's People

Contributors

apaszke avatar arthurbrussee avatar bchetioui avatar brianwa84 avatar chr1sj0nes avatar chsigg avatar gflegar avatar giorgio-arena avatar gnecula avatar hawkinsp avatar hbq1 avatar mattjj avatar moerafaat avatar sharadmv avatar superbobry avatar tonywu95 avatar wangkuiyi avatar yashk2810 avatar zhangqiaorjc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jax-triton's Issues

Hopper Support?

Hey, I am running on latest main: d258f8f, using
"jax @ git+https://github.com/google/jax@d872812a359a3bafcfdeba1fcdb874ec77c209db",
"triton @ git+https://github.com/openai/triton@3452615d795bc0c69a189e41f1e775904e5659be#subdirectory=python"
When running on a hopper node I get the following error

    test_out = test_fn(**fn_kwargs, **extra_test_kwargs)    # My call to triton                                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback                                                                             
    return fun(*args, **kwargs)                                                                                                                                                                       
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/custom_derivatives.py", line 620, in __call__                                                                                                
    out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,                                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/custom_derivatives.py", line 770, in bind
    outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 840, in process_custom_vjp_call
    return fun.call_wrapped(*tracers)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 252, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 165, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2596, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 821, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1143, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 1228, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas fatal   : PTX with .target 'sm_90a' cannot be compiled for architecture 'sm_90'
; current tracing scope: custom-call.7; current profiling annotation: XlaModule:#hlo_module=jit_attention,program_id=22#.

Is hopper supported?

Is scan supported in pallas?

I have a kernel code that contains jax.lax.map. It runs fine with interpret=True, however lowering to triton fails with the following error:

E         jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn:
E           a:i32[6] = scan[
E           jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let
E               d:i32[] = mul c 64
E               e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E               f:i32[64] = add e d
E               g:bool[64] = lt f 384
E               h:i32[64] <- b[d:d+64]
E               i:i32[] = reduce_min[axes=(0,)] h
E             in (i,) }
E           length=6
E           linear=(False, False)
E           num_carry=0
E           num_consts=1
E           reverse=False
E           unroll=1
E         ] j k

Is it because scan is not supported or is there some other problem? Happy to provide more details if necessary.

fail to install jax_triton though pip

I just use command "pip install jax_triton" and get such error

Building wheels for collected packages: jax-triton
  Building wheel for jax-triton (pyproject.toml) ... error
  error: subprocess-exited-with-error
  
  × Building wheel for jax-triton (pyproject.toml) did not run successfully.
  │ exit code: 1
  ╰─> [105 lines of output]
      running bdist_wheel
      running build
      running build_py
      creating build
      creating build/lib.linux-x86_64-cpython-310
      creating build/lib.linux-x86_64-cpython-310/jax_triton
      copying jax_triton/triton_lib.py -> build/lib.linux-x86_64-cpython-310/jax_triton
      copying jax_triton/__init__.py -> build/lib.linux-x86_64-cpython-310/jax_triton
      creating build/lib.linux-x86_64-cpython-310/jax_triton/pallas
      copying jax_triton/pallas/lowering.py -> build/lib.linux-x86_64-cpython-310/jax_triton/pallas
      copying jax_triton/pallas/core.py -> build/lib.linux-x86_64-cpython-310/jax_triton/pallas
      copying jax_triton/pallas/__init__.py -> build/lib.linux-x86_64-cpython-310/jax_triton/pallas
      copying jax_triton/pallas/pallas_call.py -> build/lib.linux-x86_64-cpython-310/jax_triton/pallas
      copying jax_triton/pallas/primitives.py -> build/lib.linux-x86_64-cpython-310/jax_triton/pallas
      creating build/lib.linux-x86_64-cpython-310/jax_triton/pallas/ops
      copying jax_triton/pallas/ops/attention.py -> build/lib.linux-x86_64-cpython-310/jax_triton/pallas/ops
      copying jax_triton/pallas/ops/__init__.py -> build/lib.linux-x86_64-cpython-310/jax_triton/pallas/ops
      running egg_info
      writing jax_triton.egg-info/PKG-INFO
      writing dependency_links to jax_triton.egg-info/dependency_links.txt
      writing requirements to jax_triton.egg-info/requires.txt
      writing top-level names to jax_triton.egg-info/top_level.txt
      reading manifest file 'jax_triton.egg-info/SOURCES.txt'
      adding license file 'LICENSE'
      writing manifest file 'jax_triton.egg-info/SOURCES.txt'
      /tmp/pip-build-env-jycc8uf9/overlay/lib/python3.10/site-packages/setuptools/command/build_py.py:202: SetuptoolsDeprecationWarning:     Installing 'jax_triton.experimental.fusion' as data is deprecated, please list it in `packages`.
          !!
      
      
          ############################
          # Package would be ignored #
          ############################
          Python recognizes 'jax_triton.experimental.fusion' as an importable package,
          but it is not listed in the `packages` configuration of setuptools.
      
          'jax_triton.experimental.fusion' has been automatically added to the distribution only
          because it may contain data files, but this behavior is likely to change
          in future versions of setuptools (and therefore is considered deprecated).
      
          Please make sure that 'jax_triton.experimental.fusion' is included as a package by using
          the `packages` configuration field or the proper discovery methods
          (for example by using `find_namespace_packages(...)`/`find_namespace:`
          instead of `find_packages(...)`/`find:`).
      
          You can read more about "package discovery" and "data files" on setuptools
          documentation page.
      
      
      !!
      
        check.warn(importable)
      creating build/lib.linux-x86_64-cpython-310/jax_triton/experimental
      creating build/lib.linux-x86_64-cpython-310/jax_triton/experimental/fusion
      copying jax_triton/experimental/fusion/__init__.py -> build/lib.linux-x86_64-cpython-310/jax_triton/experimental/fusion
      copying jax_triton/experimental/fusion/fusion.py -> build/lib.linux-x86_64-cpython-310/jax_triton/experimental/fusion
      copying jax_triton/experimental/fusion/jaxpr_rewriter.py -> build/lib.linux-x86_64-cpython-310/jax_triton/experimental/fusion
      copying jax_triton/experimental/fusion/lowering.py -> build/lib.linux-x86_64-cpython-310/jax_triton/experimental/fusion
      running build_ext
      building 'jax_triton.triton_kernel_call_lib' extension
      creating build/temp.linux-x86_64-cpython-310
      creating build/temp.linux-x86_64-cpython-310/lib
      x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/usr/local/cuda/include -I/tmp/pip-build-env-jycc8uf9/overlay/lib/python3.10/site-packages/pybind11/include -I/usr/local/google/home/johnqiangzhang/Documents/test_py_env/include -I/usr/include/python3.10 -c lib/triton_kernel_call.cc -o build/temp.linux-x86_64-cpython-310/lib/triton_kernel_call.o -std=c++17 -v
      Using built-in specs.
      COLLECT_GCC=/usr/bin/x86_64-linux-gnu-gcc
      OFFLOAD_TARGET_NAMES=nvptx-none:amdgcn-amdhsa
      OFFLOAD_TARGET_DEFAULT=1
      Target: x86_64-linux-gnu
      Configured with: ../src/configure -v --with-pkgversion='Debian 12.2.0-10' --with-bugurl=file:///usr/share/doc/gcc-12/README.Bugs --enable-languages=c,ada,c++,go,d,fortran,objc,obj-c++,m2 --prefix=/usr --with-gcc-major-version-only --program-suffix=-12 --program-prefix=x86_64-linux-gnu- --enable-shared --enable-linker-build-id --libexecdir=/usr/lib --without-included-gettext --enable-threads=posix --libdir=/usr/lib --enable-nls --enable-clocale=gnu --enable-libstdcxx-debug --enable-libstdcxx-time=yes --with-default-libstdcxx-abi=new --enable-gnu-unique-object --disable-vtable-verify --enable-plugin --enable-default-pie --with-system-zlib --enable-libphobos-checking=release --with-target-system-zlib=auto --enable-objc-gc=auto --enable-multiarch --disable-werror --enable-cet --with-arch-32=i686 --with-abi=m64 --with-multilib-list=m32,m64,mx32 --enable-multilib --with-tune=generic --enable-offload-targets=nvptx-none=/build/gcc-12-hWCYKv/gcc-12-12.2.0/debian/tmp-nvptx/usr,amdgcn-amdhsa=/build/gcc-12-hWCYKv/gcc-12-12.2.0/debian/tmp-gcn/usr --enable-offload-defaulted --without-cuda-driver --enable-checking=release --build=x86_64-linux-gnu --host=x86_64-linux-gnu --target=x86_64-linux-gnu
      Thread model: posix
      Supported LTO compression algorithms: zlib zstd
      gcc version 12.2.0 (Debian 12.2.0-10)
      COLLECT_GCC_OPTIONS='-Wno-unused-result' '-Wsign-compare' '-D' 'NDEBUG' '-g' '-O2' '-Wall' '-g' '-fstack-protector-strong' '-Wformat=1' '-Werror=format-security' '-g' '-fwrapv' '-O2' '-fPIC' '-I' '/usr/local/cuda/include' '-I' '/tmp/pip-build-env-jycc8uf9/overlay/lib/python3.10/site-packages/pybind11/include' '-I' '/usr/local/google/home/johnqiangzhang/Documents/test_py_env/include' '-I' '/usr/include/python3.10' '-c' '-o' 'build/temp.linux-x86_64-cpython-310/lib/triton_kernel_call.o' '-std=c++17' '-v' '-mtune=generic' '-march=x86-64' '-dumpdir' 'build/temp.linux-x86_64-cpython-310/lib/'
       /usr/lib/gcc/x86_64-linux-gnu/12/cc1plus -quiet -v -I /usr/local/cuda/include -I /tmp/pip-build-env-jycc8uf9/overlay/lib/python3.10/site-packages/pybind11/include -I /usr/local/google/home/johnqiangzhang/Documents/test_py_env/include -I /usr/include/python3.10 -imultiarch x86_64-linux-gnu -D_GNU_SOURCE -D NDEBUG lib/triton_kernel_call.cc -quiet -dumpdir build/temp.linux-x86_64-cpython-310/lib/ -dumpbase triton_kernel_call.cc -dumpbase-ext .cc -mtune=generic -march=x86-64 -g -g -g -O2 -O2 -Wno-unused-result -Wsign-compare -Wall -Wformat=1 -Werror=format-security -std=c++17 -version -fstack-protector-strong -fwrapv -fPIC -fasynchronous-unwind-tables -o /tmp/ccFv4c8g.s
      GNU C++17 (Debian 12.2.0-10) version 12.2.0 (x86_64-linux-gnu)
          compiled by GNU C version 12.2.0, GMP version 6.2.1, MPFR version 4.1.0, MPC version 1.2.1, isl version isl-0.25-GMP
      
      GGC heuristics: --param ggc-min-expand=100 --param ggc-min-heapsize=131072
      ignoring duplicate directory "/usr/include/x86_64-linux-gnu/c++/12"
      ignoring nonexistent directory "/usr/local/include/x86_64-linux-gnu"
      ignoring nonexistent directory "/usr/lib/gcc/x86_64-linux-gnu/12/include-fixed"
      ignoring nonexistent directory "/usr/lib/gcc/x86_64-linux-gnu/12/../../../../x86_64-linux-gnu/include"
      ignoring nonexistent directory "/usr/local/cuda/include"
      #include "..." search starts here:
      #include <...> search starts here:
       /tmp/pip-build-env-jycc8uf9/overlay/lib/python3.10/site-packages/pybind11/include
       /usr/local/google/home/johnqiangzhang/Documents/test_py_env/include
       /usr/include/python3.10
       /usr/include/c++/12
       /usr/include/x86_64-linux-gnu/c++/12
       /usr/include/c++/12/backward
       /usr/lib/gcc/x86_64-linux-gnu/12/include
       /usr/local/include
       /usr/include/x86_64-linux-gnu
       /usr/include
      End of search list.
      GNU C++17 (Debian 12.2.0-10) version 12.2.0 (x86_64-linux-gnu)
          compiled by GNU C version 12.2.0, GMP version 6.2.1, MPFR version 4.1.0, MPC version 1.2.1, isl version isl-0.25-GMP
      
      GGC heuristics: --param ggc-min-expand=100 --param ggc-min-heapsize=131072
      Compiler executable checksum: f18a0f32bd70b25f70021e93ee28005d
      lib/triton_kernel_call.cc:28:10: fatal error: cuda.h: No such file or directory
         28 | #include "cuda.h"
            |          ^~~~~~~~
      compilation terminated.
      error: command '/usr/bin/x86_64-linux-gnu-gcc' failed with exit code 1
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for jax-triton
Failed to build jax-triton
ERROR: Could not build wheels for jax-triton, which is required to install pyproject.toml-based projects

Any comments ? It looks something related to cuda.
Thanks

Add switch and cond lowerings for pallas

I'd like to evaluate the performance improvement of pushing control flow into the kernel for a workload with some lax.switch statements. Can we get support for lowering these primitives?

Add support for runtime scalar kernel inputs (i.e. non-constexpr).

Triton supports runtime scalar inputs, but these are not exposed by jax-triton. For example, in Triton's matmul example, m, n, and k are all runtime inputs, but in the jax-triton example they are constexpr. This forces a recompile, even though there is negligible runtime performance benefit.

Incorrect dtype cast in Matmul Example?

In the example, jax-triton/examples/matmul.py on line 96,

c = accumulator.to(tl.float16)

However, on line 120, the dtype of the output is set to the same as the input i.e. jnp.float32

out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=a.dtype)

Surely c should be just set to accumulator (which is a float32).

Next release?

The current version of the pip package is incompatible with jax>0.4.5 due to the removal of the ShapedArrayRef type from jax._src.state.types. Will the pip package be updated soon to allow compatibility with the latest version of jax again?

Example notebook with attention isn't working

I've tried running the example notebook using jax_triton and jaxlib installed from head. Unfortunately, it doesn't seem to work: running the cell with test_triton_jax(2, 32, 2048, 64) hangs indefinitely.

  • Should the example work?
  • Is there a tested version of flax attention that is guaranteed to work with latest jax_triton?

Thanks!

Problems Running Jax-Triton with an Nvidia 4090

Running the quick start example using an Nvidia 4090, if you use the suggested triton version (2.0.0.dev20221202), you receive the following error

RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/filePpgwy4, line 6; error   : PTX .version 7.4 does not support .target sm_89
ptxas fatal   : Ptx assembly aborted due to errors

Upgrading the latest development version of triton, it is possible to run Pytorch based examples, but JAX-ML results in the following error

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 135, in <module>
    print(add(x_val, y_val))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 124, in add
    return jt.triton_call(
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 531, in triton_call
    out_flat = triton_kernel_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AttributeError: module 'triton.compiler' has no attribute '_compile'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 135, in <module>
    print(add(x_val, y_val))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/main.py", line 124, in add
    return jt.triton_call(
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 531, in triton_call
    out_flat = triton_kernel_call_p.bind(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/core.py", line 807, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 122, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/util.py", line 254, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/util.py", line 247, in cached
    return f(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 201, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 353, in _xla_callable_uncached
    computation = sharded_lowering(fun, device, backend, name, donated_invars,
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/dispatch.py", line 343, in sharded_lowering
    return pxla.lower_sharding_computation(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 3082, in lower_sharding_computation
    lowering_result = mlir.lower_jaxpr_to_module(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 742, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1044, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "/home/adam/anaconda3/envs/triton/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1179, in jaxpr_subcomp
    ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 320, in triton_kernel_call_lowering
    kernel, specialization = get_or_create_triton_kernel(
  File "/home/adam/Downloads/jax-triton/jax_triton/triton_lib.py", line 173, in get_or_create_triton_kernel
    asm, shared_mem, name = tc._compile(
AttributeError: module 'triton.compiler' has no attribute '_compile'. Did you mean: 'compile'?

AttributeError: module 'jaxlib.gpu_triton' has no attribute 'TritonKernel'

Following the instructions,

https://jax-ml.github.io/jax-triton/#installation-at-head

i.e. pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'

which install jax=0.4.11, jaxlib=0.4.11, triton=2.1.0, jax-trition=0.1.4,

Running on Ubuntu 22.04, Python 3.9.16, with CUDA 12.1

I get the following error

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 15, in <module>
    from jax_triton import pallas as pl
  File "/home/adam/anaconda3/envs/triton/lib/python3.9/site-packages/jax_triton/__init__.py", line 19, in <module>
    from jax_triton.triton_lib import triton_call
  File "/home/adam/anaconda3/envs/triton/lib/python3.9/site-packages/jax_triton/triton_lib.py", line 185, in <module>
    ) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
AttributeError: module 'jaxlib.gpu_triton' has no attribute 'TritonKernel'

I also tried with jax/jaxlib=0.4.12, with the same result.

I have also tried a new environment running CUDA 11.8 with the everything else the same, with the same problem.

Pallas Broken after making JAX-Triton calls serializable update

The new update has broken pallas again with the error shown below. I have tried updating jax / jaxlib to head and the issue persists.

I0624 11:43:47.866600 139777320678464 xla_bridge.py:568] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
I0624 11:43:47.866749 139777320678464 xla_bridge.py:568] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0624 11:43:47.866775 139777320678464 xla_bridge.py:568] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(4, 8)
Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 231, in <module>
    app.run(main)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 226, in main
    print(pl.pallas_call(kernel1, out_shape=out_shape, grid=grid)(x, y))
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/pallas_call.py", line 352, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
    1. jaxlib.cuda._triton.TritonKernel(arg0: str, arg1: str, arg2: int, arg3: int)

pallas_error.txt

triton dtype mapping for float inputs

I have a question about this error:

File /usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1069, in ast_to_ttir(fn, signature, specialization, constants, debug)
   1067 all_constants = constants.copy()
   1068 all_constants.update(new_constants)
-> 1069 arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
   1071 prototype = language.function_type([], arg_types)
   1072 generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
   1073                           function_name=function_name, attributes=new_attrs,
   1074                           is_kernel=True, debug=debug)

File /usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1069, in <listcomp>(.0)
   1067 all_constants = constants.copy()
   1068 all_constants.update(new_constants)
-> 1069 arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
   1071 prototype = language.function_type([], arg_types)
   1072 generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
   1073                           function_name=function_name, attributes=new_attrs,
   1074                           is_kernel=True, debug=debug)

File /usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1036, in str_to_ty(name)
   1017     return language.pointer_type(ty)
   1018 tys = {
   1019     "fp8e5": language.float8e5,
   1020     "fp8e4": language.float8e4,
   (...)
   1034     "B": language.int1,
   1035 }
-> 1036 return tys[name]

KeyError: 'f'

Floats are mapped to f here but f is not found in the jax to triton mapping here.

Additionally, adding tys['f'] = language.float32 results in

TypeError: create_scalar_parameter(): incompatible function arguments. The following argument types are supported:
    1. (arg0: bool, arg1: str) -> jaxlib.cuda._triton.TritonParameter
    2. (arg0: int, arg1: str) -> jaxlib.cuda._triton.TritonParameter

Attached is a minimal repro:
add.txt

Error Using Matrices For the Pallas Template Example

The Pallas Template example works fine with an input, x, of shape (1,). However, when attempting to generalise the pallas template example (as shown in the code below), it produces the following error,

    x = pl.load(x_ref, ())
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/primitives.py", line 477, in load
    idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/primitives.py", line 278, in from_indices_shape
    else i for i, s in zip(indices, shape))
ValueError: safe_zip() argument 2 is longer than argument 1

Minimal example:

from jax_triton import pallas as pl
import jax
import jax.numpy as jnp
import jax.random as jnr

def make_kernel(eltwise_kernel):
  def add(x_ref, y_ref, o_ref):
    x = pl.load(x_ref, ())
    y = pl.load(y_ref, ())
    pl.store(o_ref, (), eltwise_kernel(x + y))
  return add

kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)

def main(unused_argv):
  # x = jnp.array(1.)

  key = jnr.PRNGKey(0)
  key, subkey = jnr.split(key)

  x = jnr.normal(key, (100, 100))
  y = jnr.normal(subkey, (100, 100))

  m, k = x.shape
  n, _ = y.shape
  out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=x.dtype)

  block_size_m = 10
  block_size_n = 10
  grid = (m // block_size_m * n // block_size_n,)

  print(pl.pallas_call(kernel1, out_shape=out_shape, grid=grid)(x, y))
  print(pl.pallas_call(kernel2, out_shape=out_shape, grid=grid)(x, y))

if __name__ == "__main__":
  from absl import app
  app.run(main)

The full trace is as follows,

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 53, in <module>
    app.run(main)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 48, in main
    print(pl.pallas_call(kernel1, out_shape=out_shape, grid=grid)(x, y))
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/api.py", line 300, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/pjit.py", line 499, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/pjit.py", line 961, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/pjit.py", line 914, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/pallas_call.py", line 348, in wrapped
    jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree,
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/pallas_call.py", line 308, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 22, in add
    x = pl.load(x_ref, ())
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/primitives.py", line 477, in load
    idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/primitives.py", line 278, in from_indices_shape
    else i for i, s in zip(indices, shape))
jax._src.traceback_util.UnfilteredStackTrace: ValueError: safe_zip() argument 2 is longer than argument 1
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 48, in main
    print(pl.pallas_call(kernel1, out_shape=out_shape, grid=grid)(x, y))
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/pallas_call.py", line 348, in wrapped
    jaxpr, consts, _ = _initial_style_open_jaxpr(f, jaxpr_in_tree,
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/pallas_call.py", line 308, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 22, in add
    x = pl.load(x_ref, ())
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/primitives.py", line 477, in load
    idx = NDIndexer.from_indices_shape(idx, x_ref.shape)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.10/site-packages/jax_triton/pallas/primitives.py", line 278, in from_indices_shape
    else i for i, s in zip(indices, shape))
ValueError: safe_zip() argument 2 is longer than argument 1
python-BaseException

ptxas unsupported version: Is CUDA 11.8 support broken in HEAD install?

Hey everyone, thanks for developing this library. I'd like to use block sparse matmul with jax, and this project seems to deliver just what we need 👍 Yet, I am having trouble getting examples/pallas/blocksparse_matmul.py to run. When installing from HEAD, I run into compatibility problems with ptxas. Help with this would be much appreciated.

As far as I understand ptxas version 7.8 is shipped with CUDA 11.8, and 8.0 with CUDA 12.0. As noted below, I installed jaxlib with local CUDA 11.8. Considering the traceback below, I am wondering if jax-triton requires CUDA 12 in its current form? In this case, I would be happy to get a recommendation for jax, and jax-triton commits to install from source.

Traceback

2023-07-27 13:17:05.422836: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors

2023-07-27 13:17:05.422953: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2537] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.
Traceback (most recent call last):
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 227, in <module>
    app.run(main)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 201, in main
    sdd_matmul(x, y, bn=bn, debug=True).block_until_ready()
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 2578, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 382, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/core.py", line 814, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1223, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1207, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 1163, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1344, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 227, in <module>
    app.run(main)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/filesystem/my/workspace/pythonenv/jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/filesystem/my/workspace/Code/BlockSparse/jax_triton_blocksparse.py", line 201, in main
    sdd_matmul(x, y, bn=bn, debug=True).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-taurusi8032-f5aaecd1-56106-601761aa90836, line 5; fatal   : Unsupported .version 8.0; current version is '7.8'
ptxas fatal   : Ptx assembly aborted due to errors
; current tracing scope: custom-call.0; current profiling annotation: XlaModule:#hlo_module=jit_sdd_matmul,program_id=292#.

Environment

Working on a university cluster with installed modules for Python 3.10, CUDA 11.8 and cuDNN 8.6. Upon loading the modules, they appear in the $PATH, and $CUDA_HOME is properly set to the directory (e.g. nvcc and ptxas are located here).

I installed jaxlib according to my cuda versions:

pip install "jaxlib @ https://storage.googleapis.com/jax-releases/nightly/cuda118/jaxlib-0.4.14.dev20230714+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"

Then installed jax-triton as recommended from HEAD.

pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'

Pip list shows (selection):

jax                               0.4.14
jax-triton                        0.1.4
jaxlib                            0.4.14.dev20230714+cuda11.cudnn86
triton-nightly                    2.1.0.dev20230714011643

executing ptxas --version yields

bash$ ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:31:59_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

(Side note: Using a stable jaxlib (0.4.13) and jax-triton stable (0.1.3) yields the already reported import error #157 )

Pallas Templating Example Error

When running the pallas templating example (https://github.com/jax-ml/jax-triton/blob/main/examples/pallas/templating.py) using Ubuntu 22.04 with Python 3.9, CUDA 11.8, JAX 0.48, nightly Triton / Jax-Triton (from main branch) on a Geforce 4090, I get the following error

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 67, in <module>
    app.run(main)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 45, in main
    dx, dw, db = f_grad(x, weight, bias)
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 37, in f
    return layer_norm.layer_norm(x, w, b).sum()
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/ops/layer_norm.py", line 100, in layer_norm_forward
    out, mean, rstd = method(x, weight, bias)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/pallas_call.py", line 397, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Cannot call @triton.jit'd outside of the scope of a kernel

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 67, in <module>
    app.run(main)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 45, in main
    dx, dw, db = f_grad(x, weight, bias)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/pjit.py", line 238, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/core.py", line 2592, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/pjit.py", line 1229, in _pjit_call_impl
    compiled = _pjit_lower(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/pjit.py", line 1315, in _pjit_lower
    return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, **kwargs)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/pjit.py", line 1374, in _pjit_lower_cached
    return pxla.lower_sharding_computation(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2551, in lower_sharding_computation
    lowering_result = mlir.lower_jaxpr_to_module(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py", line 742, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py", line 1037, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py", line 1172, in jaxpr_subcomp
    ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/pjit.py", line 1453, in _pjit_lowering
    func = mlir.lower_jaxpr_to_fun(
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py", line 1037, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py", line 1172, in jaxpr_subcomp
    ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 850, in pallas_call_lowering
    compilation_result = compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)),
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 817, in compile_jaxpr
    lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 173, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 212, in lower_jaxpr_to_triton_ir
    outvals = rule(rule_ctx, *invals, **eqn.params)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 646, in _reduce_lowering
    return triton_op(a, axis=axis, _builder=ctx.builder)
  File "/home/adam/Downloads/triton/python/triton/runtime/jit.py", line 395, in __call__
    raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Cannot call @triton.jit'd outside of the scope of a kernel

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 67, in <module>
    app.run(main)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/adam/anaconda3/envs/triton_test/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/adam/shared_drive/PycharmProjects/triton_test/pallas_test.py", line 45, in main
    dx, dw, db = f_grad(x, weight, bias)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 850, in pallas_call_lowering
    compilation_result = compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)),
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 817, in compile_jaxpr
    lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 173, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 212, in lower_jaxpr_to_triton_ir
    outvals = rule(rule_ctx, *invals, **eqn.params)
  File "/home/adam/Downloads/jax-triton/jax_triton/pallas/triton_ir_lowering.py", line 646, in _reduce_lowering
    return triton_op(a, axis=axis, _builder=ctx.builder)
  File "/home/adam/Downloads/triton/python/triton/runtime/jit.py", line 395, in __call__
    raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
RuntimeError: Cannot call @triton.jit'd outside of the scope of a kernel

Process finished with exit code 1

Bug in Matmul Example

In the matmul example, I believe there is a mistake in the code at line 119

m, k = a.shape
n, _ = b.shape

should be

m, k = a.shape
_, n = b.shape

Inconsistent behavior when using autotune

I've noticed that when I use autotune, kernels often start non-deterministically giving different results. For example, I've adapted the following code from the triton matmul tutorial:

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr,
    # Matrix dimensions
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
    # by to get the element one row down (A has M rows)
    stride_am: tl.constexpr,
    stride_ak: tl.constexpr,
    stride_bk: tl.constexpr,
    stride_bn: tl.constexpr,
    stride_cm: tl.constexpr,
    stride_cn: tl.constexpr,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse
    # See above `L2 Cache Optimizations` section for details
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
    # see above `Pointer Arithmetics` section for details
    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        # Note that for simplicity, we don't apply a mask here.
        # This means that if K is not a multiple of BLOCK_SIZE_K,
        # this will access out-of-bounds memory and produce an
        # error or (worse!) incorrect results.
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        # We accumulate along the K dimension
        accumulator += tl.dot(a, b)
        # Advance the ptrs to the next K block
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # you can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


M, K, N = 256, 256, 256
def jax_version():
    import jax
    import jax.numpy as jnp
    import jax_triton as jt
    def matmul(a, b):
        M, K = a.shape
        K, N = b.shape
        grid = lambda META: (
            triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
        )
        result = jt.triton_call(
            a, b,
            stride_am=K, stride_ak=1,
            stride_bk=N, stride_bn=1,
            stride_cm=N, stride_cn=1,
            M=M, N=N, K=K,
            kernel=matmul_kernel,
            out_shape=jax.ShapeDtypeStruct((M, N), jnp.float16),
            grid=grid
        )
        return result
    xkey, ykey = jax.random.split(jax.random.PRNGKey(0))
    x = jax.random.normal(xkey, (M, K), dtype=jnp.float16)
    y = jax.random.normal(ykey, (K, N), dtype=jnp.float16)

    for _ in range(10):
        result = matmul(x, y)
        print(jnp.min(result), jnp.max(result))

def torch_version():
    import torch
    def matmul(a, b):
        M, K = a.shape
        K, N = b.shape
        output = torch.empty((M, N), dtype=torch.float16, device='cuda')
        grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
        matmul_kernel[grid](
            a, b, output,
            M, N, K,
            a.stride(0), a.stride(1),
            b.stride(0), b.stride(1),
            output.stride(0), output.stride(1),
        )
        return output

    torch.manual_seed(0)
    x = torch.randn(M, K, dtype=torch.float16, device='cuda')
    y = torch.randn(K, N, dtype=torch.float16, device='cuda')
    for _ in range(10):
        result = matmul(x, y).cpu().numpy()
        print(result.min(), result.max())

Running torch_version() leads to the output:

-58.47 60.34
-58.47 60.34
-58.47 60.34
-58.47 60.34
-58.47 60.34
-58.47 60.34
-58.47 60.34
-58.47 60.34
-58.47 60.34
-58.47 60.34

Running jax_version() gives:

-68.4 64.44
nan nan
-68.4 64.44
nan nan
-68.4 64.44
nan nan
-68.4 64.44
nan nan
-68.4 64.44
nan nan

If I remove either of the configs in the autotune call, the JAX version instead has consistent behavior, so it must be something about autotune not the specific config being used. It's also possible that I'm not invoking the JAX version 100% correctly.

Problems in Installing at HEAD

Hi, thank you so much for releasing this wonderful library.

When I was trying to install jax-triton at HEAD, I got this error message: ERROR: triton_nightly-2.1.0.dev20230714011643-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl is not a supported wheel on this platform.

My machine is x86_64 Linux, with the below cuda env
image

Could you please help me with this? Thank you very much for your time and help!

Distributed array support?

I am using distributed array with partitionspec in jax.
I am wondering this package are well working with distributed array.

Recently installed version of jax-triton is broken

Error:

      if kernel is None:
        # TODO(sharadmv): handle multiple devices, right now we assume device 0
        # which is fine when we have multiple of the same GPU but this won't work in
        # general.
        device = 0
        arch = triton_kernel_call_lib.get_compute_capability(device)
>       module = code_gen.ast_to_ttir(
            fn, signature, specialization, constants, debug=dump, arch=arch
        )
E       TypeError: ast_to_ttir() got an unexpected keyword argument 'arch'

/usr/local/lib/python3.10/dist-packages/jax_triton/triton_lib.py:215: TypeError

I've installed jax-triton from git+https://github.com/jax-ml/jax-triton.git@56ffd00b7355a425bf1cc95c8b2aca45c33bed1f'

Can you advise on how to deal with it?

Question about Purpose of Pallas

If I understand it correctly, the idea of Pallas is to provide a level of abstraction from Triton, enabling one to define a kernel using JAX functions.

Obviously, Triton doesn't provide any AutoDiff functionality, it is just a way of interacting with the GPU memory in a more user-friendly way than CUDA (and ultimately plan is enable it to also work on non-Nvidia hardware).

Is the idea of Pallas to seamlessly provide AutoDiff as well, so that any kernels defined will come with the ability to take gradients (in the way you can with normal JAX functions)?

cannot import name 'code_generator' from 'triton.compiler'

I'm attempting to run the fused_attention_kernel and it gives an import error. It is installed as:

!pip install -U git+https://github.com/jax-ml/jax-triton.git
Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/jax-ml/jax-triton.git
  Cloning https://github.com/jax-ml/jax-triton.git to /tmp/pip-req-build-ekxn8af8
  Running command git clone --filter=blob:none --quiet https://github.com/jax-ml/jax-triton.git /tmp/pip-req-build-ekxn8af8
  Resolved https://github.com/jax-ml/jax-triton.git to commit c74809e4175d1c47468093ef81c43a71011e2339
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: absl-py>=1.4.0 in ./.local/lib/python3.10/site-packages (from jax-triton==0.1.4) (1.4.0)
Requirement already satisfied: jax>=0.4.2 in ./.local/lib/python3.10/site-packages (from jax-triton==0.1.4) (0.4.2)
Requirement already satisfied: triton==2.0.0a2 in ./.local/lib/python3.10/site-packages (from jax-triton==0.1.4) (2.0.0a2)

But when I attempt to do import jax_triton as jt it gives an error:

ImportError                               Traceback (most recent call last)

Cell In [16], line 28
     25 for x in getmembers(tc, isfunction):
     26   print(x)
---> 28 import jax_triton as jt
     30 def _strides(shape):
     31   size = np.prod(shape)

File ~/.local/lib/python3.10/site-packages/jax_triton/__init__.py:19
     17 from jax_triton.utils import next_power_of_2
     18 from jax_triton.utils import strides_from_shape
---> 19 from jax_triton.triton_lib import triton_call
     20 from jax_triton.triton_lib import triton_kernel_call_lib
     21 from jax_triton.version import __version__

File ~/.local/lib/python3.10/site-packages/jax_triton/triton_lib.py:47
     45 try:
     46   import triton
---> 47   from triton.compiler import code_generator as code_gen
     48   from triton.compiler import compiler as tc
     49   import triton.language as tl

ImportError: cannot import name 'code_generator' from 'triton.compiler'

I tried using the head triton, but it still gives the same error.

Import error encountered in jax_triton

Hello, I was running jax_triton on A100 and CUDA 12.2, but when I run the command python -c 'import jax_triton as jt', error occurs:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/lustre/grp/gyqlab/liyh/debugs/jax-triton/jax_triton/__init__.py", line 19, in <module>
    from jax_triton.triton_lib import triton_call
  File "/lustre/grp/gyqlab/liyh/debugs/jax-triton/jax_triton/triton_lib.py", line 50, in <module>
    from triton._C.libtriton import ir as tl_ir
ImportError: cannot import name 'ir' from 'triton._C.libtriton' (/lustre/grp/gyqlab/liyh/anaconda3/envs/jax_triton3/lib/python3.10/site-packages/triton/_C/libtriton.so)

My jax_triton was installed following google/jax#18603

Pallas upstream is now working?

I saw the pallas concept in official latest jax docs, and follow up the pallas quickstart section.

I installed latest jaxlib and jax using github head.

I encountered the following error.

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/sh0416/research/scripts/pallas_quickstart.py", line 20, in <module>
    print(add_vectors(jnp.arange(8), jnp.arange(8)))
  File "/home/sh0416/research/scripts/pallas_quickstart.py", line 17, in add_vectors
    return pl.pallas_call(add_vectors_kernel, out_shape=out_shape)(x, y)
  File "/home/sh0416/research/venv/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 353, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cuda

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/sh0416/research/scripts/pallas_quickstart.py", line 20, in <module>
    print(add_vectors(jnp.arange(8), jnp.arange(8)))
NotImplementedError: MLIR translation rule for primitive 'pallas_call' not found for platform cuda

Do I have to install something different? or has it not fully upstreamed yet?

Changes required for AMD/rocm support with Triton

There is now support for flash-attention2 on AMD GPUs with PyTorch. They use the triton kernels for the same.

https://github.com/ROCmSoftwarePlatform/flash-attention

JAX-Triton currently doesn't work. On trying the add example I get the following error which I suspect is due to some CUDA specific things in the triton_lib.py

I can run other tests etc. that are requested here to help make progress on this.

(/jax_miniconda) Singularity> python add.py 
2024-01-22 13:12:41.578159: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
2024-01-22 13:12:46.635298: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: //
// Generated by LLVM NVPTX Back-End
//

.version 8.2
.target sm_90a
.address_size 64

	// .globl	add_kernel_0d1d2d

.visible .entry add_kernel_0d1d2d(
	.param .u64 add_kernel_0d1d2d_param_0,
	.param .u64 add_kernel_0d1d2d_param_1,
	.param .u64 add_kernel_0d1d2d_param_2
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<5>;
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<8>;
	.loc	1 24 0
$L__func_begin0:
	.loc	1 24 0

	ld.param.u64 	%rd4, [add_kernel_0d1d2d_param_0];
	ld.param.u64 	%rd5, [add_kernel_0d1d2d_param_1];
$L__tmp0:
	.loc	1 33 39
	mov.u32 	%r5, %tid.x;
	and.b32  	%r6, %r5, 7;
	ld.param.u64 	%rd6, [add_kernel_0d1d2d_param_2];
	.loc	1 31 22
	mov.u32 %r1, %ctaid.x;
	.loc	1 32 22
	shl.b32 	%r7, %r1, 3;
	.loc	1 33 26
	or.b32  	%r8, %r7, %r6;
	.loc	1 34 19
	setp.lt.s32 	%p1, %r8, 8;
	.loc	1 35 22
	mul.wide.s32 	%rd7, %r8, 4;
	add.s64 	%rd1, %rd4, %rd7;
	.loc	1 35 14
	mov.u32 %r2, 0x0;
	@%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];
	.loc	1 36 22
	add.s64 	%rd2, %rd5, %rd7;
	.loc	1 36 14
	mov.u32 %r3, 0x0;
	@%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ];
	.loc	1 37 15
	add.s32 	%r4, %r3, %r2;
	.loc	1 38 24
	add.s64 	%rd3, %rd6, %rd7;
	.loc	1 38 33
	and.b32  	%r9, %r5, 120;
	setp.eq.s32 	%p4, %r9, 0;
	and.pred  	%p3, %p4, %p1;
	@%p3 st.global.b32 [ %rd3 + 0 ], { %r4 };
	.loc	1 38 2
	ret;
$L__tmp1:
$L__func_end0:

}
	.file	1 "/jax_miniconda/add.py"
	.section	.debug_abbrev
	{
.b8 1
.b8 17
.b8 1
.b8 37
.b8 8
.b8 19
.b8 5
.b8 3
.b8 8
.b8 16
.b8 6
.b8 27
.b8 8
.b8 180
.b8 66
.b8 12
.b8 17
.b8 1
.b8 18
.b8 1
.b8 0
.b8 0
.b8 2
.b8 46
.b8 0
.b8 17
.b8 1
.b8 18
.b8 1
.b8 64
.b8 10
.b8 135
.b8 64
.b8 8
.b8 3
.b8 8
.b8 58
.b8 11
.b8 59
.b8 11
.b8 63
.b8 12
.b8 0
.b8 0
.b8 0
	}
	.section	.debug_info
	{
.b32 119
.b8 2
.b8 0
.b32 .debug_abbrev
.b8 8
.b8 1
.b8 116
.b8 114
.b8 105
.b8 116
.b8 111
.b8 110
.b8 0
.b8 2
.b8 0
.b8 97
.b8 100
.b8 100
.b8 46
.b8 112
.b8 121
.b8 0
.b32 .debug_line
.b8 47
.b8 106
.b8 97
.b8 120
.b8 95
.b8 109
.b8 105
.b8 110
.b8 105
.b8 99
.b8 111
.b8 110
.b8 100
.b8 97
.b8 0
.b8 1
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 2
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 1
.b8 156
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 1
.b8 24
.b8 1
.b8 0
	}
	.section	.debug_pubnames
	{
.b32 $L__pubNames_end0-$L__pubNames_start0
$L__pubNames_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 64
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b32 0
$L__pubNames_end0:
	}
	.section	.debug_pubtypes
	{
.b32 $L__pubTypes_end0-$L__pubTypes_start0
$L__pubTypes_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 0
$L__pubTypes_end0:
	}
	.section	.debug_loc	{	}
; No such file or directory
2024-01-22 13:12:46.635704: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2716] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: //
// Generated by LLVM NVPTX Back-End
//

.version 8.2
.target sm_90a
.address_size 64

	// .globl	add_kernel_0d1d2d

.visible .entry add_kernel_0d1d2d(
	.param .u64 add_kernel_0d1d2d_param_0,
	.param .u64 add_kernel_0d1d2d_param_1,
	.param .u64 add_kernel_0d1d2d_param_2
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<5>;
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<8>;
	.loc	1 24 0
$L__func_begin0:
	.loc	1 24 0

	ld.param.u64 	%rd4, [add_kernel_0d1d2d_param_0];
	ld.param.u64 	%rd5, [add_kernel_0d1d2d_param_1];
$L__tmp0:
	.loc	1 33 39
	mov.u32 	%r5, %tid.x;
	and.b32  	%r6, %r5, 7;
	ld.param.u64 	%rd6, [add_kernel_0d1d2d_param_2];
	.loc	1 31 22
	mov.u32 %r1, %ctaid.x;
	.loc	1 32 22
	shl.b32 	%r7, %r1, 3;
	.loc	1 33 26
	or.b32  	%r8, %r7, %r6;
	.loc	1 34 19
	setp.lt.s32 	%p1, %r8, 8;
	.loc	1 35 22
	mul.wide.s32 	%rd7, %r8, 4;
	add.s64 	%rd1, %rd4, %rd7;
	.loc	1 35 14
	mov.u32 %r2, 0x0;
	@%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];
	.loc	1 36 22
	add.s64 	%rd2, %rd5, %rd7;
	.loc	1 36 14
	mov.u32 %r3, 0x0;
	@%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ];
	.loc	1 37 15
	add.s32 	%r4, %r3, %r2;
	.loc	1 38 24
	add.s64 	%rd3, %rd6, %rd7;
	.loc	1 38 33
	and.b32  	%r9, %r5, 120;
	setp.eq.s32 	%p4, %r9, 0;
	and.pred  	%p3, %p4, %p1;
	@%p3 st.global.b32 [ %rd3 + 0 ], { %r4 };
	.loc	1 38 2
	ret;
$L__tmp1:
$L__func_end0:

}
	.file	1 "/jax_miniconda/add.py"
	.section	.debug_abbrev
	{
.b8 1
.b8 17
.b8 1
.b8 37
.b8 8
.b8 19
.b8 5
.b8 3
.b8 8
.b8 16
.b8 6
.b8 27
.b8 8
.b8 180
.b8 66
.b8 12
.b8 17
.b8 1
.b8 18
.b8 1
.b8 0
.b8 0
.b8 2
.b8 46
.b8 0
.b8 17
.b8 1
.b8 18
.b8 1
.b8 64
.b8 10
.b8 135
.b8 64
.b8 8
.b8 3
.b8 8
.b8 58
.b8 11
.b8 59
.b8 11
.b8 63
.b8 12
.b8 0
.b8 0
.b8 0
	}
	.section	.debug_info
	{
.b32 119
.b8 2
.b8 0
.b32 .debug_abbrev
.b8 8
.b8 1
.b8 116
.b8 114
.b8 105
.b8 116
.b8 111
.b8 110
.b8 0
.b8 2
.b8 0
.b8 97
.b8 100
.b8 100
.b8 46
.b8 112
.b8 121
.b8 0
.b32 .debug_line
.b8 47
.b8 106
.b8 97
.b8 120
.b8 95
.b8 109
.b8 105
.b8 110
.b8 105
.b8 99
.b8 111
.b8 110
.b8 100
.b8 97
.b8 0
.b8 1
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 2
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 1
.b8 156
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 1
.b8 24
.b8 1
.b8 0
	}
	.section	.debug_pubnames
	{
.b32 $L__pubNames_end0-$L__pubNames_start0
$L__pubNames_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 64
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b32 0
$L__pubNames_end0:
	}
	.section	.debug_pubtypes
	{
.b32 $L__pubTypes_end0-$L__pubTypes_start0
$L__pubTypes_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 0
$L__pubTypes_end0:
	}
	.section	.debug_loc	{	}
; No such file or directory; current tracing scope: custom-call.3; current profiling annotation: XlaModule:#prefix=jit(triton_kernel_call)/jit(main)/triton_kernel_call[fn=JITFunction(__main__:add_kernel) scalar_args=() name= custom_call_target_name=triton_kernel_call out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=int32),) grid=(1,) num_warps=None num_stages=None num_ctas=1 enable_fp_fusion=True enable_warp_specialization=False enable_persistent=False input_output_aliases=() zeroed_outputs=() debug=False serialized_metadata=b'' block_size=8],hlo_module=jit_triton_kernel_call,program_id=2#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/jax_miniconda/add.py", line 56, in <module>
    print(add(x_val, y_val))
  File "/jax_miniconda/add.py", line 44, in add
    return jt.triton_call(
  File "/jax_miniconda/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 681, in triton_call
    out_flat = triton_kernel_call_p.bind(
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 402, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 405, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 893, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: //
// Generated by LLVM NVPTX Back-End
//

.version 8.2
.target sm_90a
.address_size 64

	// .globl	add_kernel_0d1d2d

.visible .entry add_kernel_0d1d2d(
	.param .u64 add_kernel_0d1d2d_param_0,
	.param .u64 add_kernel_0d1d2d_param_1,
	.param .u64 add_kernel_0d1d2d_param_2
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<5>;
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<8>;
	.loc	1 24 0
$L__func_begin0:
	.loc	1 24 0

	ld.param.u64 	%rd4, [add_kernel_0d1d2d_param_0];
	ld.param.u64 	%rd5, [add_kernel_0d1d2d_param_1];
$L__tmp0:
	.loc	1 33 39
	mov.u32 	%r5, %tid.x;
	and.b32  	%r6, %r5, 7;
	ld.param.u64 	%rd6, [add_kernel_0d1d2d_param_2];
	.loc	1 31 22
	mov.u32 %r1, %ctaid.x;
	.loc	1 32 22
	shl.b32 	%r7, %r1, 3;
	.loc	1 33 26
	or.b32  	%r8, %r7, %r6;
	.loc	1 34 19
	setp.lt.s32 	%p1, %r8, 8;
	.loc	1 35 22
	mul.wide.s32 	%rd7, %r8, 4;
	add.s64 	%rd1, %rd4, %rd7;
	.loc	1 35 14
	mov.u32 %r2, 0x0;
	@%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];
	.loc	1 36 22
	add.s64 	%rd2, %rd5, %rd7;
	.loc	1 36 14
	mov.u32 %r3, 0x0;
	@%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ];
	.loc	1 37 15
	add.s32 	%r4, %r3, %r2;
	.loc	1 38 24
	add.s64 	%rd3, %rd6, %rd7;
	.loc	1 38 33
	and.b32  	%r9, %r5, 120;
	setp.eq.s32 	%p4, %r9, 0;
	and.pred  	%p3, %p4, %p1;
	@%p3 st.global.b32 [ %rd3 + 0 ], { %r4 };
	.loc	1 38 2
	ret;
$L__tmp1:
$L__func_end0:

}
	.file	1 "/jax_miniconda/add.py"
	.section	.debug_abbrev
	{
.b8 1
.b8 17
.b8 1
.b8 37
.b8 8
.b8 19
.b8 5
.b8 3
.b8 8
.b8 16
.b8 6
.b8 27
.b8 8
.b8 180
.b8 66
.b8 12
.b8 17
.b8 1
.b8 18
.b8 1
.b8 0
.b8 0
.b8 2
.b8 46
.b8 0
.b8 17
.b8 1
.b8 18
.b8 1
.b8 64
.b8 10
.b8 135
.b8 64
.b8 8
.b8 3
.b8 8
.b8 58
.b8 11
.b8 59
.b8 11
.b8 63
.b8 12
.b8 0
.b8 0
.b8 0
	}
	.section	.debug_info
	{
.b32 119
.b8 2
.b8 0
.b32 .debug_abbrev
.b8 8
.b8 1
.b8 116
.b8 114
.b8 105
.b8 116
.b8 111
.b8 110
.b8 0
.b8 2
.b8 0
.b8 97
.b8 100
.b8 100
.b8 46
.b8 112
.b8 121
.b8 0
.b32 .debug_line
.b8 47
.b8 106
.b8 97
.b8 120
.b8 95
.b8 109
.b8 105
.b8 110
.b8 105
.b8 99
.b8 111
.b8 110
.b8 100
.b8 97
.b8 0
.b8 1
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 2
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 1
.b8 156
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 1
.b8 24
.b8 1
.b8 0
	}
	.section	.debug_pubnames
	{
.b32 $L__pubNames_end0-$L__pubNames_start0
$L__pubNames_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 64
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b32 0
$L__pubNames_end0:
	}
	.section	.debug_pubtypes
	{
.b32 $L__pubTypes_end0-$L__pubTypes_start0
$L__pubTypes_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 0
$L__pubTypes_end0:
	}
	.section	.debug_loc	{	}
; No such file or directory; current tracing scope: custom-call.3; current profiling annotation: XlaModule:#prefix=jit(triton_kernel_call)/jit(main)/triton_kernel_call[fn=JITFunction(__main__:add_kernel) scalar_args=() name= custom_call_target_name=triton_kernel_call out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=int32),) grid=(1,) num_warps=None num_stages=None num_ctas=1 enable_fp_fusion=True enable_warp_specialization=False enable_persistent=False input_output_aliases=() zeroed_outputs=() debug=False serialized_metadata=b'' block_size=8],hlo_module=jit_triton_kernel_call,program_id=2#.

triton requirement is out of date

Attempting to use jax-triton nightly fails on CUDA with

 File "/home/michael/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1433, in jaxpr_subcomp
    ans = rule(rule_ctx, *rule_inputs, **eqn.params)
  File "/home/michael/.local/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 383, in triton_kernel_call_lowering
    kernel, specialization = get_or_create_triton_kernel(
  File "/home/michael/.local/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 252, in get_or_create_triton_kernel
    module = code_gen.ast_to_ttir(
TypeError: ast_to_ttir() got an unexpected keyword argument 'target'

This appears to be because requirements.txt lists triton_nightly-2.1.0.dev20230714011643, but a newer triton nightly is actually required.
(That version doesn't have a 'target' argument to ast_to_ttir).

I would around this by installing triton_nightly-2.1.0.dev20231014192330, but it would be great to have correct dependencies.

Pallas Operators

I'm playing around with Pallas, but I immediately ran into a bunch of basic operators that don't seem to be supported

Specifically it would be nice to have logical and arithmetic binary operators as well as sign: <<, >>, 2**, etc.

Been attempting to port this over from triton in an extensible way: https://github.com/fpgaminer/GPTQ-triton , but immediately ran into this issue.

How do "num_warps" and "num_stages" mean?

Dear all, I notice that triton_call() function receives num_warps and num_stages arguments. And they are different in "add.py" and "matrix_multiplication.py" examples. How do they mean? How can we provide them when customize an CUDA operator?

Internal error found in jax-triton for fused attention

Overview

We wrote a program in jax-triton derived from Triton's example for fused attention (06-fused-attention.py). In terms of the triton kernel in the program, it contains only minimal change. See the script main.py [1] and the environment env.txt [2] for details.

How to run

A simple run as follows can replicate the errors included in error.log [3]

python main.py

jax-triton fails at some input shapes

qkv: (2, 2, 32, 32)  # pass
qkv: (2, 2, 128, 64)  # pass
qkv: (2, 2, 256, 64)  # failed

It seems an internal error:

...
  File "/home/xx/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 468, in triton_call
    out_flat = triton_kernel_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: IndexError: map::at
  1. main.py
  2. env.txt
  3. error.txt

Running Pallas

Hi,

Installing jax, jaxlib, and jax-triton nightly builds cause the following error:

  File "/home/mehdi/Repos/venvs/jax-dev/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1721, in pallas_call_lowering
    backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
TypeError: to_proto(): incompatible function arguments. The following argument types are supported:
    1. (self: jaxlib.cuda._triton.TritonKernelCall, arg0: str, arg1: str) -> bytes

Invoked with: <jaxlib.cuda._triton.TritonKernelCall object at 0x7faebfce6ef0>, b''
I0000 00:00:1692795278.467295  180484 tfrt_cpu_pjrt_client.cc:469] TfrtCpuClient destroyed.

Package Version Editable project location


absl-py 1.4.0
filelock 3.12.2
jax 0.4.15
jax-triton 0.1.4 /home/mehdi/Repos/jax-triton
jaxlib 0.4.15.dev20230822+cuda12.cudnn89
ml-dtypes 0.2.0
numpy 1.25.2
opt-einsum 3.3.0
pip 22.0.2
scipy 1.11.2
setuptools 59.6.0
triton-nightly 2.1.0.dev20230714011643

Empty outputs of triton_call are corrupted

Consider this:

import jax
import jax.numpy as jnp
import jax_triton as jt
import triton
import triton.language as tl

@triton.jit
def temp_kern(x_ptr,):
  pid = tl.program_id(axis=0)

@jax.jit
def temp(x):
  return jt.triton_call(
      x,
      kernel=temp_kern,
      out_shape=jax.ShapeDtypeStruct(shape=[], dtype=x.dtype),
      grid=(1,),)

x = jnp.ones((2,2))
out = temp(x)

The array out is now corrupted:

print(out.shape, out.size)
# (), 1

and len(out) throws and error TypeError: len() of unsized object.

Unable to install Jax-Triton

Hi there,

I am actually trying to use Pallas but it keeps on asking me to install Triton (no problems encountered) and Jax-Triton (problems encountered).

image

I am doing pip install with Python 3.11, not sure what is happening and what I can do to resolve this.

Any ideas?

Cannot install jax-triton

Dear all, we cannot install jax-triton successfully through pip or python setup.py install command.

Errors reported are:

      gcc -pthread -B /home/brainpy/miniconda3/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /home/brainpy/miniconda3/include -I/home/brainpy/miniconda3/include -fPIC -O2 -isystem /home/brainpy/miniconda3/include -fPIC -I/usr/local/cuda/include -I/tmp/pip-build-env-chiaeyab/overlay/lib/python3.9/site-packages/pybind11/include -I/home/brainpy/miniconda3/include/python3.9 -c lib/custom_call.cc -o build/temp.linux-x86_64-cpython-39/lib/custom_call.o
      lib/custom_call.cc: In functionvoid do_custom_call(CUstream, void**, char*, size_t)’:
      lib/custom_call.cc:52:48: error: invalid conversion fromconst void*tovoid*’ [-fpermissive]
         52 |     CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
            |                                     ~~~~~~~~~~~^~
            |                                                |
            |                                                const void*
      error: command '/usr/bin/gcc' failed with exit code 1
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for jax-triton
Failed to build jax-triton
ERROR: Could not build wheels for jax-triton, which is required to install pyproject.toml-based projects

Using jax_triton breaks jax compilation_cache

The jax compilation_cache uses a hash of the serialized MLIR as a key for the cache (compilation_cache.py#L152). When using jax_triton, the MLIR module becomes non-deterministic. This leads to cache misses and the compilation_cache no longer works.

I dumped the MLIR of two runs that should be identical and compared the differences. There were only a handful of differences and these were caused by jax_triton.

Most of these issues could be fixed by adapting the logic in compilation_cache.py#L176. I'm happy to send a PR that ignores the backend_config field and removes the function objects from the serialized IR when computing the hash.


First difference

The value of backend_config is different even though no settings were changed.

Run A:

    %0 = mhlo.custom_call @triton_kernel_call(%arg0, %arg1, %arg2, %arg3) {api_version = 2 : i32, backend_config = "\C0\15\C7\17kU\00\00", operand_layouts = [dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[2, 1, 0]> : tensor<3xindex>, dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<1x8192x24x128xbf16>, tensor<1x8192x24x128xbf16>, tensor<1x8192x24x128xbf16>, tensor<1x8192xi32>) -> tuple<tensor<1x8192x24x128xbf16>, tensor<1x24x8192xf32>, tensor<1x24x8192xf32>> loc(#loc482)

Run B:

    %0 = mhlo.custom_call @triton_kernel_call(%arg0, %arg1, %arg2, %arg3) {api_version = 2 : i32, backend_config = "\E0\AB\80\CE}U\00\00", operand_layouts = [dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[3, 2, 1, 0]> : tensor<4xindex>, dense<[2, 1, 0]> : tensor<3xindex>, dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<1x8192x24x128xbf16>, tensor<1x8192x24x128xbf16>, tensor<1x8192x24x128xbf16>, tensor<1x8192xi32>) -> tuple<tensor<1x8192x24x128xbf16>, tensor<1x24x8192xf32>, tensor<1x24x8192xf32>> loc(#loc482)

Second difference

Stringified function pointers are different. The ir.Module might be identical but the hash of the string will not be.

Run A:

#loc284 = loc("pjit(update)/jit(main)/transpose(jvp(my_model))/transformer/decoder_layer/remat2[prevent_cse=True differentiated=True policy=<function nothing_saveable at 0x7ff6b2344040>]"(#loc48))

Run B:

#loc284 = loc("pjit(update)/jit(main)/transpose(jvp(my_model))/transformer/decoder_layer/remat2[prevent_cse=True differentiated=True policy=<function nothing_saveable at 0x7fbf8a450040>]"(#loc48))

Third difference

Stringified object pointers are different.

Run A:

#loc501 = loc("pjit(update)/jit(main)/transpose(jvp(my_model))/transformer/decoder_layer/remat/multi_head_attention/jit(shmap_body)/jit(wrapped)/pallas_call[name=mha_forward which_linear=(False, False, False, False) in_shapes=(ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 8192), dtype=int32)) out_shapes=(ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 21, 8192), dtype=float32), ShapeDtypeStruct(shape=(1, 21, 8192), dtype=float32)) debug=False interpret=False grid_spec=GridSpec(grid=(64, 1, 24), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 133, <jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 133, <jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 133, <jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 8192), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 133, <jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, <jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 8192), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, <jax_triton.pallas.core.Mapped object at 0x7ff09c3bb7c0>, 8192), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, c, 0) })), mapped_dims=()) input_output_aliases=() num_warps=8 num_stages=1]"(#loc62))

Run B:

#loc501 = loc("pjit(update)/jit(main)/transpose(jvp(my_model))/transformer/decoder_layer/remat/multi_head_attention/jit(shmap_body)/jit(wrapped)/pallas_call[name=mha_forward which_linear=(False, False, False, False) in_shapes=(ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 8192), dtype=int32)) out_shapes=(ShapeDtypeStruct(shape=(1, 133, 21, 123), dtype=bfloat16), ShapeDtypeStruct(shape=(1, 21, 8192), dtype=float32), ShapeDtypeStruct(shape=(1, 21, 8192), dtype=float32)) debug=False interpret=False grid_spec=GridSpec(grid=(64, 1, 24), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 133, <jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 133, <jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 133, <jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 8192), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 133, <jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 123), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, <jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 8192), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, <jax_triton.pallas.core.Mapped object at 0x7fb9942bb820>, 8192), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let  in (b, c, 0) })), mapped_dims=()) input_output_aliases=() num_warps=8 num_stages=1]"(#loc62))

Internal bounds-check failed when calling FlashAttention kernel from jax-triton

Hello. I'm realatively new to triton and trying to call the HazyResearch flash attention triton kernel from Jax using jax-triton. I've taken the implementation from here and tried to do some simple modifications to make it compatible with jax-tritron. I made sure the inputs are jax.Array type and changed the order of the kernel parameters so that the output parameters are last. I also changed bias_type to an enum (maybe this is the issue?) because of some issues I had with the string parameters with Jax.

I'm getting this error after calling jit():

python: /root/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/ArrayRef.h:441: T& llvm::MutableArrayRef<T>::operator[](size_t) const [with T = mlir::BlockArgument; size_t = long unsigned int]: Assertion `Index < this->size() && "Invalid index!"' failed.

I've tried this with several input shapes, with the same results. Including headdim=64.

Here is the file:

flash-attn-jax-triton.py.txt

Env:
jax: 0.4.9
jax-triton: 0.1.4
triton: 2.1.0 (HEAD)

ImportError for jax_triton

When importing jax_triton, I got the following error in colab:

ImportError: cannot import name 'ShapedArrayRef' from 'jax._src.state' (/usr/local/lib/python3.10/dist-packages/jax/_src/state/__init__.py)

Unimplemented Pallas operations

Remaining Triton ops to implement in Pallas:

  • Programming model operations
    • num_programs
  • Atomic operations
    • atomic_cas
    • atomic_max
    • atomic_min
    • atomic_add
    • atomic_and
    • atomic_or
    • atomic_xor
    • atomic_xchg
  • Bitwise operations
    • bitwise_and
    • bitwise_or
    • bitwise_xor
    • shift_left
    • shift_right
    • bitwise_not
  • General operations
    • neg
    • greater_than
    • less_than
    • less_equal
    • equal
    • not_equal
  • Array operations
    • cat
    • reshape
  • Reductions
    • argmin
    • min
    • argmax
    • xor_sum
  • Math operations
    • umulhi
    • log
    • sqrt
  • Misc. operations
    • dequantize
    • bitcast
    • clock
    • globaltimer
    • debug_barrier

`jax-triton` in Google Colab

I'm trying to run the add_kernel example in Google Colab, using an NVIDIA Tesla T4 GPU with CUDA v11.8.89.

Everytime I call add(x_val, y_val), I get a XlaRuntimeError with the following message:

XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-b4023cedd8a0-bb9aba3b-452-601567fa49fa3, line 5; fatal : Unsupported .version 8.0; current version is '7.8'
ptxas fatal : Ptx assembly aborted due to errors
; current tracing scope: custom-call.3; current profiling annotation: XlaModule:#hlo_module=jit_triton_kernel_call,program_id=4#.

Please see this notebook for more information:
https://gist.github.com/leandrolcampos/fbdac327dbc7399c1d58d8b929dbeb67

Is it possible to run jax-triton in Google Colab? If yes, how? I tried different installation options presented in the documentation, but nothing worked.

All the best,

Inconsistent NaN results on Triton matmul kernel

I've found a behavior in which the output of jt.triton_call differs depending on when/where certain metaparameters (I suspect the metaparameters related to the grid) are defined.

Specifically, for the Triton repo's matmul kernel (source):

(1) jt.triton_call returns a matrix of NaNs from the second call onwards (first call is correct), if the metaparams BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K are directly passed into the function call
(2) jt.triton_call returns correct results when those metaparameters are selected via triton.autotune (and not directly passed into jt.triton_call)

Also, simply importing Triton's matmul_perf_model (source) further affects this; with the import, the jt.triton_call fails (NaN outputs, as described in (1)) on the second call and beyond; if the import is commented out, then it fails on the third call and beyond.

I am attaching a script that reproduces this behavior.

I'm wondering if this is expected behavior, and if so, what jax_triton conventions I should be following regarding metaparameter/tl.constexpr passing. In general, the boundary between args and metaparams seems a bit vague; is a parameter a metaparameter if and only if it is a constexpr?

Thanks for the help!

matmul_repro.txt

Unable to Install

Hi there. I'm attempting to follow the installation directions. First, I install jax triton using pip install jax-triton.

Then, I install JAX and JAXLib using the Install JAX Docs:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Currently, jax-triton installs JAX version 0.4.14. But there are no compatible JAXlib versions available. Thus, I install the most up to date JAX / JAXlib using the command above.

I keep getting the following error message when running the example code:
ValueError: triton_call is only available when triton is installed.

What is the correct / updated way to install JAX Triton? It would be great to have some clearer docs, I'm happy to make a PR once I can get it working.

XlaRuntimeError when applying FusedAttention

Description

Hi, while applying FusedAttention with jax-triton, we got the following XLA error happens on Nvidia-A100:

2023-08-28 03:06:51.319566: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.319790: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.319901: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320187: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320240: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320386: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320465: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320846: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/paxml/paxml/main.py", line 510, in
app.run(main, flags_parser=absl_flags.flags_parser)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/workspace/paxml/paxml/main.py", line 445, in main
_main(argv)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 487, in _main
run(experiment_config=experiment_config,
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 420, in run
run_experiment(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 285, in run_experiment
train.train_and_evaluate(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/train.py", line 274, in train_and_evaluate
executor.start()
File "/workspace/paxml/paxml/executors.py", line 269, in start
_train_and_evaluate_common(
File "/workspace/paxml/paxml/executors.py", line 406, in _train_and_evaluate_common
program_output = train_program.run(partitioned_train_state, step_i)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/programs.py", line 332, in run
new_step, new_state, train_outputs = self.train_step(
File "/workspace/paxml/paxml/programs.py", line 620, in train_step
return step + 1, *train_step(state, prng_key, inputs, static_args)
File "/workspace/paxml/paxml/trainer_lib.py", line 1634, in call
return pjitted_fn(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

Steps for reproducing:
Add model variants to /root/.local/lib/python3.10/site-packages/paxml/tasks/lm/params/nvidia.py

--- a/paxml/tasks/lm/params/nvidia.py
+++ b/paxml/tasks/lm/params/nvidia.py
@@ -350,6 +350,20 @@ class NVIDIA70BProxy(NVIDIA5B):
   MODEL_DIMS = 8192
   HIDDEN_DIMS = 4 * 8192

+@experiment_registry.register
+class test7B(NVIDIA70BProxy):
+  PERCORE_BATCH_SIZE = 16
+  MICROBATCH_SIZE = 1
+  USE_FLASH_ATTENTION = False
+  USE_TRITON_LAYER_NORM = False
+  NUM_LAYERS = 8
+  NUM_STAGES = 4
+  ICI_MESH_SHAPE = [4, 1, 1, 1]
+
+@experiment_registry.register
+class test7BFA(test7B):
+  USE_FLASH_ATTENTION = True
+  USE_TRITON_LAYER_NORM = True

 @experiment_registry.register
 class NVIDIA116BProxy(NVIDIA5B):

Run w/o FusedAttention (PASS case):
python3 -u -m paxml.main --noenable_checkpoint_saving --job_log_dir=./jax_tmp --exp=paxml.tasks.lm.params.nvidia.test7B

Run w FusedAttention (FAILED case):
python3 -u -m paxml.main --noenable_checkpoint_saving --job_log_dir=./jax_tmp --exp=paxml.tasks.lm.params.nvidia.test7BFA

Versions:

python3 -m pip install git+https://github.com/google/paxml orbax==0.1.6 --user
python3 -m pip install git+https://github.com/google/praxis --user
python3 -m pip install git+https://github.com/google/flax --user
python3 -m pip uninstall orbax orbax-checkpoint -y
python3 -m pip install git+https://github.com/google/orbax/#subdirectory=checkpoint --user

python3 -m pip uninstall triton -y
python3 -m pip install git+https://github.com/openai/triton@b24dc19##subdirectory=python

python3 -m pip uninstall jax-triton -y
python3 -m pip install git+https://github.com/jax-ml/jax-triton@4f97b83 --no-deps

python3 -m pip uninstall jax jaxlib -y
git clone https://github.com/google/jax
pushd jax
git checkout 8d80e25
#build JAXLIB
apt update -y;apt install g++ -y
python3 -m pip install numpy wheel build
python3 build/build.py --enable_cuda
#install JAX
python3 setup.py develop --user
#install JAXLIB
python3 -m pip install dist/*.whl
popd
 
## Change the used source of pallas.ops in praxis
sed -i 's/jax.experimental.pallas.ops/jax_triton.pallas.ops/g' /root/.local/lib/python3.10/site-packages/praxis/layers/gpu_fast_attention.py

NVIDIA GPU info

4 A100-SXM-80GB GPUs

Kernel compilation hangs with a particular dtype

Here's the pallas kernel from the repo that I've slightly modified by introducing control over accumulator dtype:

def mha_forward_kernel(
    q_ref,
    k_ref,
    v_ref,
    o_ref,
    *residual_refs,
    dot_product_scale: float,
    block_q: int,
    block_d: int,
    block_kv: int
):
    dtype = jnp.float32  # HANGS IF I REPLACE THIS WITH BFLOAT16 !!!

    seq_len = q_ref.shape[0]
    start_q = pl.program_id(0)

    neg_inf = -1e20

    # acc is the buffer where we accumulate the output on sram.
    # m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
    m_i = jnp.full(block_q, dtype=dtype, fill_value=neg_inf)
    l_i = jnp.zeros(block_q, dtype=dtype)
    # acc is the buffer where we accumulate the output on sram.
    acc = jnp.zeros((block_q, block_d), dtype=dtype)

    # Load q: it will stay in L1 throughout. Indices form a matrix because we
    # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
    # q tile has shape [block_q, block_d], block_d == head_dim.
    q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)))

    # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
    # (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
    # Here we only loop over blocks of kv to process entire seq_len, the loop over
    # blocks of q is carried out by the grid.
    def body(start_k, carry):
        acc, m_prev, l_prev = carry

        k = pl.load(k_ref, (pl.dslice(start_k * block_kv, block_kv), slice(None)))

        qk = jnp.zeros([block_q, block_kv], dtype=dtype)
        qk += pl.dot(q, k.T)  # [block_q, block_k]
        qk *= dot_product_scale  # [block_q, block_k]

        m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev)
        l_prev *= jnp.exp(m_prev - m_curr)
        p = jnp.exp(qk - m_curr[:, None])
        l_curr = jnp.sum(p, axis=1) + l_prev

        l_rcp = jnp.ones((), dtype=dtype) / l_curr
        p = p * l_rcp[:, None]
        acc *= (l_prev * l_rcp)[:, None]

        v = pl.load(
            v_ref, (pl.dslice(start_k * block_kv, block_kv), pl.dslice(block_d))
        )
        acc = acc + pl.dot(p.astype(v.dtype), v)
        return acc.astype(dtype), m_curr.astype(dtype), l_curr.astype(dtype)

    upper_bound = jt.cdiv(seq_len, block_kv)
    acc, m_i, l_i = jax.lax.fori_loop(0, upper_bound, body, (acc, m_i, l_i))

    if residual_refs:
        l_ref, m_ref = residual_refs
        pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i)
        pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i)

    # Write output to dram.
    acc = acc.astype(o_ref.dtype)
    pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc)

Suprisingly, the compilation of this kernel hangs (!) if I set the dtype to be bfloat16. I suspect there's a bug somewhere.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.