Coder Social home page Coder Social logo

treeo's People

Contributors

cgarciae avatar github-actions[bot] avatar jiyuuchc avatar nalzok avatar thomaspinder avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

treeo's Issues

Stacking of Treeo.Tree

I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
    name: str = to.field(node=False) # I am a static field!

persons = [
    Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
    Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
    Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
]

# Stack (struct of arrays instead of list of structs)
jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)

However, this fails with the following exception:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 18
     11     name: str = to.field(node=False) # I am a static field!
     13 persons = [
     14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
     15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
     16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
     17 ]
---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)

File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
    166 """Maps a multi-input function over pytree args to produce a new pytree.
    167 
    168 Args:
   (...)
    196   [[5, 7, 9], [6, 1, 2]]
    197 """
    198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
    166 """Maps a multi-input function over pytree args to produce a new pytree.
    167 
    168 Args:
   (...)
    196   [[5, 7, 9], [6, 1, 2]]
    197 """
    198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').

Versions used:

  • JAX: 0.3.20
  • Treeo: 0.0.10

From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields.
Has anyone has tried something similar and come up with a nice solution?

TracedArrays treated as nodes by default

Current for convenience all non-Tree fields which are not declared are set to static fields as most fields actually are, however, for more complex applications a Traced Array might actually be passed when a static field is usually expected.

A simple solution is change the current node policy to treat any field containing a TracedArray as a node, this would be the same as the current policy for Tree fields.

Get all unique kinds

Hi,

Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

class KindOne: pass
class KindTwo: pass

@dataclass
class SubModel(to.Tree):
    parameter: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=KindOne
    )


@dataclass 
class Model(to.Tree):
    parameter: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=KindTwo
    )

m = Model()

m.unique_kinds() # [KindOne, KindTwo]

Use field kinds within tree_map

Firstly, thanks for creating Treeo - it's a fantastic package.

Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

import jax.numpy as jnp

class Parameter:
    def transform(self):
        return jnp.exp(self)


@dataclass
class Model(to.Tree):
    lengthscale: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )

is there a way that I could do something similar to the following pseudocode snippet:

m = Model()
jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))

Jitting twice for a class method

import jax
import jax.numpy as jnp
import treeo as to

class A(to.Tree):
    X: jnp.array = to.field(node=True)
    
    def __init__(self):
        self.X = jnp.ones((50, 50))

    @jax.jit
    def f(self, Y):
        return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)

Y = jnp.ones(2)
for i in range(5):
    print(A.f._cache_size())
    a = A()
    a.f(Y)

The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

Relax jax/jaxlib version constraints

Now that jax 0.3.0 and jaxlib 0.3.0 have been released the version constraints in pyproject.toml are outdated.

treeo/pyproject.toml

Lines 16 to 17 in a402f3f

jax = "^0.2.24"
jaxlib = "^0.1.73"

This corresponds to the version constraint jax<0.3.0,>=0.2.18 (https://python-poetry.org/docs/dependency-specification/#caret-requirements). Now that jax v0.3.0 has been released (https://github.com/google/jax/releases/tag/jax-v0.3.0) this doesn't work with the latest version. I think the same applies to jaxlib as well, since it also got upgraded to v0.3.0 (https://github.com/google/jax/releases/tag/jaxlib-v0.3.0).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.