Coder Social home page Coder Social logo

flaport / sax Goto Github PK

View Code? Open in Web Editor NEW
55.0 7.0 15.0 66.61 MB

S + Autograd + XLA :: S-parameter based frequency domain circuit simulations and optimizations using JAX.

Home Page: https://flaport.github.io/sax

License: Apache License 2.0

Python 49.27% Jupyter Notebook 50.72% Dockerfile 0.01%
s-parameters autograd xla jax circuit simulation simulations optimization photonics photonic-circuit

sax's Introduction

SAX

S + Autograd + XLA

SAX LOGO

Autograd and XLA for S-parameters - a scatter parameter circuit simulator and optimizer for the frequency domain based on JAX.

The simulator was developed for simulating Photonic Integrated Circuits but in fact is able to perform any S-parameter based circuit simulation. The goal of SAX is to be a thin wrapper around JAX with some basic tools for S-parameter based circuit simulation and optimization. Therefore, SAX does not define any special datastructures and tries to stay as close as possible to the functional nature of JAX. This makes it very easy to get started with SAX as you only need functions and standard python dictionaries. Let's dive in...

Quick Start

Full Quick Start page - Documentation.

Let's first import the SAX library, along with JAX and the JAX-version of numpy:

import sax
import jax
import jax.numpy as jnp

Define a model function for your component. A SAX model is just a function that returns an 'S-dictionary'. For example a directional coupler:

def coupler(coupling=0.5):
    kappa = coupling**0.5
    tau = (1-coupling)**0.5
    sdict = sax.reciprocal({
        ("in0", "out0"): tau,
        ("in0", "out1"): 1j*kappa,
        ("in1", "out0"): 1j*kappa,
        ("in1", "out1"): tau,
    })
    return sdict

coupler(coupling=0.3)
{('in0', 'out0'): 0.8366600265340756,
 ('in0', 'out1'): 0.5477225575051661j,
 ('in1', 'out0'): 0.5477225575051661j,
 ('in1', 'out1'): 0.8366600265340756,
 ('out0', 'in0'): 0.8366600265340756,
 ('out1', 'in0'): 0.5477225575051661j,
 ('out0', 'in1'): 0.5477225575051661j,
 ('out1', 'in1'): 0.8366600265340756}

Or a waveguide:

def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission =  amplitude * jnp.exp(1j * phase)
    sdict = sax.reciprocal({("in0", "out0"): transmission})
    return sdict

waveguide(length=100.0)
{('in0', 'out0'): 0.97953-0.2013j, ('out0', 'in0'): 0.97953-0.2013j}

These component models can then be combined into a circuit:

