flaport / sax Goto Github PK
View Code? Open in Web Editor NEWS + Autograd + XLA :: S-parameter based frequency domain circuit simulations and optimizations using JAX.
Home Page: https://flaport.github.io/sax
License: Apache License 2.0
S + Autograd + XLA :: S-parameter based frequency domain circuit simulations and optimizations using JAX.
Home Page: https://flaport.github.io/sax
License: Apache License 2.0
Would it be possible to include models of active components? I want to be able to reference components within the circuit, change an applied voltage and see the effect the phase shift has on the system.
Best,
-Alex
I'm trying to build a hierarchical circuit that uses a subcircuit as a model.
Something like:
top_circuit:
instances: {- model
- sub_circuit}
If I naively try to add a circuit as a model I get the following error message: AttributeError: 'function' object has no attribute 'items'
Here is a minimum (non-)working example:
import sax
import numpy as np
def ring(length=200.0, coupling=0.1, prop_loss=0.1, neff=1.4, ng=1.4, wl0=1.5, wl=1.5):
ring_sax, _ = sax.circuit(
netlist={
"instances": {
"dc": {"component": "coupler", "settings": {"coupling": coupling}},
"top": {
"component": "straight",
"settings": {
"length": length,
"loss": prop_loss,
"neff": neff,
"ng": ng,
"wl0": wl0,
"wl": wl,
},
},
},
"connections": {
"dc,out1": "top,in0",
"top,out0": "dc,in1",
},
"ports": {
"in0": "dc,in0",
"out0": "dc,out0",
},
},
models={
"coupler": sax.models.coupler,
"straight": sax.models.straight,
},
)
return ring_sax
def ring_test(ring_length=10.0, ring_prop_loss=1.0, ring_coupling=0.001, neff=1.4, ng=1.4, wl0=1.3, wl=1.3):
r_model, _ = sax.circuit(
netlist={
"instances": {
"wg": {
"component": "straight",
"settings": {
"length": 100.0,
"loss": 0.1,
"neff": neff,
"ng": ng,
"wl0": wl0,
"wl": wl,}
},
"ring_ps": {
"component": "ring",
"settings": {
"length": ring_length,
"coupling": ring_coupling,
"prop_loss": ring_prop_loss,
"neff": neff,
"ng": ng,
"wl0": wl0,
"wl": wl,}
},
},
"connections": {
"wg,out0":"ring_ps,in0",
},
"ports": {
"in0":"wg,in0",
"out0":"ring_ps,out0",
},
},
models={
"straight": sax.models.straight,
"ring": ring,
},
)
return r_model
wl = np.linspace(1.265, 1.285, 1000)
ring_circuit = ring_test(wl=wl)
ring_circuit()
If in the models
dictionary I do "ring": ring(),
, then there's no error but the results are not correct (it acts as if there were no ring).
Am I missing something?
Thank you!
See issue 1067 in gdsfactory
How easy would it be for sax.circuit to also account for interfaces between components?
I am thinking a flag that, if True,
(1) requires the user to provide models for component pairs that show up connected in the graph
(2) adds a component evaluating these models at every "connection" of the netlist
The obvious use case is to model reflections in direct transitions between waveguides with different widths or bend radii. Currently gdsfactory+SAX throws a critical warning for the former, and evaluates the second without mode mismatch reflections.
Maybe this should be added as an option in gdsfactory.get_netlist()?
Example of desired behaviour for bend --> coupler --> straight :
Input:
circuit, _ = sax.circuit(c.get_netlist(), models=models, interfaces=True)
Output:
"Given Models": [],
"Required Models": ["bend_circular", "coupler_full", "straight"],
"Required Interface Models": [("bend_circular:o2", "coupler_full:o1"), ("coupler_full:o3", "straight:o1")],
If there were open-source EMEs around the required reflection matrices would be very easy to calculate :)
With jax update to 0.4.24, the installation of sax is broken:
$ pip install sax
...
$ python -c 'import sax'
Traceback (most recent call last):
File "<string>", line 1, in <module>
File ".../python3.11/site-packages/sax/__init__.py", line 19, in <module>
from . import backends as backends
File ".../python3.11/site-packages/sax/backends/__init__.py", line 40, in <module>
from .klu import analyze_circuit_klu, analyze_instances_klu, evaluate_circuit_klu
File ".../python3.11/site-packages/sax/backends/klu.py", line 9, in <module>
import klujax
File ".../python3.11/site-packages/klujax.py", line 164, in <module>
@xla_register_cpu(coo_mul_vec_c128, klujax_cpp.coo_mul_vec_c128)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../python3.11/site-packages/klujax.py", line 64, in decorator
xla.backend_specific_translations["cpu"][primitive] = partial(fun, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'jax.interpreters.xla' has no attribute 'backend_specific_translations'
Not sure if this should be fixed at sax level (by pinning jax and jaxlib to 0.4.23 as here) of it it should be fixed in klujax
I can confirm that using pip install sax jax==0.4.23 jaxlib==0.4.23
works fine.
How can we simulate hierarchical components?
See for example 2 mzis connected to each other
Trying to create a simple MZI circuit with where I can give complex parameters to each coupler. For example:
def coupler(S31=1/jnp.sqrt(2), S41=1j/jnp.sqrt(2)) -> sax.SDict:
coupler_dict = sax.reciprocal(
{
("in0", "out0"): S31,
("in0", "out1"): S41,
("in1", "out0"): S41,
("in1", "out1"): S31,
}
)
return coupler_dict
def MZI_arms(phase_top = 0., phase_bottom = 0) -> sax.SDict:
_sdict = sax.reciprocal(
{
("in0", "out0"): jnp.exp(1j*phase_top),
("in1", "out1"): jnp.exp(1j*phase_bottom),
}
)
return _sdict
mzi, info = sax.circuit(
netlist={
"instances": {
"BS1": "coupler",
"PS1": "phase_shifter",
"BS2": "coupler",
"PS2": "phase_shifter",
},
"connections": {
"BS1,out0": "PS1,in0",
"BS1,out1": "PS1,in1",
"PS1,out0": "BS2,in0",
"PS1,out1": "BS2,in1",
"BS2,out0": "PS2,in0",
"BS2,out1": "PS2,in1",
},
"ports": {
"in0": "BS1,in0",
"in1": "BS1,in1",
"out0": "PS2,out0",
"out1": "PS2,out1",
},
},
models={
"coupler": coupler,
"phase_shifter": MZI_arms,
}
)
This returns a warning for me as follows:
[.../lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2089]: ComplexWarning: Casting complex values to real discards the imaginary part
out_array: Array = lax_internal._convert_element_type(
Is that an issue? I am not sure I should trust the final result...
How could we replace sipann with Jax based neural network models?
How can we install sax on windows?
i think the SAX instructions for windows does not seem to work anymore
set PIP_FIND_LINKS="https://whls.blob.core.windows.net/unstable/index.html"
pip install sax[jax]
https://github.com/gdsfactory/gdsfactory/actions/runs/5208126572/jobs/9396371280
Sax looks like a very interesting and useful package!
I was wondering where the code for klujax is located, since it doesn't seem to be part of sax.
Additionally, as I was reading the docs on the KLU backend, I was surprised to learn that the backend does not support gradients. If I understand the approach correctly, the calculation is a combination of matrix multiplications and a linear solve (using klujax). The gradient for the linear solve operation should just involve another linear solve, meaning that you should be able to define the gradient rule for JAX to make another call to klujax. Am I missing something?
this notebook has some issues with the latest release of sax
https://github.com/gdsfactory/gplugins/blob/main/notebooks/sax_01_sax.ipynb
Can we add a changelog to SAX?
We are now getting some strange issues
https://github.com/gdsfactory/gplugins/actions/runs/6230788778/job/16911327047
Hi Floris,
what do you think of having jax
as an optional dependency?
that way windows users can still run circuit simulations,
maybe we could have a mode where sax tries to import jax
and if it fails you can still do some functions (such as circuit sims)
let me know what you think
How could we add time domain simulations to SAX?
similar to photontorch
missing f in fstring in line 200 in netlist.py
~/mambaforge/lib/python3.10/site-packages/sax/netlist.py in _model_operations(models, ops, default_models)
198
199 if "model" not in model:
--> 200 raise ValueError(
201 "Invalid model dict for '{component}'. Key 'model' not found."
202 )
ValueError: Invalid model dict for '{component}'. Key 'model' not found.
this should be
raise ValueError(
201 f"Invalid model dict for '{component}'. Key 'model' not found."
202 )
Just a heads up, with jax==0.4.26
and flax==0.7.*
I was getting errors when importing sax
installed recently with pip install --upgrade sax
>>> import sax
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/opt/homebrew/lib/python3.11/site-packages/sax/__init__.py", line 15, in <module>
from flax.core.frozen_dict import FrozenDict as FrozenDict
File "/opt/homebrew/lib/python3.11/site-packages/flax/__init__.py", line 19, in <module>
from .configurations import (
File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 93, in <module>
flax_filter_frames = define_bool_state(
^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 42, in define_bool_state
return jax_config.define_bool_state('flax_' + name, default, help)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Config' object has no attribute 'define_bool_state'
It was fixed when pip install --upgrade flax
installed 0.8.*
.
Not sure if this is something you might need to know for setting requirements but thought I'd let you know / create a paper trail in case others see this, feel free to close
Hi @flaport
I'm curious if you have a suggestion for how to compute S-matrices for the circuit components in parallel. For example, let's say the coupler
and waveguide
functions involve running some simulations and I'd like to kick those off at the same time, is there a way to handle this in the current state of sax
or would I need to make a fork and change the internals to do some multi-threading? just curious if you have any ideas about this, thanks!
mzi, _ = sax.circuit(
netlist={
"instances": {
"lft": coupler,
"top": waveguide,
"rgt": coupler,
},
"connections": {
"lft,out0": "rgt,in0",
"lft,out1": "top,in0",
"top,out0": "rgt,in1",
},
"ports": {
"in0": "lft,in0",
"in1": "lft,in1",
"out0": "rgt,out0",
"out1": "rgt,out1",
},
}
)
SAX should better handle 'invalid' netlist names. At the very least a warning should be issued.
Also see gdsfactory/gdsfactory#667 (comment)
Hi Floris, we have a notebook demonstrating sax. It seems to error at cell [18] for some users and not for others. There seem to be only minor differences in the dependencies and we can't figure out what is causing this discrepancy.
Do you have any suggestions for things to look into here? We're stumped after testing several different dependencies.
This is the stack trace
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[17], line 2
1 # how to pass specific parmaeters to each of the sub-functions for the instances
----> 2 s = circuit_fn(splitter={"params": params0}, combiner={"params": 0 * params0}, beta=3, phase_sifter=dict(phi=2.0))
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/saxtypes.py:267, in sdict.<locals>.wrapper(**kwargs)
265 @functools.wraps(model)
266 def wrapper(**kwargs):
--> 267 return sdict(model(**kwargs))
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/circuit.py:226, in _flat_circuit.<locals>._circuit(**settings)
223 for inst_name, model in inst2model.items():
224 instances[inst_name] = model(**full_settings.get(inst_name, {}))
--> 226 S = evaluate_fn(analyzed, instances)
227 return S
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/sax/backends/klu.py:122, in evaluate_circuit_klu(analyzed, instances)
117 idx += len(ports_map)
119 Sx = jnp.concatenate(
120 [jnp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1
121 )
--> 122 CSx = Sx[..., mask]
123 Ix = jnp.ones((*batch_shape, n_col))
124 I_CSx = jnp.concatenate([-CSx, Ix], -1)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/array.py:319, in ArrayImpl.__getitem__(self, idx)
317 return lax_numpy._rewriting_take(self, idx)
318 else:
--> 319 return lax_numpy._rewriting_take(self, idx)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4152, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
4146 if (isinstance(aval, core.DShapedArray) and aval.shape == () and
4147 dtypes.issubdtype(aval.dtype, np.integer) and
4148 not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
4149 isinstance(arr.shape[0], int)):
4150 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
4153 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
4154 unique_indices, mode, fill_value)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4230, in _split_index_for_jit(idx, shape)
4226 raise TypeError(f"JAX does not support string indexing; got {idx=}")
4228 # Expand any (concrete) boolean indices. We can then use advanced integer
4229 # indexing logic to handle them.
-> 4230 idx = _expand_bool_indices(idx, shape)
4232 leaves, treedef = tree_flatten(idx)
4233 dynamic = [None] * len(leaves)
File ~/miniconda3/envs/flex/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4552, in _expand_bool_indices(idx, shape)
4550 expected_shape = shape[start: start + _ndim(i)]
4551 if i_shape != expected_shape:
-> 4552 raise IndexError("boolean index did not match shape of indexed array in index "
4553 f"{dim_number}: got {i_shape}, expected {expected_shape}")
4554 out.extend(np.where(i))
4555 else:
IndexError: boolean index did not match shape of indexed array in index 1: got (18,), expected (14,)
And when we pip freeze
for the erroring case (python 3.11 on ubuntu)
Package Version Editable project location
----------------------------- --------------- ---------------------------------------------------------------
absl-py 2.1.0
accessible-pygments 0.0.4
alabaster 0.7.16
annotated-types 0.6.0
anyio 4.2.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
astroid 3.0.3
asttokens 2.4.1
async-lru 2.0.4
attrs 23.2.0
autograd 1.6.2
Babel 2.14.0
beautifulsoup4 4.12.3
black 23.12.1
bleach 6.1.0
boto3 1.23.1
botocore 1.26.10
cached-property 1.5.2
cachetools 5.3.2
certifi 2024.2.2
cffi 1.16.0
cfgv 3.4.0
chardet 5.2.0
charset-normalizer 3.3.2
chex 0.1.82
click 8.0.3
cloudpickle 3.0.0
colorama 0.4.6
comm 0.2.1
commonmark 0.9.1
contourpy 1.2.0
coverage 7.4.1
cycler 0.12.1
dask 2023.10.1
dataclasses-json 0.6.4
debugpy 1.8.0
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.8
distlib 0.3.8
docutils 0.20.1
entrypoints 0.4
etils 1.6.0
exceptiongroup 1.2.0
executing 2.0.1
fastcore 1.5.29
fastjsonschema 2.19.1
filelock 3.13.1
flax 0.7.4
flow360scripts 0.0.1
fonttools 4.47.2
fqdn 1.5.1
fsspec 2024.2.0
future 0.18.3
gdspy 1.6.13
gdstk 0.9.50
gitdb 4.0.11
GitPython 3.1.41
gltflib 1.0.13
grcwa 0.1.2
h11 0.14.0
h2 4.1.0
h5netcdf 1.0.2
h5py 3.10.0
hpack 4.0.0
httpcore 1.0.3
httpx 0.26.0
hyperframe 6.0.1
identify 2.5.33
idna 3.6
imagesize 1.4.1
importlib-metadata 6.11.0
importlib-resources 6.1.1
iniconfig 2.0.0
ipykernel 6.28.0
ipython 8.21.0
ipywidgets 8.1.1
isoduration 20.11.0
isort 5.13.2
jax 0.4.14
jaxlib 0.4.14
jaxtyping 0.2.25
jedi 0.19.1
Jinja2 3.1.3
jmespath 1.0.1
json5 0.9.14
jsonpointer 2.4
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
jupyter 1.0.0
jupyter_client 8.6.0
jupyter-console 6.6.3
jupyter_core 5.7.1
jupyter-events 0.9.0
jupyter-lsp 2.2.2
jupyter_server 2.12.5
jupyter-server-mathjax 0.2.6
jupyter_server_terminals 0.5.2
jupyterlab 4.0.12
jupyterlab_pygments 0.3.0
jupyterlab_server 2.25.2
jupyterlab-widgets 3.0.9
kiwisolver 1.4.5
klujax 0.2.4
locket 1.0.0
markdown-it-py 3.0.0
MarkupSafe 2.1.3
marshmallow 3.20.2
matplotlib 3.8.2
matplotlib-inline 0.1.6
mccabe 0.7.0
mdit-py-plugins 0.4.0
mdurl 0.1.2
memory-profiler 0.61.0
mistune 3.0.2
ml-dtypes 0.3.2
mpmath 1.3.0
msgpack 1.0.7
multiprocess 0.70.16
mypy-extensions 1.0.0
myst-parser 2.0.0
natsort 8.4.0
nbclient 0.8.0
nbconvert 7.16.1
nbdime 4.0.1
nbformat 5.9.2
nbsphinx 0.9.3
nest_asyncio 1.6.0
networkx 2.8.8
nodeenv 1.8.0
notebook 7.0.8
notebook_shim 0.2.3
numpy 1.26.3
opt-einsum 3.3.0
optax 0.1.9
orbax-checkpoint 0.5.3
orjson 3.9.13
overrides 7.7.0
packaging 23.2
pandas 2.2.0
pandocfilters 1.5.0
parso 0.8.3
partd 1.4.1
pathspec 0.12.1
pexpect 4.9.0
pillow 10.2.0
pip 23.3.1
pkgutil_resolve_name 1.3.10
platformdirs 4.2.0
pluggy 1.4.0
ply 3.11
pre-commit 3.6.0
prometheus_client 0.20.0
prompt-toolkit 3.0.43
protobuf 4.25.2
psutil 5.9.0
ptyprocess 0.7.0
pure-eval 0.2.2
pybind11 2.11.1
pycparser 2.21
pydantic 2.6.3
pydantic_core 2.16.3
pydata-sphinx-theme 0.15.2
Pygments 2.17.2
PyJWT 2.8.0
pylint 3.0.3
PyMieScatt 1.8.1.1
pyparsing 3.1.1
pyproject-api 1.6.1
pyroots 0.5.0
pyrsistent 0.20.0
pyswarms 1.3.0
pytest 8.0.0
pytest-timeout 2.2.0
python-dateutil 2.8.2
python-json-logger 2.0.7
pytz 2024.1
PyYAML 6.0.1
pyzmq 25.1.2
qtconsole 5.5.1
QtPy 2.4.1
referencing 0.33.0
requests 2.28.2
responses 0.24.1
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 12.5.1
rpds-py 0.17.1
Rtree 1.0.1
ruff 0.2.1
s3transfer 0.5.2
sax 0.12.1
scipy 1.12.0
Send2Trash 1.8.2
setuptools 68.2.2
shapely 2.0.2
signac 2.1.0
six 1.16.0
smmap 5.0.1
sniffio 1.3.0
snowballstemmer 2.2.0
soupsieve 2.5
Sphinx 7.2.6
sphinx-book-theme 1.1.0
sphinx-copybutton 0.5.2
sphinx-sitemap 2.5.1
sphinx-tabs 3.4.5
sphinxcontrib-applehelp 1.0.8
sphinxcontrib-devhelp 1.0.6
sphinxcontrib-htmlhelp 2.0.5
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 1.0.7
sphinxcontrib-serializinghtml 1.1.10
sphinxemoji 0.3.1
stack-data 0.6.2
sympy 1.12
synced-collections 1.0.0
tensorstore 0.1.53
terminado 0.18.0
tidy3d 2.6.0rc1 /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_frontend
tidy3d-beta 1.9.0
tidy3d-denormalizer 0.1.0 /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d-denormalizer
tidy3d_pipeline 2.6.0rc1 /home/momchil/Drive/flexcompute/tidy3d-core/tidy3d_pipeline
tinycss2 1.2.1
tmm 0.1.8
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.3
toolz 0.12.1
tornado 6.4
tox 4.12.1
tqdm 4.66.1
traitlets 5.14.1
trimesh 3.20.0
typeguard 2.13.3
types-python-dateutil 2.8.19.20240106
typing_extensions 4.9.0
typing-inspect 0.9.0
typing-utils 0.1.0
tzdata 2023.4
uri-template 1.3.0
urllib3 1.26.18
virtualenv 20.25.0
vtk 9.2.6
wcwidth 0.2.13
webcolors 1.13
webencodings 0.5.1
websocket-client 1.7.0
wheel 0.41.2
widgetsnbextension 4.0.9
xarray 2023.12.0
zipp 3.17.0
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.