Coder Social home page Coder Social logo

pennylaneai / catalyst Goto Github PK

View Code? Open in Web Editor NEW
97.0 12.0 22.0 6.1 MB

A JIT compiler for hybrid quantum programs in PennyLane

Home Page: https://docs.pennylane.ai/projects/catalyst

License: Apache License 2.0

Makefile 0.82% Python 46.11% CMake 1.18% C 0.49% C++ 41.53% MLIR 9.15% LLVM 0.44% Shell 0.03% TeX 0.17% Dockerfile 0.08%
jit mlir qir quantum-compiler autodiff automatic-differentiation jax llvm python quantum

catalyst's People

Contributors

co9olguy avatar dependabot[bot] avatar dime10 avatar doctorperceptron avatar erick-xanadu avatar grwlf avatar josh146 avatar lillian542 avatar maliasadi avatar mandrenkov avatar multiphasecfd avatar muzammiluddin-syed-ece avatar paul0403 avatar pengmai avatar rashidnhm avatar rauletorresc avatar rmoyard avatar tzunghanjuang avatar vincentmr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

catalyst's Issues

Incorrect return value data-type with functions returning `qml.counts`

Functions returning qml.counts() as the measurement process returns list instead of tuple of samples and counts.

Example:

@qjit()
@qml.qnode(qml.device(backend, wires=1, shots=1000))
def circuit(x):
    qml.RX(x, wires=0)
    return (qml.counts(), qml.state())

result = circuit(0.2)
assert isinstance(result, tuple) # PASS
assert isinstance(result[0], tuple) # FAIL
assert len(result) == 2 # FAIL as len(result) == 3

[Frontend] JAX integration fails if qjit function parameters change shape

Consider the following qjitted function:

dev = qml.device("lightning.qubit", wires=2)

@qjit
@qml.qnode(dev)
def circuit(params, n):
    
    def ansatz(i, x):
        qml.RX(x[i, 0], wires=0)
        qml.RY(x[i, 1], wires=1)
        qml.CNOT(wires=[0, 1])
        return x

    catalyst.for_loop(0, n, 1)(ansatz)(jnp.reshape(params, (-1, 2)))

    return qml.expval(qml.PauliZ(1))

This works well with JAX integration:

>>> params = jnp.array([0.54, 0.3154, 0.654, 0.123])
>>> circuit(params, 2)
0.7612754362314241
>>> jax.grad(circuit, argnums=0)(params, 2)
[ 0.07954928 -0.32372842 -0.50406511  0.00828534]

However, if called with a different input parameter shape, it now fails:

>>> params = jnp.array([0.54, 0.3154, 0.654, 0.123, 0.1, 0.2])
>>> circuit(params, 3)
0.36152373616486333
>>> jax.grad(circuit, argnums=0)(params, 3)
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (6,).

Note that if the circuit is re-qjitted, everything now works:

>>> circuit = qjit(circuit.qfunc)
>>> jax.grad(circuit, argnums=0)(params, 3)
[-0.42763436 -0.46892463  0.13741047 -0.54825337 -0.42408931 -0.07123154]

Incorrect gradient return type deduction in finite-diff method

The following example fails compilation:

import pennylane as qml
from catalyst import grad, qjit

@qml.qnode(qml.device("lightning.qubit", wires=1))
def func(p: float):
    x = qml.probs()
    y = p**2
    return x, y

qjit(grad(func, method="fd"))(0.1)

with the error message:

/tmp/tmpgwi9cyw0/func.nohlo.mlir:4:12: error: 'gradient.grad' op invalid result type: grad result at position 1 must be 'tensor<f64>' but got 'tensor<2xf64>'
    %0:2 = "gradient.grad"(%arg0) {callee = @func, diffArgIndices = dense<0> : tensor<1xi64>, finiteDiffParam = 9.9999999999999995E-8 : f64, method = "fd"} : (tensor<f64>) -> (tensor<2xf64>, tensor<2xf64>)
           ^
/tmp/tmpgwi9cyw0/func.nohlo.mlir:4:12: note: see current operation: %0:2 = "gradient.grad"(%arg0) {callee = @func, diffArgIndices = dense<0> : tensor<1xi64>, finiteDiffParam = 9.9999999999999995E-8 : f64, method = "fd"} : (tensor<f64>) -> (tensor<2xf64>, tensor<2xf64>)

Add support for `if - else if - else` conditional chains

Context

Currently, Catalyst has native support for conditionals via the catalyst.cond function, a decorator for defining functional if-else conditionals with a similar meaning to JAX's lax.cond :

from catalyst import qjit, cond

@qjit
def test(x: float):

    @cond(x > 1.4)  # provide Boolean predicate to decorator
    def multiply():
        return x * 2

    @multiply.otherwise  # define an alternative branch
    def multiply():
        return x

    return multiply()  # call the conditional function and obtain return values from the appropriate branch
>>> test(0.1)
Array(0.1, dtype=float64)
>>> test(2.1)
Array(4.2, dtype=float64)

The decorator form is meant to resemble Python block conditionals, however, branching into more than 2 options can became quite cumbersome since it requires nesting conditionals. If we add another range (2.7, โˆž) to our example above:

@qjit
def test(x: float):

    @cond(x > 2.7)
    def multiply():
        return x * 4

    @multiply.otherwise
    def multiply():

        @cond(x > 1.4)
        def inner():
            return x * 2

        @inner.otherwise
        def inner():
            return x

        return inner()

    return multiply()

we can see that the readability suffers.

Goal

We would like to extend the catalyst.cond function to support branching into an arbitrary number of cases without having to resort to nesting conditionals, but rather via if - else if - ... - else chains. The following example should then be supported:

@qjit
def test(x: float):

    @cond(x > 2.7)
    def multiply():
        return x * 4

    @multiply.else_if(x > 1.4)
    def multiply():
        return x * 2
    
    @multiply.otherwise
    def multiply():
        return x

    return multiply()
>>> test(0.1)
Array(0.1, dtype=float64)
>>> test(2.1)
Array(4.2, dtype=float64)
>>> test(3.1)
Array(12.4, dtype=float64)

Technical details

The implementation of catalyst.cond relies on several elements:

  • the user-facing function:

    def cond(pred):
    """A :func:`~.qjit` compatible decorator for if-else conditionals in PennyLane/Catalyst.

    This function exists to define the user interface of the Catalyst conditional. It generates an object that can be further modified by the user and then called (CondCallable).

  • the representation stored on PennyLane's tape class:

    class Cond(Operation):
    """PennyLane's conditional operation."""

    As part of the program tracing step in the Catalyst frontend, quantum operations (as well as Catalyst control flow operations) are stored in a data structure called the quantum tape. We store the conditional here and retrieve it later to convert it to its JAXPR form.

  • the representation stored in the JAX Program Representation (JAXPR):

    qcond_p = jax.core.AxisPrimitive("qcond")
    qcond_p.multiple_results = True

    Elements (operations) of the JAXPR are called primitives. They store the input and output types of operations, and define new values to be used by other operations in the JAXPR.

  • and its lowering to MLIR:

    def _qcond_lowering(
    jax_ctx: mlir.LoweringRuleContext,
    pred: ir.Value,
    *branch_args_plus_consts: tuple,
    true_jaxpr: jax.core.ClosedJaxpr,
    false_jaxpr: jax.core.ClosedJaxpr,
    ):

    JAXPR is converted to MLIR to be used as input to a compiler. In JAX's case this is usually XLA, and PennyLane is able to use the Catalyst compiler to produce binary code for Python functions. Custom conversion functions can be registered in the JAX framework for primitives, which need to generate MLIR from Python via the use of MLIR's Python bindings. These bindings are automatically generated during Catalyst's build process.

In order to implement the proposed feature on the catalyst.cond function:

  • the UI needs to be modified by adding a new else_if member to the CondCallable class, which will allow the user to successively add new branches to a @cond decorated function,
  • the lowering strategy towards MLIR needs to be updated to express the new if - else if - ... - else structure using MLIR operations, preferably from the SCF dialect,
  • and other classes/functions may need adjusting as appropriate.

[BUG] Classical pre-processing not working when using `grad` with enzyme

The following code, with no classical post-processing on the function we are computing the gradient of, works correctly:

@qjit
def f(x):
    @qml.qnode(dev)
    def g(y):
        qml.RX(y, wires=0)
        return qml.expval(qml.PauliZ(0))
    return grad(lambda y: g(y) ** 2)(x)
>>> f(0.4)
array(-0.99994172)

However, if we introduce classical pre-processing on the QNode argument, this no longer works:

@qjit
def f(x):
    @qml.qnode(dev)
    def g(y):
        qml.RX(y, wires=0)
        return qml.expval(qml.PauliZ(0))
    return grad(lambda y: g(jnp.cos(y)) ** 2)(x)
>>> f(0.4)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-17-86870092fac7>](https://localhost:8080/#) in <cell line: 1>()
----> 1 f(0.4)

4 frames

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    645             return self.user_function(*args, **kwargs)
    646 
--> 647         function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
    648             self.compiled_function, *args
    649         )

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args)
    620             if not self.compiling_from_textual_ir:
    621                 self.mlir_module = self.get_mlir(*r_sig)