mzi, _ = sax.circuit(
    netlist={
        "instances": {
            "lft": coupler,
            "top": waveguide,
            "rgt": coupler,
        },
        "connections": {
            "lft,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    }
)

type(mzi)
function

As you can see, the mzi we just created is just another component model function! To simulate it, call the mzi function with the (possibly nested) settings of its subcomponents. Global settings can be added to the 'root' of the circuit call and will be distributed over all subcomponents which have a parameter with the same name (e.g. 'wl'):

wl = jnp.linspace(1.53, 1.57, 1000)
result = mzi(wl=wl, lft={'coupling': 0.3}, top={'length': 200.0}, rgt={'coupling': 0.8})

plt.plot(1e3*wl, jnp.abs(result['in0', 'out0'])**2, label="in0->out0")
plt.plot(1e3*wl, jnp.abs(result['in0', 'out1'])**2, label="in0->out1", ls="--")
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.grid(True)
plt.figlegend(ncol=2, loc="upper center")
plt.show()

output

Those are the basics. For more info, check out the full SAX Quick Start page or the rest of the Documentation.

Installation

You can install SAX with pip:

pip install sax

If you want to be able to run all the example notebooks, you'll need python>=3.10 and you should install the development version of SAX:

pip install 'sax[dev]'

License

Copyright © 2023, Floris Laporte, Apache-2.0 License

sax's People

Contributors

daquintero avatar flaport avatar jan-david-fischbach avatar joamatab avatar simbilod avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

sax's Issues

fstring needs fix

missing f in fstring in line 200 in netlist.py

~/mambaforge/lib/python3.10/site-packages/sax/netlist.py in _model_operations(models, ops, default_models)
    198 
    199         if "model" not in model:
--> 200             raise ValueError(
    201                 "Invalid model dict for '{component}'. Key 'model' not found."
    202             )

ValueError: Invalid model dict for '{component}'. Key 'model' not found.

this should be

raise ValueError(
    201                 f"Invalid model dict for '{component}'. Key 'model' not found."
    202             )

computing circuit components in parallel

Hi @flaport

I'm curious if you have a suggestion for how to compute S-matrices for the circuit components in parallel. For example, let's say the coupler and waveguide functions involve running some simulations and I'd like to kick those off at the same time, is there a way to handle this in the current state of sax or would I need to make a fork and change the internals to do some multi-threading? just curious if you have any ideas about this, thanks!

mzi, _ = sax.circuit(
    netlist={
        "instances": {
            "lft": coupler,
            "top": waveguide,
            "rgt": coupler,
        },
        "connections": {
            "lft,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    }
)

`IndexError`

Hi Floris, we have a notebook demonstrating sax. It seems to error at cell [18] for some users and not for others. There seem to be only minor differences in the dependencies and we can't figure out what is causing this discrepancy.

Do you have any suggestions for things to look into here? We're stumped after testing several different dependencies.

This is the stack trace

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[17], line 2
      1 # how to pass specific parmaeters to each of the sub-functions for the instances
----> 2 s = circuit_fn(splitter={"params": params0}, combiner={"params": 0 * params0}, beta=3, phase_sifter=dict(phi=2.0))

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/saxtypes.py:267, in sdict.<locals>.wrapper(**kwargs)
    265 @functools.wraps(model)
    266 def wrapper(**kwargs):
--> 267     return sdict(model(**kwargs))

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/circuit.py:226, in _flat_circuit.<locals>._circuit(**settings)
    223 for inst_name, model in inst2model.items():
    224     instances[inst_name] = model(**full_settings.get(inst_name, {}))
--> 226 S = evaluate_fn(analyzed, instances)
    227 return S

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/backends/klu.py:122, in evaluate_circuit_klu(analyzed, instances)
    117     idx += len(ports_map)
    119 Sx = jnp.concatenate(
    120     [jnp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1
    121 )
--> 122 CSx = Sx[..., mask]
    123 Ix = jnp.ones((*batch_shape, n_col))
    124 I_CSx = jnp.concatenate([-CSx, Ix], -1)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/array.py:319, in ArrayImpl.__getitem__(self, idx)
    317   return lax_numpy._rewriting_take(self, idx)
    318 else:
--> 319   return lax_numpy._rewriting_take(self, idx)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4152, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4146     if (isinstance(aval, core.DShapedArray) and aval.shape == () and
   4147         dtypes.issubdtype(aval.dtype, np.integer) and
   4148         not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
   4149         isinstance(arr.shape[0], int)):
   4150       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4153 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4154                unique_indices, mode, fill_value)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4230, in _split_index_for_jit(idx, shape)
   4226   raise TypeError(f"JAX does not support string indexing; got {idx=}")
   4228 # Expand any (concrete) boolean indices. We can then use advanced integer
   4229 # indexing logic to handle them.
-> 4230 idx = _expand_bool_indices(idx, shape)
   4232 leaves, treedef = tree_flatten(idx)
   4233 dynamic = [None] * len(leaves)

File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4552, in _expand_bool_indices(idx, shape)
   4550     expected_shape = shape[start: start + _ndim(i)]
   4551     if i_shape != expected_shape:
-> 4552       raise IndexError("boolean index did not match shape of indexed array in index "
   4553                        f"{dim_number}: got {i_shape}, expected {expected_shape}")
   4554     out.extend(np.where(i))
   4555 else:

IndexError: boolean index did not match shape of indexed array in index 1: got (18,), expected (14,)

And when we pip freeze for the erroring case (python 3.11 on ubuntu)

Package                       Version         Editable project location
----------------------------- --------------- ---------------------------------------------------------------
absl-py                       2.1.0
accessible-pygments           0.0.4
alabaster                     0.7.16
annotated-types               0.6.0
anyio                         4.2.0
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
astroid                       3.0.3
asttokens                     2.4.1
async-lru                     2.0.4
attrs                         23.2.0
autograd                      1.6.2
Babel                         2.14.0
beautifulsoup4                4.12.3
black                         23.12.1
bleach                        6.1.0
boto3                         1.23.1
botocore                      1.26.10
cached-property               1.5.2
cachetools                    5.3.2
certifi                       2024.2.2
cffi                          1.16.0
cfgv                          3.4.0
chardet                       5.2.0
charset-normalizer            3.3.2
chex                          0.1.82
click                         8.0.3
cloudpickle                   3.0.0
colorama                      0.4.6
comm                          0.2.1
commonmark                    0.9.1
contourpy                     1.2.0
coverage                      7.4.1
cycler                        0.12.1
dask                          2023.10.1
dataclasses-json              0.6.4
debugpy                       1.8.0
decorator                     5.1.1
defusedxml                    0.7.1
dill                          0.3.8
distlib                       0.3.8
docutils                      0.20.1
entrypoints                   0.4
etils                         1.6.0
exceptiongroup                1.2.0
executing                     2.0.1
fastcore                      1.5.29
fastjsonschema                2.19.1
filelock                      3.13.1
flax                          0.7.4
flow360scripts                0.0.1
fonttools                     4.47.2
fqdn                          1.5.1
fsspec                        2024.2.0
future                        0.18.3
gdspy                         1.6.13
gdstk                         0.9.50
gitdb                         4.0.11
GitPython                     3.1.41
gltflib                       1.0.13
grcwa                         0.1.2
h11                           0.14.0
h2                            4.1.0
h5netcdf                      1.0.2
h5py                          3.10.0
hpack                         4.0.0
httpcore                      1.0.3
httpx                         0.26.0
hyperframe                    6.0.1
identify                      2.5.33
idna                          3.6
imagesize                     1.4.1
importlib-metadata            6.11.0
importlib-resources           6.1.1
iniconfig                     2.0.0
ipykernel                     6.28.0
ipython                       8.21.0
ipywidgets                    8.1.1
isoduration                   20.11.0
isort                         5.13.2
jax                           0.4.14
jaxlib                        0.4.14
jaxtyping                     0.2.25
jedi                          0.19.1
Jinja2                        3.1.3
jmespath                      1.0.1
json5                         0.9.14
jsonpointer                   2.4
jsonschema                    4.21.1
jsonschema-specifications     2023.12.1
jupyter                       1.0.0
jupyter_client                8.6.0
jupyter-console               6.6.3
jupyter_core                  5.7.1
jupyter-events                0.9.0
jupyter-lsp                   2.2.2
jupyter_server                2.12.5
jupyter-server-mathjax        0.2.6
jupyter_server_terminals      0.5.2
jupyterlab                    4.0.12
jupyterlab_pygments           0.3.0
jupyterlab_server             2.25.2
jupyterlab-widgets            3.0.9
kiwisolver                    1.4.5
klujax                        0.2.4
locket                        1.0.0
markdown-it-py                3.0.0
MarkupSafe                    2.1.3
marshmallow                   3.20.2
matplotlib                    3.8.2
matplotlib-inline             0.1.6
mccabe                        0.7.0
mdit-py-plugins               0.4.0
mdurl                         0.1.2
memory-profiler               0.61.0
mistune                       3.0.2
ml-dtypes                     0.3.2
mpmath                        1.3.0
msgpack                       1.0.7
multiprocess                  0.70.16
mypy-extensions               1.0.0
myst-parser                   2.0.0
natsort                       8.4.0
nbclient                      0.8.0
nbconvert                     7.16.1
nbdime                        4.0.1
nbformat                      5.9.2
nbsphinx                      0.9.3
nest_asyncio                  1.6.0
networkx                      2.8.8
nodeenv                       1.8.0
notebook                      7.0.8
notebook_shim                 0.2.3
numpy                         1.26.3
opt-einsum                    3.3.0
optax                         0.1.9
orbax-checkpoint              0.5.3
orjson                        3.9.13
overrides                     7.7.0
packaging                     23.2
pandas                        2.2.0
pandocfilters                 1.5.0
parso                         0.8.3
partd                         1.4.1
pathspec                      0.12.1
pexpect                       4.9.0
pillow                        10.2.0
pip                           23.3.1
pkgutil_resolve_name          1.3.10
platformdirs                  4.2.0
pluggy                        1.4.0
ply                           3.11
pre-commit                    3.6.0
prometheus_client             0.20.0
prompt-toolkit                3.0.43
protobuf                      4.25.2
psutil                        5.9.0
ptyprocess                    0.7.0
pure-eval                     0.2.2
pybind11                      2.11.1
pycparser                     2.21
pydantic                      2.6.3
pydantic_core                 2.16.3
pydata-sphinx-theme           0.15.2
Pygments                      2.17.2
PyJWT                         2.8.0
pylint                        3.0.3
PyMieScatt                    1.8.1.1
pyparsing                     3.1.1
pyproject-api                 1.6.1
pyroots                       0.5.0
pyrsistent                    0.20.0
pyswarms                      1.3.0
pytest                        8.0.0
pytest-timeout                2.2.0
python-dateutil               2.8.2
python-json-logger            2.0.7
pytz                          2024.1
PyYAML                        6.0.1
pyzmq                         25.1.2
qtconsole                     5.5.1
QtPy                          2.4.1
referencing                   0.33.0
requests                      2.28.2
responses                     0.24.1
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rich                          12.5.1
rpds-py                       0.17.1
Rtree                         1.0.1
ruff                          0.2.1
s3transfer                    0.5.2
sax                           0.12.1
scipy                         1.12.0
Send2Trash                    1.8.2
setuptools                    68.2.2
shapely                       2.0.2
signac                        2.1.0
six                           1.16.0
smmap                         5.0.1
sniffio                       1.3.0
snowballstemmer               2.2.0
soupsieve                     2.5
Sphinx                        7.2.6
sphinx-book-theme             1.1.0
sphinx-copybutton             0.5.2
sphinx-sitemap                2.5.1
sphinx-tabs                   3.4.5
sphinxcontrib-applehelp       1.0.8
sphinxcontrib-devhelp         1.0.6
sphinxcontrib-htmlhelp        2.0.5
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          1.0.7
sphinxcontrib-serializinghtml 1.1.10
sphinxemoji                   0.3.1
stack-data                    0.6.2
sympy                         1.12
synced-collections            1.0.0
tensorstore                   0.1.53
terminado                     0.18.0
tidy3d                        2.6.0rc1        /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_frontend
tidy3d-beta                   1.9.0
tidy3d-denormalizer           0.1.0           /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d-denormalizer
tidy3d_pipeline               2.6.0rc1        /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_pipeline
tinycss2                      1.2.1
tmm                           0.1.8
toml                          0.10.2
tomli                         2.0.1
tomlkit                       0.12.3
toolz                         0.12.1
tornado                       6.4
tox                           4.12.1
tqdm                          4.66.1
traitlets                     5.14.1
trimesh                       3.20.0
typeguard                     2.13.3
types-python-dateutil         2.8.19.20240106
typing_extensions             4.9.0
typing-inspect                0.9.0
typing-utils                  0.1.0
tzdata                        2023.4
uri-template                  1.3.0
urllib3                       1.26.18
virtualenv                    20.25.0
vtk                           9.2.6
wcwidth                       0.2.13
webcolors                     1.13
webencodings                  0.5.1
websocket-client              1.7.0
wheel                         0.41.2
widgetsnbextension            4.0.9
xarray                        2023.12.0
zipp                          3.17.0

@momchil-flex

Interfacial components

See issue 1067 in gdsfactory

How easy would it be for sax.circuit to also account for interfaces between components?

I am thinking a flag that, if True,
(1) requires the user to provide models for component pairs that show up connected in the graph
(2) adds a component evaluating these models at every "connection" of the netlist

The obvious use case is to model reflections in direct transitions between waveguides with different widths or bend radii. Currently gdsfactory+SAX throws a critical warning for the former, and evaluates the second without mode mismatch reflections.

Maybe this should be added as an option in gdsfactory.get_netlist()?

Example of desired behaviour for bend --> coupler --> straight :

Input:
circuit, _ = sax.circuit(c.get_netlist(), models=models, interfaces=True)
Output:

"Given Models": [],
"Required Models": ["bend_circular", "coupler_full", "straight"],
"Required Interface Models": [("bend_circular:o2", "coupler_full:o1"), ("coupler_full:o3", "straight:o1")],

If there were open-source EMEs around the required reflection matrices would be very easy to calculate :)

Location of klujax and grad-able solves with KLU

Sax looks like a very interesting and useful package!

I was wondering where the code for klujax is located, since it doesn't seem to be part of sax.

Additionally, as I was reading the docs on the KLU backend, I was surprised to learn that the backend does not support gradients. If I understand the approach correctly, the calculation is a combination of matrix multiplications and a linear solve (using klujax). The gradient for the linear solve operation should just involve another linear solve, meaning that you should be able to define the gradient rule for JAX to make another call to klujax. Am I missing something?

Models of active components

Would it be possible to include models of active components? I want to be able to reference components within the circuit, change an applied voltage and see the effect the phase shift has on the system.

Best,
-Alex

jax as optional dependency

Hi Floris,

what do you think of having jax as an optional dependency?

that way windows users can still run circuit simulations,

maybe we could have a mode where sax tries to import jax and if it fails you can still do some functions (such as circuit sims)

let me know what you think

Issues with Models with complex arguments

Trying to create a simple MZI circuit with where I can give complex parameters to each coupler. For example:

def coupler(S31=1/jnp.sqrt(2), S41=1j/jnp.sqrt(2)) -> sax.SDict:
    coupler_dict = sax.reciprocal(
        {
            ("in0", "out0"): S31,
            ("in0", "out1"): S41,
            ("in1", "out0"): S41,
            ("in1", "out1"): S31,
        }
    )
    return coupler_dict
def MZI_arms(phase_top = 0., phase_bottom = 0) -> sax.SDict:
    _sdict = sax.reciprocal(
        {
            ("in0", "out0"): jnp.exp(1j*phase_top),
            ("in1", "out1"): jnp.exp(1j*phase_bottom),
        }
    )
    return _sdict
mzi, info = sax.circuit(
    netlist={
        "instances": {
            "BS1": "coupler",
            "PS1": "phase_shifter",
            "BS2": "coupler",
            "PS2": "phase_shifter",
        },
        "connections": {
            "BS1,out0": "PS1,in0",
            "BS1,out1": "PS1,in1",
            "PS1,out0": "BS2,in0",
            "PS1,out1": "BS2,in1",
            "BS2,out0": "PS2,in0",
            "BS2,out1": "PS2,in1",
        },
        "ports": {
            "in0": "BS1,in0",
            "in1": "BS1,in1",
            "out0": "PS2,out0",
            "out1": "PS2,out1",
        },
    },
    models={
        "coupler": coupler,
        "phase_shifter": MZI_arms,
    }
)

This returns a warning for me as follows:

[.../lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2089]: ComplexWarning: Casting complex values to real discards the imaginary part
  out_array: Array = lax_internal._convert_element_type(

Is that an issue? I am not sure I should trust the final result...

Using circuits as models for other circuits

I'm trying to build a hierarchical circuit that uses a subcircuit as a model.

Something like:
top_circuit:
instances: {- model
- sub_circuit}

If I naively try to add a circuit as a model I get the following error message: AttributeError: 'function' object has no attribute 'items'

Here is a minimum (non-)working example:


import sax
import numpy as np

def ring(length=200.0, coupling=0.1, prop_loss=0.1, neff=1.4, ng=1.4, wl0=1.5, wl=1.5):

    ring_sax, _ = sax.circuit(
        netlist={
            "instances": {
                "dc": {"component": "coupler", "settings": {"coupling": coupling}},
                "top": {
                    "component": "straight",
                    "settings": {
                        "length": length,
                        "loss": prop_loss, 
                        "neff": neff,
                        "ng": ng,
                        "wl0": wl0,
                        "wl": wl,
                    },
                },
            },
            "connections": {
                "dc,out1": "top,in0",
                "top,out0": "dc,in1",
            },
            "ports": {
                "in0": "dc,in0",
                "out0": "dc,out0",
            },
        },
        models={
            "coupler": sax.models.coupler,
            "straight": sax.models.straight,
        },
    )

    return ring_sax


def ring_test(ring_length=10.0, ring_prop_loss=1.0, ring_coupling=0.001, neff=1.4, ng=1.4, wl0=1.3, wl=1.3):

    r_model, _ = sax.circuit(
        netlist={
            "instances": {
                "wg": {
                    "component": "straight",
                    "settings": {
                        "length": 100.0,
                        "loss": 0.1,
                        "neff": neff,
                        "ng": ng,
                        "wl0": wl0,
                        "wl": wl,}
                },
                "ring_ps": {
                    "component": "ring",
                    "settings": {
                        "length": ring_length,
                        "coupling": ring_coupling,
                        "prop_loss": ring_prop_loss,
                        "neff": neff,
                        "ng": ng,
                        "wl0": wl0,
                        "wl": wl,}
                },
            },
            "connections": {
                "wg,out0":"ring_ps,in0",
            },
            "ports": {
                "in0":"wg,in0",
                "out0":"ring_ps,out0",
            },
        },
        models={
            "straight": sax.models.straight,
            "ring": ring,
        },

    )

    return r_model


wl = np.linspace(1.265, 1.285, 1000)
ring_circuit = ring_test(wl=wl) 
ring_circuit()

If in the models dictionary I do "ring": ring(),, then there's no error but the results are not correct (it acts as if there were no ring).

Am I missing something?

Thank you!

jax update breaks sax installation

With jax update to 0.4.24, the installation of sax is broken:

$ pip install sax
   ...
$ python -c 'import sax' 
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File ".../python3.11/site-packages/sax/__init__.py", line 19, in <module>
    from . import backends as backends
  File ".../python3.11/site-packages/sax/backends/__init__.py", line 40, in <module>
    from .klu import analyze_circuit_klu, analyze_instances_klu, evaluate_circuit_klu
  File ".../python3.11/site-packages/sax/backends/klu.py", line 9, in <module>
    import klujax
  File ".../python3.11/site-packages/klujax.py", line 164, in <module>
    @xla_register_cpu(coo_mul_vec_c128, klujax_cpp.coo_mul_vec_c128)
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../python3.11/site-packages/klujax.py", line 64, in decorator
    xla.backend_specific_translations["cpu"][primitive] = partial(fun, name)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'jax.interpreters.xla' has no attribute 'backend_specific_translations'

Not sure if this should be fixed at sax level (by pinning jax and jaxlib to 0.4.23 as here) of it it should be fixed in klujax

I can confirm that using pip install sax jax==0.4.23 jaxlib==0.4.23 works fine.

error with `flax` < 0.8

Just a heads up, with jax==0.4.26 and flax==0.7.* I was getting errors when importing sax installed recently with pip install --upgrade sax

>>> import sax
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/lib/python3.11/site-packages/sax/__init__.py", line 15, in <module>
    from flax.core.frozen_dict import FrozenDict as FrozenDict
  File "/opt/homebrew/lib/python3.11/site-packages/flax/__init__.py", line 19, in <module>
    from .configurations import (
  File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 93, in <module>
    flax_filter_frames = define_bool_state(
                         ^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 42, in define_bool_state
    return jax_config.define_bool_state('flax_' + name, default, help)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Config' object has no attribute 'define_bool_state'

It was fixed when pip install --upgrade flax installed 0.8.*.

Not sure if this is something you might need to know for setting requirements but thought I'd let you know / create a paper trail in case others see this, feel free to close

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.