Coder Social home page Coder Social logo

Use field kinds within tree_map about treeo HOT 10 CLOSED

cgarciae avatar cgarciae commented on August 20, 2024
Use field kinds within tree_map

from treeo.

Comments (10)

cgarciae avatar cgarciae commented on August 20, 2024

Hey @thomaspinder! Thanks for the kind words.

Let me first rule out the easy solution, curious if this is enough?

class Parameter:
    pass
    
def transform(x):
    return jnp.exp(x)

@dataclass
class Model(to.Tree):
    lengthscale: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )

m = Model()
jax.tree_map(transform, to.filter(m, Parameter))

One thing to notice is that Kinds are just types that serve as metadata linked to a field but its not expected that they will be instantiated.

from treeo.

thomaspinder avatar thomaspinder commented on August 20, 2024

Hey @cgarciae , thanks, but perhaps my MWE was an oversimplification. The reason for defining the transform as a method of the kind class is that there can be numerous classes e.g.,

class PositiveParameter():
    def transform(self):
        return jnp.abs(self)

class NegativeParameter():
    def transform(self):
        return jnp.array(-1.) * self 

and so on...

This makes the solution you've proposed a little more tricky as there'd need to be some awkward function mappings.

from treeo.

thomaspinder avatar thomaspinder commented on August 20, 2024

Based on the tidy solution you've provided in #3 , one possible solution to this problem could be the following. Do you see any issues with this?

from dataclasses import dataclass
from typing import Set

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field


class KindOne:
    def transform(self):
        def transform_fn(x):
            return jnp.abs(x)
        return transform_fn

class KindTwo: 
    def transform(self):
        def transform_fn(x):
            return jnp.array(-1.) * x
        return transform_fn


@dataclass
class SubModel(to.Tree):
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindOne)

@dataclass
class Model(to.Tree):
    submodel: SubModel
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindTwo)

def unique_kinds(tree: to.Tree) -> Set[type]:
    kinds = set()

    def add_subtree_kinds(subtree: to.Tree):
        for field in subtree.field_metadata.values():
            if field.kind is not type(None):
                kinds.add(field.kind)

    to.apply(add_subtree_kinds, tree)

    return list(kinds)


sub_m = SubModel()
m = Model(submodel=sub_m)


for kind in unique_kinds(m):
    transform = kind().transform()
    m = to.map(transform, m, kind)

from treeo.

cgarciae avatar cgarciae commented on August 20, 2024

@thomaspinder I was guessing you where trying to do this 😅

Here is the solution:

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field


class Parameter:
    @staticmethod
    def transform(x):
        return jnp.exp(x)


@dataclass
class Model(to.Tree):
    lengthscale: jnp.ndarray = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )


m = Model()

with to.add_field_info():
    m2 = jax.tree_map(lambda field: field.kind.transform(field.value), to.filter(m, Parameter))

print(m2)

The add_field_info function probably needs a section on the User Guide, what it does is that when flattening a Tree its leaves will all be of a type called FieldInfo which among other things contains the kind and value attributes which you can use to achieve what you want. Note that I've changed transform to be a staticmethod.

My thoughts is that if this pattern becomes more widespread it would be convenient to add a add_field_info: bool argument to to.map so you could write something like this:

m2 = to.map(
    lambda field: field.kind.transform(field.value), 
    to.filter(m, Parameter), 
    add_field_info=True,
)

from treeo.

cgarciae avatar cgarciae commented on August 20, 2024

BTW: Not sure if this is relevant to you but if you are doing something like this:

params = jax.tree_map(some_function, to.filter(m, Parameter))
m = to.merge(m, params)

You can simply use:

m = to.map(some_function, m, Parameter)

from treeo.

cgarciae avatar cgarciae commented on August 20, 2024

@thomaspinder sure! That solution based on #3 works. For ergonomics you can convert transform to be a staticmethod so you don't have to instantiate the kind.

from treeo.

thomaspinder avatar thomaspinder commented on August 20, 2024

Thanks so much. The solution you give using with to.add_field_info()... is the perfect solution to my problem. Adding the additional argument to to.map() would be really great - if you ever want a hand with this e.g., writing tests/documentation, then I'd be happy to help you out.

from treeo.

cgarciae avatar cgarciae commented on August 20, 2024

@thomaspinder happy to guide you if you want to contribute 🌝
This issue looks very self contained, can be a good starting point.

Ping me if you need anything.

from treeo.

thomaspinder avatar thomaspinder commented on August 20, 2024

Sure! I'd be happy to contribute. Are you able to outline the main steps that I should be mindful of when doing this?

from treeo.

cgarciae avatar cgarciae commented on August 20, 2024

I think adding a field_info: bool argument to map and then conditionally using the add_field_info context manager over this line should be enough:

https://github.com/cgarciae/treeo/blob/master/treeo/api.py#L197

Also try to add a test 😃. Sadly we don't have a contributing document yet but to start developing do the following:

  1. Install poetry
  2. Run poetry install to install dependencies
  3. Run poetry shell to activate environment.
  4. Run pre-commit install to install precommit hooks.
  5. Run pytest to run tests.

from treeo.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.