cgarciae / treeo Goto Github PK
View Code? Open in Web Editor NEWA small library for creating and manipulating custom JAX Pytree classes
Home Page: https://cgarciae.github.io/treeo
License: MIT License
A small library for creating and manipulating custom JAX Pytree classes
Home Page: https://cgarciae.github.io/treeo
License: MIT License
I'm running into some issues when trying to stack a list of Treeo.Tree
objects into a single object. I've made a short example:
from dataclasses import dataclass
import jax
import jax.numpy as jnp
import treeo as to
@dataclass
class Person(to.Tree):
height: jnp.array = to.field(node=True) # I am a node field!
age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
name: str = to.field(node=False) # I am a static field!
persons = [
Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
]
# Stack (struct of arrays instead of list of structs)
jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
However, this fails with the following exception:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[1], line 18
11 name: str = to.field(node=False) # I am a static field!
13 persons = [
14 Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
15 Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
16 Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
17 ]
---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
166 """Maps a multi-input function over pytree args to produce a new pytree.
167
168 Args:
(...)
196 [[5, 7, 9], [6, 1, 2]]
197 """
198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
166 """Maps a multi-input function over pytree args to produce a new pytree.
167
168 Args:
(...)
196 [[5, 7, 9], [6, 1, 2]]
197 """
198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').
Versions used:
From a certain perspective this is expected because jax.tree_map
does not apply to static (node=False
) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields.
Has anyone has tried something similar and come up with a nice solution?
Current for convenience all non-Tree
fields which are not declared are set to static fields as most fields actually are, however, for more complex applications a Traced Array
might actually be passed when a static field is usually expected.
A simple solution is change the current node policy to treat any field containing a TracedArray
as a node, this would be the same as the current policy for Tree
fields.
Hi,
Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:
class KindOne: pass
class KindTwo: pass
@dataclass
class SubModel(to.Tree):
parameter: jnp.array = to.field(
default=jnp.array([1.0]), node=True, kind=KindOne
)
@dataclass
class Model(to.Tree):
parameter: jnp.array = to.field(
default=jnp.array([1.0]), node=True, kind=KindTwo
)
m = Model()
m.unique_kinds() # [KindOne, KindTwo]
Firstly, thanks for creating Treeo - it's a fantastic package.
Is there a way to use methods defined within a field's kind
object within a tree_map
call? For example, consider the following MWE
import jax.numpy as jnp
class Parameter:
def transform(self):
return jnp.exp(self)
@dataclass
class Model(to.Tree):
lengthscale: jnp.array = to.field(
default=jnp.array([1.0]), node=True, kind=Parameter
)
is there a way that I could do something similar to the following pseudocode snippet:
m = Model()
jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
import jax
import jax.numpy as jnp
import treeo as to
class A(to.Tree):
X: jnp.array = to.field(node=True)
def __init__(self):
self.X = jnp.ones((50, 50))
@jax.jit
def f(self, Y):
return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)
Y = jnp.ones(2)
for i in range(5):
print(A.f._cache_size())
a = A()
a.f(Y)
The output of the above is 0 1 2 2 2
with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1
. Thanks.
Now that jax 0.3.0 and jaxlib 0.3.0 have been released the version constraints in pyproject.toml
are outdated.
Lines 16 to 17 in a402f3f
This corresponds to the version constraint jax<0.3.0,>=0.2.18 (https://python-poetry.org/docs/dependency-specification/#caret-requirements). Now that jax v0.3.0 has been released (https://github.com/google/jax/releases/tag/jax-v0.3.0) this doesn't work with the latest version. I think the same applies to jaxlib as well, since it also got upgraded to v0.3.0 (https://github.com/google/jax/releases/tag/jaxlib-v0.3.0).
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.