--> 622             function = self.compile()
    623         else:
    624             assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in compile(self)
    579             qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
    580 
--> 581             shared_object, llvm_ir, inferred_func_data = self.compiler.run(
    582                 self.mlir_module, pipelines=self.compile_options.pipelines
    583             )

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run(self, mlir_module, *args, **kwargs)
    399         """
    400 
--> 401         return self.run_from_ir(
    402             mlir_module.operation.get_asm(
    403                 binary=False, print_generic_op_form=False, assume_verified=True

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run_from_ir(self, ir, module_name, pipelines, lower_to_llvm)
    356             print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)
    357 
--> 358         compiler_output = run_compiler_driver(
    359             ir,
    360             workspace,

RuntimeError: Compilation failed:
operand #0 does not dominate this use

grad returns non-deterministic values for some multi-dimentional inputs

Consider the following program:

import jax.numpy as jnp
from catalyst import qjit, grad as grad

def func(p1):
  return jnp.stack([1*p1, 2*p1, 3*p1])

a = jnp.zeros([2], dtype=float)
r = qjit(grad(func, argnum=[0]))(a)
print(type(r), r.shape)
print(r)

While not clearly documented, we expected r to be a deterministic tensor representing the func's jacobian. In fact, it contains some non-deterministic values:

<class 'jaxlib.xla_extension.Array'> (2, 3, 2)
[[[1.00000000e+000 1.00000000e+000]
  [0.00000000e+000 0.00000000e+000]
  [4.66239210e-303 4.66239210e-303]]

 [[0.00000000e+000 0.00000000e+000]
  [2.00000000e+000 2.00000000e+000]
  [4.66239175e-303 4.66239175e-303]]]

where the imprecise numbers change from run to run, and sometimes contain nans (uninitialized?).

Note, that jax.jit(jax.jacfwd(func, argnums=[0]))(a) returns a deterministic tensor of shape (3,2,2).
The catalyst is bd8eb11

Canonicalize MLIR output from JAX before lit testing

The MLIR obtained after tracing and stored in the CompilationPipeline.mlir attribute contains lots of redundant/mhlo-specific operations, making it more difficult to read but more importantly difficult to test.

We should run the -canonicalize pass on this representation to obtain a compact and consistent representation for testing and further processing.

[Frontend] Remove special handling for Hamiltonians primitives

Once this issue in PL is resolved, apply the following patch:

diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 0205381..f7bddb6 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -352,24 +352,6 @@ def trace_quantum_tape(
     return out, qreg, qubit_states
 
 
-# TODO: remove once fixed upstream
-def trace_hamiltonian(coeffs, *nested_obs):
-    """Trace a hamiltonian.
-
-    Args:
-        coeffs: a list of coefficients
-        nested_obs: a list of the nested observables
-
-    Returns:
-        a hamiltonian JAX primitive used for tracing
-    """
-    # jprim.hamiltonian cannot take a python list as input
-    # only as *args can a list be passed as an input.
-    # Instead cast it as a JAX array.
-    coeffs = jax.numpy.asarray(coeffs)
-    return jprim.hamiltonian(coeffs, *nested_obs)
-
-
 def trace_observables(obs, qubit_states, p, num_wires, qreg):
     """Trace observables.
 
@@ -405,7 +387,7 @@ def trace_observables(obs, qubit_states, p, num_wires, qreg):
         jax_obs = jprim.tensorobs(*nested_obs)
     elif isinstance(obs, qml.Hamiltonian):
         nested_obs = [trace_observables(o, qubit_states, p, num_wires, qreg)[0] for o in obs.ops]
-        jax_obs = trace_hamiltonian(op_args, *nested_obs)
+        jax_obs = jprim.hamiltonian(op_args, *nested_obs)
     else:
         raise RuntimeError(f"unknown observable in measurement process: {obs}")
     return jax_obs, qubits

Add `jax.jit` and JAX transformation integration

Context

There are two ways we could support 'external' JAX integration - that is, the ability to:

  • include @qjit inside @jax.jit hybrid functions, and
  • call JAX transformations such as jax.grad and jax.vmap directly on qjit functions.

To do this, we could either use:

  1. Custom derivative rules to register the @qjit decorator with JAX, alongside pure_callback to allow for an XLA compiled program to 'callback' to Python to execute the qjit compiled program.

  2. Custom JAX primitives that can directly make use of XLA CustomCalls to call arbitrary C++ code (see https://dfm.io/posts/extending-jax/)

Longer term, (2) is likely a better solution. Option (1), as it requires a Python callback, may introduce unneeded overhead.

However, Option (1) is quite easy to do, and might make sense to do first, since it enables the same user-facing behaviour.

Option (1) prototype/proof-of-concept (click to show) Below is a proof-of-concept of Option (1):
import catalyst
import jax
import jax.numpy as jnp
def _make_jax_jit(func, shapes):
    """Use jax.pure_callback to make a JAX-JIT compatible QJIT function.

    Args:
        func (QJIT): the QJIT function to make jax-jit compatible
        shapes (Sequence[ShapeDtypeStruct]): list of shapes and dtypes to
            expect from the output of the QJIT function
    """
    qjit_func = catalyst.qjit(lambda args: func(*args))

    def wrapper(*args):
        # NOTE: at the moment, it appears that QJIT is not vectorized.
        # Can we support vectorization?

        if args[0].ndim > 1:
            # batch dimension is present. Hardcoding in only *one* batch dim.
            res = [jax.pure_callback(qjit_func, shapes, [a], vectorized=False)[0] for a in args[0]]
            return jnp.stack(res)

        return jax.pure_callback(qjit_func, shapes, args, vectorized=False)[0]

    return wrapper

def qjit(func):
    """JAX compatible qjit decorator"""

    @jax.custom_jvp
    def jax_func(*args):
        # Hardcoding in a qnode function with array(2,) output
        # and dtype float64. Questions:
        # - how do we determine this dynamically?
        # - how do we determine the batch dimensions?
        shapes = [jax.ShapeDtypeStruct((2,), jax.numpy.float64)]
        return _make_jax_jit(func, shapes)(*args)

    @jax_func.defjvp
    def f_jvp(primals, tangents):
        # hardcoding in a qnode function with tensor input of shape (3,)
        # and dtype float64. Questions:
        # - how do we determine this dynamically?
        # - how do we determine the batch dimensions?
        shapes = [jax.ShapeDtypeStruct((3, 2), jax.numpy.float64)]

        jac_fn = _make_jax_jit(catalyst.grad(func), shapes)
        jac = jac_fn(*primals)

        primals_out = jax_func(*primals)

        # Compute the vector jacobian products.
        # NOTE: this is assuming the QJIT function only has single array
        # output, and does not take into account returning multiple measurements
        tangents = jnp.squeeze(jnp.stack(tangents))

        if primals[0].ndim > 1:
            # batch dimension is present. Hardcoding in only *one* batch dim.
            tangents_out = jnp.stack([jnp.tensordot(j, t, axes=[0, 0]) for j, t in zip(jac, tangents)])
        else:
            tangents_out = jnp.tensordot(jac, tangents, axes=[[0], [0]])

        return primals_out, tangents_out

    return jax_func

We can see how seamlessly this allows jax.jit, jax.grad, and qjit to work together:

import pennylane as qml
dev = qml.device("lightning.qubit", wires=1)

@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(jnp.pi * x[0], wires=0)
    qml.RY(x[1] ** 2, wires=0)
    qml.RX(x[1] * x[2], wires=0)
    return qml.probs(wires=0)

@jax.jit
def cost_fn(weights):
    x = jnp.sin(weights)
    return jnp.sum(jnp.cos(circuit(x)) ** 2)

Executing the hybrid cost function, and computing the gradient:

>>> weights = jnp.array([0.543, 0.123, 0.743])
>>> cost_fn(weights)
1.535381488575644
>>> jax.grad(cost_fn)(weights)
[-0.19387385 -0.04838516 -0.00651345]

The above also works for (one) batch dimension:

>>> weights = jnp.array([[0.543, 0.123, 0.743], [0.1, 0.2, 0.3]])
>>> cost_fn(weights)
2.8580734398397474
>>> jax.grad(cost_fn)(weights)
[[-0.19387385 -0.04838516 -0.00651345]
 [ 0.49249037  0.05197949  0.02991883]]

Additional questions

  • This examples shows why it would be important to add catalyst.vjp and catalyst.jvp transformations! It will be much more efficient to compute these quantities directly, rather than compute the full Jacobian and then post-process it. This may also be easier to do with Enzyme integration.

  • I was surprised that vectorization doesn't seem to work, we should see if we can enable this. That is,

    @catalyst.qjit
    @qml.qnode(dev)
    def circuit(x):
        qml.RX(jnp.pi * x[0], wires=0)
        qml.RY(x[1] ** 2, wires=0)
        qml.RX(x[1] * x[2], wires=0)
        return qml.probs(wires=0)
    
    weights = jnp.array([[0.543, 0.123, 0.743], [0.1, 0.2, 0.3]])
    print(circuit(weights))
    File "/home/josh/xanadu/catalyst/frontend/catalyst/jax_primitives.py", line 471, in qinst_lowering
      if not ir.F64Type.isinstance(baseType):
    UnboundLocalError: local variable 'baseType' referenced before assignment
    

macOS 0.3.0 wheels require `zstd` to be installed

When executing Catalyst @qjit programs after pip installing the macOS 0.3.0 wheels, I get an error which indicates the zstd was unable to be found/loaded:

MHLOPass Error: Library not loaded: /usr/local/opt/zstd/lib/libzstd.1.dylib

By installing zstd manually (via brew install zstd), the error disappears, and @qjit programs are able to be compiled and executed.

Could it be that we are accidentally included a zstd dependency when building the wheel on macOS?

Counts does not return the same type as PL

  • Expected behavior: In PL qml.counts() returns a dictionary. I would expect in catalyst to also return a dictionary.

  • Actual behavior: In Catalyst qml.counts() returns a tuple of arrays.

Update the project documentation with the testing guidelines

As a reminder I suggest updating the project docs with these:

  • each piece of logic being tested should live in a separate test function
  • only parametrize tests if the same logic is tested with different input values
  • prefer self-contained test functions with duplicated code if necessary (within reason)
  • avoid test function with too much logic, that is overly modular / parametrized, or makes heavy use of fixtures or other setup code

Originally posted by @dime10 in #98 (comment)

The closest known gudelines are https://github.com/PennyLaneAI/guidance-docs/blob/master/development/code-review.md#code-author-and-reviewers-checklist

QJIT state machine raises recompilation warning unexpectedly

Sometimes when we enter the following branch of the QJIT state machine we shouldn't be raising a warning:

elif next_action == TypeCompatibility.NEEDS_COMPILATION:
if self.user_typed:
msg = "Provided arguments did not match declared signature, recompiling..."
warnings.warn(msg, UserWarning)
self.mlir_module = self.get_mlir(*r_sig)
function = self.compile()

Analysis:

  1. self.user_typed is true whenever the user provided type annotations or the function has no arguments.
  2. Additionally, this branch of the state machine is entered whenever we need to compile the user function.

I think the assumption here is that whenever 1) is true then we must have compiled the user function at least once, hence if 2) occurs the user must have changed the signature (what the error message is claiming).

However, the following example breaks that assumption:

@qjit(target="mlir")
def f():
    return 1
f()

Some tests fail when tensorflow is imported before catalyst

The imports

import tensorflow
import catalyst

do not always work, and some tests are failing if tensorflow is imported first. We can see that in when testing the wheels.

FAILED frontend/test/pytest/test_autograph.py::TestConditionals::test_qubit_manipulation_cond[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_conditionals.py::TestCond::test_qubit_manipulation_cond[lightning.qubit]
FAILED frontend/test/pytest/test_jax_integration.py::TestJAXJIT::test_multiple_calls[lightning.qubit]
FAILED frontend/test/pytest/test_mid_circuit_measurement.py::TestMidCircuitMeasurement::test_basic[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_mid_circuit_measurement.py::TestMidCircuitMeasurement::test_more_complex[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_pytree_args.py::TestPyTreesReturnValues::test_return_value_float[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_pytree_args.py::TestPyTreesReturnValues::test_return_value_tuples[lightning.qubit]
FAILED frontend/test/pytest/test_measurements_results.py::TestOtherMeasurements::test_multiple_return_values[lightning.qubit]
FAILED frontend/test/pytest/test_jit_behaviour.py::TestCallsiteCompileVsFunctionDefinitionCompile::test_equivalence[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_jit_behaviour.py::TestDecorator::test_function_is_cached[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_pytree_args.py::TestPyTreesFuncArgs::test_args_used_in_measure[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_jit_behaviour.py::TestCaching::test_function_is_cached[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_pytree_args.py::TestPyTreesFuncArgs::test_args_used_indirectly_in_measure[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_jit_behaviour.py::TestShots::test_shots_in_decorator_in_sample[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_jit_behaviour.py::TestShots::test_shots_in_callsite_in_sample[lightning.qubit]
FAILED frontend/test/pytest/test_qnode.py::test_variable_capture[0-False]
FAILED frontend/test/pytest/test_pytree_args.py::TestPyTreesReturnValues::test_return_value_dict[lightning.qubit]
FAILED frontend/test/pytest/test_loops.py::TestWhileLoops::test_alternating_loop[lightning.qubit]
FAILED frontend/test/pytest/test_conditionals.py::TestCond::test_identical_branch_names[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_contexts.py::TestTracing::test_fixed_tracing[lightning.qubit] - RuntimeError: random_device could not be read
FAILED frontend/test/pytest/test_contexts.py::TestTracing::test_cond_inside_while_loop[lightning.qubit]
FAILED **frontend/test/pytest/test_variable_wires.py::TestBasicCircuits::test_measurements[lightning.qubit-args0]**

Improve error handling / add type promotion to AutoGraph loops

Two issues are uncovered in the referenced discussion:

  • Which errors are raised during a fallback to Python? Sometimes it is clear the user intended to perform an autograph conversion, but the fallback to Python can obscure the error messaging.
  • Can we support type promotion for autograph captured control flow, on pre-initialized variables?

Is type promotion supported here?
I suppose the following code should pass:

def f(pred: bool):
  x = 1
  while pred:
    x = 3.2
return x

Originally posted by @maliasadi in #352 (comment)

[Frontend] Compute the JVP of only specific argnums.

When using catalyst.jvp, you sometimes may only want the JVP of only specific argnums. catalyst.jvp supports an argnum argument, however, I cannot seem to get it to work:

dev = qml.device("lightning.qubit", wires=2)

@qjit
@qml.qnode(dev)
def circuit(params, n):
    
    def ansatz(i, x):
        qml.RX(x[i, 0], wires=0)
        qml.RY(x[i, 1], wires=1)
        qml.CNOT(wires=[0, 1])
        return x

    catalyst.for_loop(0, n, 1)(ansatz)(params)

    return qml.expval(qml.PauliZ(1))

@qjit
def jvp(primals, tangents):
    return catalyst.jvp(circuit, [primals, 2], [tangents, np.zeros([])], argnum=0)
>>> params = jnp.array([[0.54, 0.3154], [0.654, 0.123]])
>>> dy = jnp.ones((2, 2))
>>> jvp(params, dy)
CalledProcessError: Command '['/home/ec2-user/anaconda3/envs/Braket/lib/python3.10/site-packages/catalyst/bin/quantum-opt', '--lower-gradients', '/tmp/tmpp09x_7_z/jvp.nohlo.mlir', '-o', '/tmp/tmpp09x_7_z/jvp.nohlo.opt.mlir']' returned non-zero exit status 1.

When using jax.jvp, there is no argnum argument; as far as I can tell, you simply need to specify an empty tangent vector (with dtype float0 if the corresponding primal is an integer) to avoid computation of that argument:

>>> jax.jvp(circuit.qfunc, (params, 2), (jnp.ones((2, 2)), np.zeros([], dtype=jax.float0))
(Array(0.76127544, dtype=float64), Array(-0.73995881, dtype=float64))

[Frontend] Refactor the process of adding a new device in the frontend

The frontend and more precisely QFunc and QJITDevice in pennylane_extensions.py require refactoring to address the following issues,

  • The installation and support of every device shouldn't be dependent on PL a priori. Catalyst must have its own device checking mechanism to verify the installed backend devices. Currently, the frontend heavily relies on PL's device infrastructure, requiring every backend device to be installed via pip, otherwise, PL throws an error, complaining that the plugin/device isn't installed.
  • Device options should be communicated through Catalyst and to the runtime, without the frontend needing to know hardcoded details about the devices.
  • If the options must be cached/stored in the frontend, we can potentially utilize a dataclass that also facilitates the serialization of device_kwargs in the backend,
from dataclasses import fields, dataclass

@dataclass
class LightningOptions:
    shots: int = 1000
    mcmc: bool = False
    backend: str = "kokkos"

    def __str__(self):
        t = lambda f: f"{f.name}={getattr(self, f.name)!r}"
        return ", ".join(t(f) for f in fields(self) if getattr(self, f.name) != f.default)
which allows:

>>> kwargs = LightningOptions(shots=100, mcmc=True)
>>> kwargs.shots
100
>>> print(kwargs)
shots=100, mcmc=True

References:

[BUG] Passing observables as parameters triggers an exception

import pennylane as qml
import pytest
from catalyst import qjit


def test_observable_as_parameter(backend):
    """Test to see if we can pass an observable parameter to qfunc."""

    coeffs0 = [0.3, -5.1]
    H0 = qml.Hamiltonian(qml.math.array(coeffs0), [qml.PauliZ(0), qml.PauliY(1)])

    @qjit
    def circuit(obs):
        return qml.expval(obs)

    circuit(H0)

Taking H0 as an observable will cause line 254 in compilation_pipelines.py to fail.
I suspect this is related to pytrees.

237     @staticmethod
238     def get_runtime_signature(*args):
239         """Get signature from arguments.
240 
241         Args:
242             *args: arguments to the compiled function
243 
244         Returns:
245             a list of JAX shaped arrays
246         """
247         args_data, args_shape = tree_flatten(args)
248 
249         try:
250             r_sig = []
251             for arg in args_data:
252                 r_sig.append(jax.api_util.shaped_abstractify(arg))
253             # Unflatten JAX abstracted args to preserve the shape
254             return tree_unflatten(args_shape, r_sig)
255         except Exception as exc:
256             arg_type = type(arg)
257             raise TypeError(f"Unsupported argument type: {arg_type}") from exc

This is the exception that is triggered (before being caught immediately after in line 255):

TypeError: float() argument must be a string or a real number, not 'ShapedArray'

This is because the Hamiltonian unflatten function will attempt to build a Hamiltonian object with a ShapedArray.

[BUG] Array slicing with strides causes gradients to crash

Consider

def f(x):
    return jnp.sum(x[::2])

This function returns correct results:

>>> x = jnp.array([0.1, 0.2, 0.3, 0.4])
>>> qjit(f)(x)
array(0.4)

But I either get ComilationPassError (I think, haven't been able to recreate without crashing lately) or the Python kernel crashing if I attempt to compute the gradient:

>>> qjit(grad(f))(x)
*crashes*

Interesting, it also errors with jax.grad:

>>> qjit(jax.grad(f))(x)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-11-8ec72a4dd125>](https://localhost:8080/#) in <cell line: 1>()
----> 1 qjit(jax.grad(f))(jnp.array([0.1, 0.2, 0.3, 0.4]))

4 frames

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    645             return self.user_function(*args, **kwargs)
    646 
--> 647         function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
    648             self.compiled_function, *args
    649         )

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args)
    620             if not self.compiling_from_textual_ir:
    621                 self.mlir_module = self.get_mlir(*r_sig)
--> 622             function = self.compile()
    623         else:
    624             assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in compile(self)
    579             qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
    580 
--> 581             shared_object, llvm_ir, inferred_func_data = self.compiler.run(
    582                 self.mlir_module, pipelines=self.compile_options.pipelines
    583             )

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run(self, mlir_module, *args, **kwargs)
    399         """
    400 
--> 401         return self.run_from_ir(
    402             mlir_module.operation.get_asm(
    403                 binary=False, print_generic_op_form=False, assume_verified=True

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run_from_ir(self, ir, module_name, pipelines, lower_to_llvm)
    356             print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)
    357 
--> 358         compiler_output = run_compiler_driver(
    359             ir,
    360             workspace,

RuntimeError: Compilation failed:
'tensor.extract' op incorrect number of indices for extract_element

[BUG] ImportError of GLIBCXX in Catalyst frontend

Issue description

I am experiencing issues when importing Catalyst that seem to depend on where and how the library is imported. It might be specific to my installation (I had to hack around a bit to get it to pick up clang), but I thought I should report because the failure seems erratic. I think it is related to something in the frontend.

  • Expected behavior: (What you expect to happen)

Importing Catalyst from both shell, interpreter, and Jupyter notebook should work without errors.

  • Actual behavior: (What actually happens)

When I fire up an interpreter, I can do

>>> from catalyst import qjit

or

>>> import catalyst

without issue.

If I have a Python script, and begin it with

from catalyst import qjit

it works when this is the only line in the script, but in other contexts (my Shor implementation), I receive the following Traceback:

Traceback (most recent call last):
  File "/home/olivia/Code/shortalyst-dev/full_jit_working.py", line 10, in <module>
    from catalyst import cond, measure, qjit, for_loop, while_loop
  File "/home/olivia/Code/catalyst/frontend/catalyst/__init__.py", line 64, in <module>
    from catalyst.compilation_pipelines import QJIT, CompileOptions, qjit
  File "/home/olivia/Code/catalyst/frontend/catalyst/compilation_pipelines.py", line 40, in <module>
    from catalyst.compiler import CompileOptions, Compiler
  File "/home/olivia/Code/catalyst/frontend/catalyst/compiler.py", line 30, in <module>
    from mlir_quantum.compiler_driver import run_compiler_driver
ImportError: /home/olivia/.conda/envs/catalyst/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/olivia/Code/catalyst/frontend/catalyst/../../mlir/build/python_packages/quantum/mlir_quantum/compiler_driver.so)

However, adding a line above this import that does import catalyst solves the problem

In a Jupyter Notebook, neither version works, and I get a similar but more detailed error:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 import catalyst
      3 from catalyst import qjit

File ~/Code/catalyst/frontend/catalyst/__init__.py:64
     58 sys.modules["mlir_quantum._mlir_libs._quantumDialects.quantum"] = types.ModuleType(
     59     "mlir_quantum._mlir_libs._quantumDialects.quantum"
     60 )
     63 from catalyst.ag_utils import AutoGraphError, autograph_source
---> 64 from catalyst.compilation_pipelines import QJIT, CompileOptions, qjit
     65 from catalyst.pennylane_extensions import (
     66     adjoint,
     67     cond,
   (...)
     74     while_loop,
     75 )
     76 from catalyst.utils.exceptions import CompileError

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:40
     38 import catalyst
     39 from catalyst.ag_utils import run_autograph
---> 40 from catalyst.compiler import CompileOptions, Compiler
     41 from catalyst.jax_tracer import trace_to_mlir
     42 from catalyst.pennylane_extensions import QFunc

File ~/Code/catalyst/frontend/catalyst/compiler.py:30
     27 from io import TextIOWrapper
     28 from typing import Any, List, Optional
---> 30 from mlir_quantum.compiler_driver import run_compiler_driver
     32 from catalyst._configuration import INSTALLED
     33 from catalyst.utils.exceptions import CompileError

ImportError: /home/olivia/.conda/envs/catalyst/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/olivia/Code/catalyst/frontend/catalyst/../../mlir/build/python_packages/quantum/mlir_quantum/compiler_driver.so)
  • Reproduces how often: (What percentage of the time does it reproduce?)

Erratic (see description above)

  • System information: (post the output of import pennylane as qml; qml.about())

Running on Ubuntu 22.04.

Catalyst is installed from source off of main branch commit 422dc14.

Clang++ is version 14.0.0-1ubuntu1.1.

Output of qml.about() (packages installed from Catalyst's requirements.txt file):

Name: PennyLane
Version: 0.32.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/olivia/.conda/envs/catalyst/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: pennylane-catalyst, PennyLane-Lightning

Platform info:           Linux-6.2.0-33-generic-x86_64-with-glibc2.35
Python version:          3.11.0
Numpy version:           1.23.5
Scipy version:           1.10.0
Installed devices:
- default.gaussian (PennyLane-0.32.0)
- default.mixed (PennyLane-0.32.0)
- default.qubit (PennyLane-0.32.0)
- default.qubit.autograd (PennyLane-0.32.0)
- default.qubit.jax (PennyLane-0.32.0)
- default.qubit.tf (PennyLane-0.32.0)
- default.qubit.torch (PennyLane-0.32.0)
- default.qutrit (PennyLane-0.32.0)
- null.qubit (PennyLane-0.32.0)
- lightning.qubit (PennyLane-Lightning-0.32.0)

Source code and tracebacks

N/A

Additional information

Let me know what else would be useful!

[Frontend] JAX integration fails to correctly compute `argnums` in certain cases

Below are two bugs I've encountered using the JAX integration and issues with custom_jvp computing the argnums. I'm not sure if these are related or not, so have combined them for now.

Note: they may be addressed by #125.

Bug 1

Consider

dev = qml.device("lightning.qubit", wires=2)

@qjit
@qml.qnode(dev)
def circuit(x, y):
    qml.RX(x, wires=0)
    qml.RY(y, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(1))

x = jnp.array(0.5)
y = jnp.array(0.1)

We see that some combinations of argnums do not work:

>>> jax.grad(circuit, argnums=0)(x, y)  # works
Array(-0.47703045, dtype=float64, weak_type=True)
>>> jax.grad(circuit, argnums=1)(x, y)  # fails
File ~/anaconda3/envs/Braket/lib/python3.10/site-packages/catalyst/compilation_pipelines.py:654, in JAX_QJIT.compute_jvp(self, primals, tangents)
    652         deriv_idx = arg_idx * len(results) + res_idx
    653         num_axes = 0 if tangent.ndim == 0 else 1
--> 654         jvp = jnp.tensordot(jnp.transpose(derivatives[deriv_idx]), tangent, axes=num_axes)
    655         jvps[res_idx] = jvps[res_idx] + jvp
    657 if len(results) == 1:

IndexError: list index out of range

Using jax.jit, it always works:

>>> jax.grad(jax.jit(circuit), argnums=0)(x, y)
Array(-0.47703045, dtype=float64, weak_type=True)
>>> jax.grad(jax.jit(circuit), argnums=1)(x, y)
Array(-0.08761211, dtype=float64, weak_type=True)

Bug 2

Consider the following circuit, which has two arguments; one a float and differentiable, and one an integer and not differentiable:

dev = qml.device("lightning.qubit", wires=2)

@qjit
@qml.qnode(dev)
def circuit(params, n):
    
    def ansatz(i, x):
        qml.RX(x[i, 0], wires=0)
        qml.RY(x[i, 1], wires=1)
        qml.CNOT(wires=[0, 1])
        return x

    catalyst.for_loop(0, n, 1)(ansatz)(params.reshape(-1, 2))

    return qml.expval(qml.PauliZ(1))

n = 3
params = jnp.ones(2 * n)

Since the second argument is not differentiable, we must pass argnums=0 when computing the grad:

>>> jax.grad(circuit, argnums=0)(params, n)
Array([-0.3262047 ,  0.42661714,  0.16441727, -0.45305815,  0.49515634,
        0.95400696], dtype=float64)

However, if circuit is jitted with JAX, this no longer works. Internally, custom_jvp will determine argnum=[0, 1], which is incorrect:

>>> jax.grad(jax.jit(circuit), argnums=0)(params, n)
TypeError: Catalyst.grad only supports differentiation on floating-point arguments, got 'int64' at position 1.

This causes issues trying to use Catalyst with jaxopt, as jaxopt will automatically under-the-hood JIT any function that it is trying to optimize.

The `braket.qubit.aws` device gives an error if you provide an S3 bucket prefix

For example, consider the following code:

s3 = ("bucket", "prefix")
arn = "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
dev = qml.device("braket.aws.qubit", device_arn=arn, s3_destination_folder=s3, wires=2)

@qjit
def workflow(x: float, y: float):

    @qml.qnode(dev)
    def circuit(x, y):
        qml.RX(y * x, wires=0)
        qml.RX(x * 2, wires=1)
        return qml.expval(qml.PauliY(0) @ qml.PauliZ(1))

    return catalyst.grad(circuit)(x, y)

Running this gives an obscure error:

>>> workflow(0.1, 0.2)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[30], line 1
----> 1 workflow(0.1, 0.2)

File ~/anaconda3/envs/Braket/lib/python3.10/site-packages/catalyst/compilation_pipelines.py:571, in QJIT.__call__(self, *args, **kwargs)
    567         self.jaxed_qfunc = JAX_QJIT(self)
    569     return self.jaxed_qfunc(*args, **kwargs)
--> 571 return self.compiled_function(*args, **kwargs)

File ~/anaconda3/envs/Braket/lib/python3.10/site-packages/catalyst/compilation_pipelines.py:392, in CompiledFunction.__call__(self, *args, **kwargs)
    388 abi_args, _buffer = CompiledFunction.args_to_memref_descs(self.restype, args)
    390 numpy_dict = {nparr.ctypes.data: nparr for nparr in _buffer}
--> 392 result = CompiledFunction._exec(
    393     self.shared_object_file,
    394     self.func_name,
    395     self.restype,
    396     numpy_dict,
    397     *abi_args,
    398 )
    400 return result

File ~/anaconda3/envs/Braket/lib/python3.10/site-packages/catalyst/compilation_pipelines.py:229, in CompiledFunction._exec(shared_object_file, func_name, has_return, numpy_dict, *args)
    226 setup(ctypes.c_int(argc), array_of_char_ptrs)
    227 result_desc = type(args[0].contents) if has_return else None
--> 229 retval = wrapper.wrap(function, args, result_desc, mem_transfer, numpy_dict)
    230 if len(retval) == 0:
    231     retval = None

RuntimeError: [/__w/catalyst/catalyst/runtime/lib/backend/openqasm/OpenQasmRunner.hpp][Line:356][Function:Expval] Error in Catalyst Runtime: s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively.

However, if the bucket prefix is removed, everything works fine:

arn = "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
dev = qml.device("braket.aws.qubit", device_arn=arn, wires=2)

@qjit
def workflow(x: float, y: float):

    @qml.qnode(dev)
    def circuit(x, y):
        qml.RX(y * x, wires=0)
        qml.RX(x * 2, wires=1)
        return qml.expval(qml.PauliY(0) @ qml.PauliZ(1))

    return catalyst.grad(circuit)(x, y)
>>> workflow(0.1, 0.2)
array(-0.18802786)

[MLIR, CompileDriver] Run entry point signature detection only if needed

Currently compile driver always attempts to detect the entry point signature (see getJITFunction) and prints messages if it can't find one, which is is only relevant for IR string input. Usually we compile Python functions where we already know signatures by design. The suggestions are:

  • Change the compiler driver API so we can explicitly ask it to run the inference, and only run it if asked
  • Maybe find a better name for this function (JIT used to be the prefix of our entry points, but we changed this)
  • Maybe also unify the verbose printing mechanisms, namely emitDIag/emitRemark of MLIR and our own CO_MSG macros. Note, that we might want to keep CO_MSG because emit*s alone get overly verbose when stack traces are enabled. Ref #317

[Frontend] Support scalars of type complex?

The following simple program doesn't work, if the type is jnp.complex128, but works with float64. Should we support complex scalars as well?

from catalyst import qjit
from jax.numpy import array, complex128, float64

@qjit
def main():
    return array(0, dtype=complex128) # float64 works

main()

The error is

Traceback (most recent call last):
  File "/workspace/src/synthesis/issue1.py", line 10, in <module>
    main()
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 598, in __call__
    return self.compiled_function(*args, **kwargs)
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 444, in __call__
    result = CompiledFunction._exec(
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 332, in _exec
    retval = CompiledFunction.return_value_ptr_to_numpy(result) if result else None
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 297, in return_value_ptr_to_numpy
    jax_array = jax.numpy.asarray(numpy_array)
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2036, in asarray
    return array(a, dtype=dtype, copy=False, order=order)  # type: ignore
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1982, in array
    dtype = dtypes._lattice_result_type(*leaves)[0]
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 469, in _lattice_result_type
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 469, in <genexpr>
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 312, in _dtype_and_weaktype
    return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 464, in dtype
    raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
TypeError: Value '(0., 0.)' with dtype [('real', '<f8'), ('imag', '<f8')] is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

[JAX] Missing support for sub-array updates with `array.at[]`

At the moment there seems to be no way to lower the scatter operation from the MHLO dialect into standard MLIR dialects used by us. This op is generated by JAX for their "in-place" array updates, see the following example:

from jax import numpy as jnp
from catalyst import qjit

@qjit
def test(idx: int):
    res = jnp.array([0, 0, 0])
    res = res.at[idx].add(1)
    return res

test(1)

The example fails with the following error message (raised when quantum-opt tries to parse the output from the mlir-hlo-opt tool):

/tmp/tmps5we6mfm/test.nohlo.mlir:33:24: error: operation being parsed with an unregistered dialect. If this is intended, please use -allow-unregistered-dialect with the MLIR tool used
    %8 = "mhlo.scatter"(%cst_1, %expanded, %cst) ({

This not necessarily an issue within Catalyst, and may need to be resolved upstream in the MLIR-HLO project.

Note: The tensor dialect seems to have a scatter op.


A workaround for the above example may be achieved as followed:

from jax import numpy as jnp
from catalyst import qjit

@qjit
def test(idx: int):
    res = jnp.array([0, 0, 0])
    additions = jnp.identity(3)
    res = res + additions[idx]
    return res

test(1)

Compound expressions as conditions in `catalyst.cond` is not supported

Catalyst doesn't support compound expressions as conditions. Running the following example will raise a jax.errors.ConcretizationTypeError.

@qjit
@qml.qnode(qml.device(backend, wires=1))
def circuit(x):
    @cond(x > 4 or True)
    def cond_fn():
        qml.PauliX(wires=0)

    cond_fn()

    return measure(wires=0)

circuit(2) 
E       jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
E       The problem arose with the `bool` function. 

[Frontend] The lack of a device specific decomposition logic

We currently serve all backend devices in a way that they all support an identical set of quantum gates. This design assumption would lead to the failure of decomposition rules for the OpenQasm3/Braket device. We may need to implement a device specific decomposition logic for every supported backend devices in the frontend to properly tackle this issue.

Check the test_unsupported_gate_braket test unit in test_aws_braket_devices.py for an example.

[MLIR] Gradient transform only considers `CustomOp` for differentiation

Other parametrized gate operations like the MultiRZ op are currently simply ignored in the gradient computation.

Relevant place in the code:

// Insert gate parameters into the params buffer.
argMapFn.walk([&](quantum::CustomOp gate) {
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(gate);
if (!gate.getParams().empty()) {

[Frontend] Optimize JVP implementation with symbolic zero support

#96 introduced support for JAX transformations such as grad, jacobian, or jvp.

This works by implementing a custom_jvp for our QJIT functions. To avoid computing Jacobians for function arguments that do not participate in differentiation, the current implementation checks whether the tangent vector provided for an argument is a concrete value and filled with all zeros.

Recently, JAX has introduced symbolic zeros to be passes as tangent vectors instead of materializing concrete values, which should be more efficient.

Update the implementation of custom_jvp to use symbolic zeros once we upgrade our dependency chain to JAX >= 0.4.9.

[MLIR] Dialect `complex' not found for custom op 'complex.constant'

Consider the following program running Catalyst 0.1.0 and PennyLane 0.28:

import pennylane as qml
import jax.numpy as jnp
from catalyst import qjit

@qjit
def workflow(params):
    @qml.qnode(qml.device("lightning.qubit", wires=1),)
    def circuit(phi):
        qml.PhaseShift(phi=phi, wires=[0])
        return qml.state()

    phi = params[0]
    for _ in range(10):
        state = circuit(phi)[0]
        phi = jnp.mean(state).real
        # phi = (state[0] + state[1])/2).real       # In contrast, this approach works
    return state

print(workflow(jnp.array([jnp.pi], dtype=jnp.float64)))

An attempt to run the program results in the following compilation error. Other expressions involving complex/float computations seem to cause similar errors.

/tmp/tmpkjcy955a/workflow.nohlo.mlir:6:12: error: Dialect `complex' not found for custom op 'complex.constant'
    %cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
           ^
/tmp/tmpkjcy955a/workflow.nohlo.mlir:6:12: note: Registered dialects: arith, builtin, func, gradient, index, linalg, llvm, memref, quantum, scf, tensor ; for more info on dialect registration see https://mlir.llvm.org/getting_started/Faq/#registered-loaded-dependent-whats-up-with-dialects-management
Traceback (most recent call last):
  File "/home/sergei/complex_dialiect_issue.py", line 18, in <module>
    print(workflow(jnp.array([jnp.pi], dtype=jnp.float64)))
  File "/home/sergei/.local/lib/python3.8/site-packages/catalyst/compilation_pipelines.py", line 611, in __call__
    self.compiled_function = self.compile()
  File "/home/sergei/.local/lib/python3.8/site-packages/catalyst/compilation_pipelines.py", line 585, in compile
    shared_object, self._llvmir = compiler.compile(
  File "/home/sergei/.local/lib/python3.8/site-packages/catalyst/compiler.py", line 318, in compile
    buff = bufferize_tensors(mlir)
  File "/home/sergei/.local/lib/python3.8/site-packages/catalyst/compiler.py", line 195, in bufferize_tensors
    subprocess.run(command, stdout=file, check=True)
  File "/usr/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/home/sergei/.local/lib/python3.8/site-packages/catalyst/bin/quantum-opt', '/tmp/tmpkjcy955a/workflow.nohlo.mlir', '--lower-gradients', '--gradient-bufferize', '--scf-bufferize', '--convert-tensor-to-linalg', '--convert-elementwise-to-linalg', '--arith-bufferize', '--empty-tensor-to-alloc-tensor', '--bufferization-bufferize', '--tensor-bufferize', '--linalg-bufferize', '--tensor-bufferize', '--quantum-bufferize', '--func-bufferize', '--finalizing-bufferize', '--buffer-hoisting', '--convert-bufferization-to-memref', '--canonicalize', '--cse']' returned non-zero exit status 1.

[Tests] Test all qubit templates

All templates that work with qubit devices should be tested. At the moment most templates are tested, but there are a few templates which are missing from the tests.

  • HilbertSchmidt
  • LocalHilbertSchmidt
  • ArbitraryUnitary
  • ParticleConservingU1
  • ParticleConservingU2
  • Broadcast (Double odd)
  • Broadcast (Custom)

The tests should be placed in the test_template.py file along with the other templates.

[Frontend] Catalyst for loop does not support negative step size

As the Catalyst control flow maps directly to MLIR control flow from the SCF dialect, the following loop configuration is not supported for example:

@qjit
def revc():
    @for_loop(9, -1, -1)
    def loop(i, agg):
        return agg + 1
    return loop(0)

def revi():
    agg = 0
    for y in range(9, -1, -1):
        agg += 1
    return agg

assert revc.mlir
assert revc() == revi()

See the documentation in MLIR for more details of the structured control flow operations: https://mlir.llvm.org/docs/Dialects/SCFDialect/#scffor-mlirscfforop.

However, while negative bounds and/or stepsizes are not directly supported in MLIR, I think we could support them in the frontend via a transformation on the loop index during the lowering phase.

Gradients and calls to QNodes in a JITed context

Issue description

  • Expected behavior:
from catalyst import qjit, grad                                                                      
import pennylane as qml
import numpy as np
    
@qml.qnode(qml.device("lightning.qubit", wires=1))                                                   
def id1(x):
    return qml.expval(qml.PauliZ(0))
    
@qjit() 
def workflow(a):
    x = id1(a)
    g = grad(id1, method="adj")                                                                      
    return g(a)


print(workflow(1.0))

should be the same as:

from catalyst import qjit, grad                                                                      
import pennylane as qml
import numpy as np
    
@qml.qnode(qml.device("lightning.qubit", wires=1))                                                   
def id1(x):
    return qml.expval(qml.PauliZ(0))
    
@qjit() 
def workflow(a):
    # x = id1(a) # <--- REMOVE THIS LINE
    g = grad(id1, method="adj")                                                                      
    return g(a)


print(workflow(1.0))
  • Actual behavior:
RuntimeError: [/home/ali/git/xanadu/catalyst/runtime/lib/backend/LightningSimulator.cpp][Line:67][Function:StopTapeRecording] Error in Catalyst Runtime: Cannot stop an already stopped cache manager
  • System information: main branch

I originally suspected something related to deallocOp but I might be wrong.

[sc-38691]

[linux x86-64] Invalid wire is valid measurement.

Issue description

Test case:

from catalyst import qjit, measure
import pennylane as qml

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x: int):
    return measure(x)
circuit(0)
circuit(1)
circuit(2)
  • Expected behavior: An assertion should trigger for circuit(2)

  • Actual behavior: No assertion is triggered.

  • Reproduces how often: Always (on x86 - linux)

[Frontend] JAX integration fails if the input parameters are not a 1D array

Consider the following code:

dev = qml.device("lightning.qubit", wires=2)

@qjit
@qml.qnode(dev)
def circuit(params, n):
    
    def ansatz(i, x):
        qml.RX(x[i, 0], wires=0)
        qml.RY(x[i, 1], wires=1)
        qml.CNOT(wires=[0, 1])
        return x

    catalyst.for_loop(0, n, 1)(ansatz)(params)

    return qml.expval(qml.PauliZ(1))

This will fail, as the JAX integration does not appear to work with input parameters with dimension larger than 1:

>>> params = jnp.array([[0.54, 0.3154], [0.654, 0.123]])
>>> jax.grad(circuit, argnums=0)(params, 2)
TypeError: Custom JVP rule must produce primal and tangent outputs with equal shapes and dtypes, but got float64[] and float64[2,2] respectively.

Note that this seems to work well with the recently added catalyst.jvp transform:

>>> f = lambda p: circuit(p, 2)  # curry the function as I can't get argnum working
>>> @qjit
... def jvp(primals, tangents):
...     return catalyst.jvp(f, [primals], [tangents])
>>> jvp(params, jnp.ones((2, 2))))
[array(0.76127544), array(-0.73995891)]

Comparing to JAX to verify:

>>> jax.jvp(circuit.qfunc, (params, 2), (jnp.ones((2, 2)), np.zeros([], dtype=jax.float0)))
(Array(0.76127544, dtype=float64), Array(-0.73995881, dtype=float64))

Since switching the custom_jvp implementation to use catalyst.jvp brings about other recommendations, this is likely the best fix.

Infinite loops are terminating if no quantum gates are used

Consider the following program containing an infinite loop.

import pennylane as qml
from catalyst import qjit, while_loop

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def main():
    @while_loop(lambda i:True)
    def whileloop1(i):
        # qml.Hadamard(0)
        return i
    whileloop1(0)
    return qml.state()

main()

The expected behavior is freezing in the loop, but in fact this program terminates immediately with the result equal to Array([1.+0.j, 0.+0.j], dtype=complex128).
Note, that uncommenting the gate instruction brings the expected behavior.
The current Catalyst git hash is 9dfbb32

Some assertions should be turned into exceptions

As a general rule of thumb:

  • we typically use assertions within the codebase for checking runtime consistency/ensuring the logic is not in an unexpected state. These raise AssertionError exceptions, and can be 'turned off' when calling Python via -O (I think).

  • we typically use exception handling for input validation. That is, to raise specific errors for input that a user may get wrong, to alert them of their error. By subclassing and providing Catalyst-specific exception subclasses, we give users good control over error handling in Catalyst.

I noticed that within Catalyst (especially the pennylane extensions), we are using assertions where we should be using exceptions:

assert method in {"fd", "ps", "adj"}, "invalid differentiation method"

In this particular case, we should be raising ValueError (or if applicable, custom subclasses).

In addition, the error messages should also be instructing the user how they can fix the error if possible. In this example, the error message should be listing the allowed gradient methods.

[Frontent] Review the topological sorting algorithm used in the two-staged tracing

The future improvement reminder issue. The nested tape tracer introduced by #240 relies on the topological sorting which might need reviewing. The sorting is used as the final step in the merging of two lists of JAX equations: one list containing classical equation and the other one induced by the quantum tape. Currently the sorting uses JAX variables for building dependencies which may lead to correct yet ill-ordered flow of equations.

Some questions worth answering:

  • Could/should we preserve the original flow of equations as close as possible?
  • Could we simplify the sorting algorithm by removing Box used for dependency tracking?

[MLIR || Runtime] Hamiltonian Gradient.

While converting the QAOA demo to be compatible with Catalyst I encountered this behaviour,

Compiling deriv with method="fd" returns results:

import catalyst
from catalyst import qjit

import pennylane as qml
from pennylane import numpy as np

import jax.numpy as jnp
import google_benchmark as benchmark
import jax
from jax.core import ShapedArray

from pennylane import qaoa
from pennylane import numpy as np
from matplotlib import pyplot as plt
import networkx as nx

edges = [(0, 1), (1, 2), (2, 0), (2, 3)]
graph = nx.Graph(edges)

cost_h, mixer_h = qaoa.min_vertex_cover(graph, constrained=False)

def qaoa_layer(gamma, alpha):
    qaoa.cost_layer(gamma, cost_h)
    qaoa.mixer_layer(alpha, mixer_h)

wires = range(4)
depth = 2

def circuit(params, **kwargs):
    for w in wires:
        qml.Hadamard(wires=w)
    qml.layer(qaoa_layer, depth, params[0], params[1])

dev = qml.device("lightning.qubit", wires=wires)

@qml.qnode(dev)
def cost_function(params):
    circuit(params)
    return qml.expval(qml.Hamiltonian(np.array(cost_h.coeffs), cost_h.ops))

@qjit(keep_intermediate=True)
def deriv(theta: jax.core.ShapedArray([2, 2], float)):
    diff = catalyst.grad(cost_function, method="fd") # <--- Line of interest
    h = diff(theta)
    return h

params = jnp.array([[0.5, 0.5], [0.5, 0.5]])
print(deriv(params))

However, compiling with method="adj"produces the following runtime error:

RuntimeError: [/home/erick.ochoalopez/catalyst/runtime/lib/backend/LightningSimulator.cpp][Line:510][Function:Gradient] Error in Catalyst Runtime: Assertion: num_train_params <= gradients[obs_idx].size()

Not sure if this is something related to the Hamiltonian, the lowering of gradients or something else yet.

Getting different values with @qjit

Before posting an issue

Search existing GitHub issues to make sure the issue does not already exist:
https://github.com/PennyLaneAI/catalyst/issues

If posting a Catalyst issue, delete everything above the dashed line, and fill
in the template. If the issue is a bug, start the title of the issue with [BUG].

If making a feature request, delete the following template and describe, in detail,
the feature and why it is needed.


Issue description

We use the following collaboration notebook https://colab.research.google.com/drive/1PU-xnvdxVqYE8DKMyJdREyYIdrdVpZvl#scrollTo=OHDvAThytfjI.

It seems that the "patata" function in this notebook gives us different results depending on whether we do @qjit or not at the top of the function. In the correct solution we should only see 0s or 3s, but with qjit we get different values.

[Frontend] Catalyst control flow is unsupported with `qml.ctrl` and `qml.adjoint`

Passing functions to the qml.ctrl and qml.adjoint meta-operations which contain Catalyst control flow primitives is currently unsupported. The following example demonstrates the issue for the for_loop op:

import pennylane as qml
from catalyst import qjit, for_loop

def inner():
    @for_loop(0, 10, 1)
    def loop(j):
        qml.X(0)
    loop()

@qml.qnode(device=qml.device("lightning.qubit", wires=2, shots=1))
def outer():
    qml.ctrl(inner, control=0)()
    return qml.sample()

print(qjit(outer)())

Gradients w/ Hamiltonians fail to work if the coefficients are not vanilla NumPy float arrays

Consider the following circuit:

dev = qml.device("lightning.qubit", wires=2)

coeffs = np.array([0.1, 0.2])
terms = [qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0)]
H = qml.Hamiltonian(coeffs, terms)

@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x[0], wires=0)
    qml.RY(x[1], wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(H)

params = jnp.array([0.3, 0.4])

The gradient works fine as long as coeffs is a vanilla NumPy float array:

>>> jax.grad(circuit)(params)
Array([-0.05443844, -0.07440512], dtype=float64)

However, if coeffs is a PL NumPy or JAX NumPy array, this will fail:

>>> coeffs = jnp.array([0.1, 0.2])
...
>>> jax.grad(circuit)(params)
TypeError: get(): incompatible function arguments. The following argument types are supported:
    1. (array: buffer, signless: bool = True, type: Optional[jaxlib.mlir._mlir_libs._mlir.ir.Type] = None, shape: Optional[List[int]] = None, context: mlir.ir.Context = None) -> jaxlib.mlir._mlir_libs._mlir.ir.DenseElementsAttr

Invoked with: Array([0.1, 0.2], dtype=float64)

Similarly, execution will fail if coeffs is an integer array.

[BUG] Error when taking `adjoint` of subroutines with `wires` arguments

Issue description

  • Expected behavior: A list of wires can be passed to a subroutine, and the subroutine executed both in its "regular" form as well as adjoint when qjitting.

  • Actual behavior: Errors are thrown in the adjoint call.

  • Reproduces how often: (What percentage of the time does it reproduce?) Always

  • System information: (post the output of import pennylane as qml; qml.about())

I am on this branch/commit but was experiencing this previously on the main branch as well.

Name: PennyLane
Version: 0.32.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/olivia/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: pennylane-catalyst, PennyLane-Lightning, PennyLane-qiskit

Platform info:           Linux-6.2.0-35-generic-x86_64-with-glibc2.35
Python version:          3.11.0
Numpy version:           1.23.5
Scipy version:           1.10.1
Installed devices:
- default.gaussian (PennyLane-0.32.0)
- default.mixed (PennyLane-0.32.0)
- default.qubit (PennyLane-0.32.0)
- default.qubit.autograd (PennyLane-0.32.0)
- default.qubit.jax (PennyLane-0.32.0)
- default.qubit.tf (PennyLane-0.32.0)
- default.qubit.torch (PennyLane-0.32.0)
- default.qutrit (PennyLane-0.32.0)
- null.qubit (PennyLane-0.32.0)
- qiskit.aer (PennyLane-qiskit-0.32.0)
- qiskit.basicaer (PennyLane-qiskit-0.32.0)
- qiskit.ibmq (PennyLane-qiskit-0.32.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.32.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.32.0)
- qiskit.remote (PennyLane-qiskit-0.32.0)
- lightning.qubit (PennyLane-Lightning-0.32.0)

Source code and tracebacks

This version works:

import catalyst
import pennylane as qml
from catalyst import qjit, adjoint

dev = qml.device("lightning.qubit", wires=3)

def subroutine_no_wires():
    for wire in range(3):
        qml.PauliX(wire)
        
@qjit(autograph=True)
@qml.qnode(dev)
def test_function():
    adjoint(subroutine_no_wires)()
    return qml.probs()

However, if we instead do:

def subroutine(wires):
    for wire in wires:
        qml.PauliX(wire)
        
@qjit(autograph=True)
@qml.qnode(dev)
def test_function():
    adjoint(subroutine)(dev.wires)
    return qml.probs()

we obtain the following traceback:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/jax/_src/api_util.py:581, in shaped_abstractify(x)
    580 try:
--> 581   return _shaped_abstractify_handlers[type(x)](x)
    582 except KeyError:

KeyError: <class 'pennylane.wires.Wires'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[6], line 5
      2     for wire in wires:
      3         qml.PauliX(wire)
----> 5 @qjit(autograph=True)
      6 @qml.qnode(dev)
      7 def test_function():
      8     subroutine(dev.wires)
      9     adjoint(subroutine)(dev.wires)

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:948, in qjit.<locals>.wrap_fn(fn)
    947 def wrap_fn(fn):
--> 948     return QJIT(
    949         fn, CompileOptions(verbose, logfile, target, keep_intermediate, pipelines, autograph)
    950     )

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:508, in QJIT.__init__(self, fn, compile_options)
    506 if parameter_types is not None:
    507     self.user_typed = True
--> 508     self.mlir_module = self.get_mlir(*parameter_types)
    509     if self.compile_options.target == "binary":
    510         self.compiled_function = self.compile()

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:555, in QJIT.get_mlir(self, *args)
    550 self.c_sig = CompiledFunction.get_runtime_signature(*args)
    552 with Patcher(
    553     (qml.QNode, "__call__", QFunc.__call__),
    554 ):
--> 555     mlir_module, ctx, jaxpr, self.shape = trace_to_mlir(self.user_function, *self.c_sig)
    557 inject_functions(mlir_module, ctx)
    558 self._jaxpr = jaxpr

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:276, in trace_to_mlir(func, *args, **kwargs)
    273 mlir_fn_cache.clear()
    275 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 276     jaxpr, shape = jax.make_jaxpr(func, return_shape=True)(*args, **kwargs)
    278 return jaxpr_to_mlir(func.__name__, jaxpr, shape)

    [... skipping hidden 6 frame]

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:149, in QFunc.__call__(self, *args, **kwargs)
    146     device = self.device
    148 with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION):
--> 149     jaxpr, shape = trace_quantum_function(self.func, device, args, kwargs)
    151 retval_tree = tree_structure(shape)
    153 def _eval_jaxpr(*args):

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:535, in trace_quantum_function(f, device, args, kwargs)
    532     in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
    533     with QueuingManager.stop_recording(), quantum_tape:
    534         # Quantum tape transformations happen at the end of tracing
--> 535         ans = wffa.call_wrapped(*in_classical_tracers)
    536     out_classical_tracers_or_measurements = [
    537         (trace.full_raise(t) if isinstance(t, DynamicJaxprTracer) else t) for t in ans
    538     ]
    540 # (2) - Quantum tracing

    [... skipping hidden 1 frame]

File /tmp/__autograph_generated_fileeeja8hs_.py:9, in outer_factory.<locals>.inner_factory.<locals>.test_function_1()
      7 with ag__.FunctionScope('test_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
      8     ag__.converted_call(subroutine, (dev.wires,), None, fscope)
----> 9     ag__.converted_call(ag__.converted_call(adjoint, (subroutine,), None, fscope), (dev.wires,), None, fscope)
     10     return ag__.converted_call(qml.probs, (), None, fscope)

File ~/Code/catalyst/frontend/catalyst/ag_primitives.py:441, in converted_call(fn, args, kwargs, caller_fn_scope, options)
    438     new_qnode = qml.QNode(qnode_call_wrapper, device=fn.device, diff_method=fn.diff_method)
    439     return new_qnode()
--> 441 return tf_converted_call(fn, args, kwargs, caller_fn_scope, options)

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:377, in converted_call(f, args, kwargs, caller_fn_scope, options)
    374   return _call_unconverted(f, args, kwargs, options)
    376 if not options.user_requested and conversion.is_allowlisted(f):
--> 377   return _call_unconverted(f, args, kwargs, options)
    379 # internal_convert_user_code is for example turned off when issuing a dynamic
    380 # call conversion from generated code while in nonrecursive mode. In that
    381 # case we evidently don't want to recurse, but we still have to convert
    382 # things like builtins.
    383 if not options.internal_convert_user_code:

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:460, in _call_unconverted(f, args, kwargs, options, update_cache)
    458 if kwargs is not None:
    459   return f(*args, **kwargs)
--> 460 return f(*args)

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:1814, in adjoint.<locals>._callable(*args, **kwargs)
   1813 def _callable(*args, **kwargs):
-> 1814     return _call_handler(*args, _callee=f, **kwargs)

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:1788, in adjoint.<locals>._call_handler(_callee, *args, **kwargs)
   1786 with EvaluationContext.frame_tracing_context(ctx) as inner_trace:
   1787     in_classical_tracers, _ = tree_flatten((args, kwargs))
-> 1788     wffa, in_avals, _ = deduce_avals(_callee, args, kwargs)
   1789     arg_classical_tracers = _input_type_to_tracers(inner_trace.new_arg, in_avals)
   1790     quantum_tape = QuantumTape()

File ~/Code/catalyst/frontend/catalyst/utils/jax_extras.py:266, in deduce_avals(f, args, kwargs)
    264 flat_args, in_tree = tree_flatten((args, kwargs))
    265 wf = wrap_init(f)
--> 266 in_avals, keep_inputs = list(map(shaped_abstractify, flat_args)), [True] * len(flat_args)
    267 in_type = tuple(zip(in_avals, keep_inputs))
    268 wff, out_tree_promise = flatten_fun(wf, in_tree)

    [... skipping hidden 1 frame]

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/jax/_src/api_util.py:572, in _shaped_abstractify_slow(x)
    570   dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
    571 else:
--> 572   raise TypeError(
    573       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    574       "does not have a dtype attribute")
    575 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    576                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'pennylane.wires.Wires'> as an abstract array; it does not have a dtype attribute

Additionally, if we pass the wires as a jnp.array, a separate error occurs in the adjoint call:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[18], line 5
      2     for wire in wires:
      3         qml.PauliX(wire)
----> 5 @qjit(autograph=True)
      6 @qml.qnode(dev)
      7 def test_function():
      8     #subroutine(dev.wires)
      9     adjoint(subroutine)(jnp.array(dev.wires))
     10     return qml.probs()

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:948, in qjit.<locals>.wrap_fn(fn)
    947 def wrap_fn(fn):
--> 948     return QJIT(
    949         fn, CompileOptions(verbose, logfile, target, keep_intermediate, pipelines, autograph)
    950     )

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:508, in QJIT.__init__(self, fn, compile_options)
    506 if parameter_types is not None:
    507     self.user_typed = True
--> 508     self.mlir_module = self.get_mlir(*parameter_types)
    509     if self.compile_options.target == "binary":
    510         self.compiled_function = self.compile()

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:555, in QJIT.get_mlir(self, *args)
    550 self.c_sig = CompiledFunction.get_runtime_signature(*args)
    552 with Patcher(
    553     (qml.QNode, "__call__", QFunc.__call__),
    554 ):
--> 555     mlir_module, ctx, jaxpr, self.shape = trace_to_mlir(self.user_function, *self.c_sig)
    557 inject_functions(mlir_module, ctx)
    558 self._jaxpr = jaxpr

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:276, in trace_to_mlir(func, *args, **kwargs)
    273 mlir_fn_cache.clear()
    275 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 276     jaxpr, shape = jax.make_jaxpr(func, return_shape=True)(*args, **kwargs)
    278 return jaxpr_to_mlir(func.__name__, jaxpr, shape)

    [... skipping hidden 6 frame]

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:149, in QFunc.__call__(self, *args, **kwargs)
    146     device = self.device
    148 with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION):
--> 149     jaxpr, shape = trace_quantum_function(self.func, device, args, kwargs)
    151 retval_tree = tree_structure(shape)
    153 def _eval_jaxpr(*args):

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:535, in trace_quantum_function(f, device, args, kwargs)
    532     in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
    533     with QueuingManager.stop_recording(), quantum_tape:
    534         # Quantum tape transformations happen at the end of tracing
--> 535         ans = wffa.call_wrapped(*in_classical_tracers)
    536     out_classical_tracers_or_measurements = [
    537         (trace.full_raise(t) if isinstance(t, DynamicJaxprTracer) else t) for t in ans
    538     ]
    540 # (2) - Quantum tracing

    [... skipping hidden 1 frame]

File /tmp/__autograph_generated_filezwu16kl_.py:8, in outer_factory.<locals>.inner_factory.<locals>.test_function_1()
      6 def test_function_1():
      7     with ag__.FunctionScope('test_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
----> 8         ag__.converted_call(ag__.converted_call(adjoint, (subroutine,), None, fscope), (ag__.converted_call(jnp.array, (dev.wires,), None, fscope),), None, fscope)
      9         return ag__.converted_call(qml.probs, (), None, fscope)

File ~/Code/catalyst/frontend/catalyst/ag_primitives.py:441, in converted_call(fn, args, kwargs, caller_fn_scope, options)
    438     new_qnode = qml.QNode(qnode_call_wrapper, device=fn.device, diff_method=fn.diff_method)
    439     return new_qnode()
--> 441 return tf_converted_call(fn, args, kwargs, caller_fn_scope, options)

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:377, in converted_call(f, args, kwargs, caller_fn_scope, options)
    374   return _call_unconverted(f, args, kwargs, options)
    376 if not options.user_requested and conversion.is_allowlisted(f):
--> 377   return _call_unconverted(f, args, kwargs, options)
    379 # internal_convert_user_code is for example turned off when issuing a dynamic
    380 # call conversion from generated code while in nonrecursive mode. In that
    381 # case we evidently don't want to recurse, but we still have to convert
    382 # things like builtins.
    383 if not options.internal_convert_user_code:

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:460, in _call_unconverted(f, args, kwargs, options, update_cache)
    458 if kwargs is not None:
    459   return f(*args, **kwargs)
--> 460 return f(*args)

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2051, in array(object, dtype, copy, order, ndmin)
   2044 out: ArrayLike
   2046 if all(not isinstance(leaf, Array) for leaf in leaves):
   2047   # TODO(jakevdp): falling back to numpy here fails to overflow for lists
   2048   # containing large integers; see discussion in
   2049   # https://github.com/google/jax/pull/6047. More correct would be to call
   2050   # coerce_to_array on each leaf, but this may have performance implications.
-> 2051   out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)  # type: ignore[arg-type]
   2052 elif isinstance(object, Array):
   2053   assert object.aval is not None

TypeError: Wires.__array__() takes 1 positional argument but 2 were given

Additional information

Any additional information, configuration or data that might be necessary
to reproduce the issue.

The `do_queue` deprecation in PennyLane breaks Catalyst tests

Looks like the do_queue deprecation in PennyLane led to failed tests in Catalyst. The error messages using PennyLane 0.31.0-dev:

warnings.warn(do_queue_deprecation_warning, UserWarning)
jax._src.traceback_util.UnfilteredStackTrace: UserWarning: The do_queue keyword argument is deprecated. Instead of setting it to False, use qml.queuing.QueuingManager.stop_recording()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.