Coder Social home page Coder Social logo

asem000 / pytreeclass Goto Github PK

View Code? Open in Web Editor NEW
38.0 1.0 1.0 3.27 MB

Visualize, create, and operate on pytrees in the most intuitive way possible.

Home Page: https://pytreeclass.rtfd.io/en/latest

License: Apache License 2.0

Python 63.00% Jupyter Notebook 37.00%
dataclasses deep-learning jax machine-learning pytorch tensorflow pytree data pipelines

pytreeclass's People

Contributors

asem000 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

dlwh

pytreeclass's Issues

Move `.at[].freeze()` / `.at[].unfreeze()`

This issue proposes to move .freeze / .unfreeze out of .at methods.
The rationale is that .freeze works only on treeclass instances, while.at work on both treeclass instances and treeclass leaves.

for example, the following won't have any effect on the node, because it's not treeclass moreover will convey a wrong impression to the user that the value has been frozen.

@pytc.treeclass 
class Test:
  a: Any = 1 

t = Test()
t.at["a"].freeze()

New design is to add tree_freeze and tree_unfreeze in pytreeclass.tree_util

Add common recipes to docs

  • How to add a leaf to the instance after instantiation.
  • How to call a function that changes internal state.
  • How to Interact with optax with freeze/unfreeze scheme.
  • How to integrate with flax/equinox/(haiku?).

Add Container data structure to register list/tuple/set/dict as dataclass fields

the motivation for this class definition is to expose the items of a list/tuple/dict/set as dataclass fields to enable field-specific operations on the items.
Operations include filtering by type, filtering non-differentiable variables etc.

For example :

import pytreeclass as pytc 
from pytreeclass.src.container import  Container
from pytreeclass.src.misc import  filter_nondiff
from typing import  Any 

@pytc.treeclass
class Test:
    a: Container  

t = Test(Container({"A":1, "B":2}))

print(t.tree_diagram())
# Test
#     └── a=Container
#         ├── A=1
#         └── B=2     

print( t == int)
# Test(a=Container(A=True,B=True))

print( filter_nondiff(t))
# Test(a=Container(*A=1,*B=2)). # * means static

In contrast, using unwrapped dict

import pytreeclass as pytc 
from pytreeclass.src.container import  Container
from typing import  Any 

@pytc.treeclass
class Test:
   a: Any  

t = Test(({"A":1, "B":2}))
print(t.tree_diagram())
# Test
#     └── a={'A': 1, 'B': 2}  

print( t == int)
# Test(a=False)

print( filter_nondiff(t))
# Test(a={A:1,B:2})

Document `tree_viz` changes

Breaking changes

  • remove model.summary() in favor of pytc.tree_summary(model)
  • remove model.tree_diagram() in favor of pytc.tree_diagram(model)
  • remove model.tree_box()

Additions

1) tree_summary, tree_diagram, tree_str, and tree_repr are rewritten from scratch to work on any registered PyTree out of the box.

import pytreeclass as pytc 
import jax

@jax.tree_util.register_pytree_node_class
class Tree:
    def __init__(self, x):
        self.x = x
        self.y = x**2
    def tree_flatten(self):
        return (self.x,self.y), None
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

tree = Tree(1)

print(pytc.tree_repr(tree))
print(pytc.tree_str(tree))
print(pytc.tree_diagram(tree))
print(pytc.tree_summary(tree))

Tree(leaf_0=1, leaf_1=1)
Tree(leaf_0=1, leaf_1=1)
Tree
    ├── leaf_0=1
    └── leaf_1=1
┌──────┬────┬─────┬──────┐
│NameTypeCountSize  │
├──────┼────┼─────┼──────┤
│leaf_0int128.00B│
├──────┼────┼─────┼──────┤
│leaf_1int128.00B│
├──────┼────┼─────┼──────┤
│ΣTree256.00B│
└──────┴────┴─────┴──────┘

Flax

import flax 
import pytreeclass as pytc

class FlaxTree(flax.struct.PyTreeNode):
    a: int = 2
    b: int = jnp.ones([2,2])

    def __call__(self, x):
        return self.a * x**2

tree = FlaxTree()

print(pytc.tree_repr(tree))
print(pytc.tree_str(tree))
print(pytc.tree_diagram(tree))
print(pytc.tree_summary(tree))

FlaxTree(a=2, b=f32{2x2}∈[1.0,1.0])
FlaxTree(a=2, b=[[1. 1.] [1. 1.]])
FlaxTree
    ├── a=2
    └── b=f32{2x2}∈[1.0,1.0]
┌────┬────────┬─────┬──────┐
│NameTypeCountSize  │
├────┼────────┼─────┼──────┤
│aint128.00B│
├────┼────────┼─────┼──────┤
│bf32{2x2}│416.00B│
├────┼────────┼─────┼──────┤
│ΣFlaxTree544.00B│
└────┴────────┴─────┴──────┘

2) Added depth parameter in tree_summary, and tree_diagram to control to which extent to print. default is depth=float('inf')

import pytreeclass as pytc 

tree = {"d1": {"d2_0": {"d3_0":1},"d2_1": 2}}

with inf depth

print(pytc.tree_diagram(tree))
print(pytc.tree_summary(tree))

dict
    └── ['d1']:dict
        ├── ['d2_0']:dict
        │   └── ['d3_0']=1
        └── ['d2_1']=2
┌──────────────────────┬────┬─────┬──────┐
│NameTypeCountSize  │
├──────────────────────┼────┼─────┼──────┤
│['d1']['d2_0']['d3_0']│int128.00B│
├──────────────────────┼────┼─────┼──────┤
│['d1']['d2_1']        │int128.00B│
├──────────────────────┼────┼─────┼──────┤
│Σdict256.00B│
└──────────────────────┴────┴─────┴──────┘

With depth=1

print(pytc.tree_diagram(tree, depth=1))
print(pytc.tree_summary(tree, depth=1))

dict
    └── ['d1']={d2_0:{d3_0:1}, d2_1:2}
┌────┬────┬─────┬──────┐
│NameTypeCountSize  │
├────┼────┼─────┼──────┤
│Σdict256.00B│
└────┴────┴─────┴──────┘

With depth=2

print(pytc.tree_diagram(tree, depth=2))
print(pytc.tree_summary(tree, depth=2))

dict
    └── ['d1']:dict
        ├── ['d2_0']={d3_0:1}
        └── ['d2_1']=2
┌──────────────┬────┬─────┬──────┐
│NameTypeCountSize  │
├──────────────┼────┼─────┼──────┤
│['d1']['d2_0']│dict128.00B│
├──────────────┼────┼─────┼──────┤
│['d1']['d2_1']│int128.00B│
├──────────────┼────┼─────┼──────┤
│Σdict256.00B│
└──────────────┴────┴─────┴──────┘

Technical note:
The new API achieves the above functionality by fetching iter handlers from jax._src.tree_util._registry.

3) Changed tree_repr array representation to include min/max stats

The change inspired by Alexey Zaytsev lovely-jax

import pytreeclass as pytc 
import jax 
x = jax.numpy.array([1,2,3])
print(pytc.tree_repr(x))
i32{3}∈[1,3]

Document `field(callabacks=[...])`

from V0.2 PyTreeClass Implements new dataclasses alternative.

The new implementation includes callbacks in the field to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:

import jax 
import pytreeclass as pytc

def positive_int_callback(value):
    if not isinstance(value, int):
        raise TypeError("Value must be an integer")
    if value <= 0:
        raise ValueError("Value must be positive")
    return value


@pytc.treeclass
class Tree:
    in_features:int = pytc.field(callbacks=[positive_int_callback])


tree = Tree(1)
# no error

tree = Tree(0)
# ValueError: Error for field=`in_features`:
# Value must be positive

tree = Tree(1.0)
# TypeError: Error for field=`in_features`:
# Value must be an integer

Other useful applications are type conversion or resolution (canonicalization) etc.

Relevant resources:

Document Broadcasting mapping decorator `bcmap`

bcmap(func, is_leaf) maps a function over PyTrees leaves with automatic broadcasting for scalar arguments.

bcmap is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like jnp.where to work with arbitrary tree structures without the need to write a specific function for each broadcasting case

For example, lets say we want to use jnp.where to zeros out all values in an arbitrary tree structure that are less than 0

tree = ([1], {"a":1, "b":2}, (1,), -1,)

we can use jax.tree_util.tree_map to apply jnp.where to the tree but we need to write a specific function for broadcasting the scalar to the tree

def map_func(leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(leaf>0, leaf, 0)

jtu.tree_map(map_func, tree)  
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))

However, lets say we want to use jnp.where to set a value to a leaf value from another tree that looks like this

def map_func2(lhs_leaf, rhs_leaf):
    # here we encoded the scalar `0` inside the function
    return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)

tree2 = jtu.tree_map(lambda x: 1000, tree)

jtu.tree_map(map_func2, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))

Now, bcmap solves this problem by figuring out the broadcasting case.

broadcastable_where = pytc.bcmap(jnp.where)
mask = jtu.tree_map(lambda x: x>0, tree)

case 1

broadcastable_where(mask, tree, 0)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(0, dtype=int32, weak_type=True))

case 2

broadcastable_where(mask, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
#  {'a': Array(1, dtype=int32, weak_type=True),
#   'b': Array(2, dtype=int32, weak_type=True)},
#  (Array(1, dtype=int32, weak_type=True),),
#  Array(1000, dtype=int32, weak_type=True))

lets then take this a step further to eliminate mask from the equation
by using pytreeclass with leafwise=True

@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
    tree : tuple = ([1], {"a":1, "b":2}, (1,), -1,)

tree = Tree()  
# Tree(tree=([1], {a:1, b:2}, (1), -1))

case 1: broadcast scalar to tree

print(broadcastable_where(tree>0, tree, 0))
# Tree(tree=([1], {a:1, b:2}, (1), 0))

case 2: broadcast tree to tree
```python
print(broadcastable_where(tree>0, tree, tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))

bcmap also works with all kind of arguments in the wrapped function

print(broadcastable_where(tree>0, x=tree, y=tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))

in concolusion, bcmap is a function transformation that can be used to
to make functions work with arbitrary tree structures without the need to write
a specific function for each broadcasting case

moreover, bcmap can be more powerful when used with pytreeclass to
facilitate operation of arbitrary functions on PyTree objects
without the need to use tree_map

Move model related viz to `serket`

  • Move summary, tree_box to serket with the aim of simplifying the package structure
  • Also this will enable adding more functionality to tree_summary appropriate to deep learning models (e.g. compact=True)

Move to immutable approach

Currently

.at[].{set,apply,reduce}, .freeze() .unfreeze() returns new instance without affecting the original treeclass instance.
Since pytreeclass relies mostly on the functional .at[] methods, it makes more sense to move entirely to .at[] methods to handle the rest of the operations.
Moreover, internal state updates can be undetected in some scenarios, which might cause a problem

This issue is to propose a couple of new things for a full immutable experience.

  • .at["method_name"](*args, **kwargs) should be called on a new tree. the method should return new_tree, and the call output. . However, the __call__ method should be used without raising an error if no side effect is caused.
  • .at["attribute name"].set() should set attribute to a new tree instance and return the updated tree.

Add freeze/unfreeze scheme in docs

Specifically, show why it can be helpful compared to tree-specific flatten rules foundflax.struct.PyTreeNode/ equinox.Module and other PyTree based libraries.

Add tree_mermaid to tree_viz.py

  • Modify tree_viz.py::tree_indent to translate model to mermaid* diagrams

Use case:

  • Mermaid can be used on Github readme, a possible use case is coupling a model with the mermaid figure in the readme/notebook.
  • the function should print out a link to the mermaid diagram to better view the diagram online and share it. similar to jax.profiler with Perfetto**

*https://mermaid-js.github.io/mermaid/#/
**https://jax.readthedocs.io/en/latest/profiling.html

  • Implement tree_mermaid , save_viz
  • Generate a link for the created HTML

More documentation

Document the structure of pytreeclass
Update getting started
Document public api

add short hand notation for node addition without dataclass field declaration

Consider adding functionality to enable users to add nodes without declaring them as dataclass fields.
Consider the following example for motivation.

@treeclass
class Linear :
    weight : jnp.ndarray
    bias   : jnp.ndarray

    def __init__(self,key,in_dim,out_dim):
        self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
        self.bias = jnp.ones((1,out_dim))

    def __call__(self,x):
        return x @ self.weight + self.bias


@treeclass 
class MLP:

    def __init__(self,layers:list[int],key=jax.random.PRNGKey(0)):
        
        keys = jax.random.split(key,len(layers))
        
        for i,(key,in_dim,out_dim) in enumerate(zip(keys,layers[:-1],layers[1:])):
            layer_name = f"linear_{i}"
            layer_val = Linear(key,in_dim,out_dim)
            
            # pytreeclass consider only data in dataclass_fields as nodes.
            # we can set all all layers to be one node with no problem.
            # however , for tree_viz to work properly , we need to split each layer to a single node.
            
            # create a dataclass field entry
            self.__dataclass_fields__ = {**self.__dataclass_fields__ ,**{f"linear_{i}":field()}}

            # set the name of the field
            setattr(self.__dataclass_fields__[layer_name],"name", layer_name)

            # set the value in __dict__
            self.__dict__[layer_name] = layer_val
            
    def __call__(self,x):
        layers = [v for k,v in self.__dict__.items() if "linear" in k ]
        for layer in layers[:-1]:
            x = layer(x)
        return layers[-1](x)

Comparison with `equinox`

First, sorry to put it in issues, but there is no discussion tab. Second, I don't want this to sound negative or critical, rather I am genuinely interested in the reasoning (and it looks like your library has some nice unique features e.g. visualising PyTrees).

There are already a number of mature JAX libraries such as Equinox that handle the idea of constructing classes as PyTrees and layering on top convenient methods to manipulate them (plus I notice you written a NN library which then builds further). I was wondering why another set of libraries? What are the advantages of PyTreeClass and Serket over something like Equinox?

Using pytreeclass with jax and pytorch without specifying backend as environment variable

Hi @ASEM000 ,

I really like your library compared to some other automated pytree alternatives and would love to see more people using it.
I was interested in using pytreeclass in CoLA, a numerical linear algebra library that I have been involved in developing. One of the design constraints is that we need to be able to support usage in both jax and pytorch, whether jax is installed, pytorch is installed, or both. This decision depends on the LinearOperator objects that the user creates , and there can be scenarios even where both jax and pytorch objects exist simultaneously.

We were hoping to use pytreeclass as the base pytree for the LinearOperator objects, but have run into some issues with this cross-platform support. We know that pytreeclass was designed with support for both jax and pytorch in mind, but I couldn't find details on this topic in the docs.

Having a look in pytreeclass/_src/backend/init.py is this specified using the environment variable?
Is there any way that pytree class can function whether or not jax or pytorch is installed based on whether the imports succeed or fail? Also do you have any thoughts for whether it would be possible to have jax and pytorch pytrees existing at the same time?

Cheers,
Marc

Frozen / Static leaves

Reading the documentation, I understand that you can freeze variables by using a mask based upon name or type.

Is it possible to set a variable to "frozen" within the class definition i.e in the way Equinox has static_field option.

While I understand the concept behind being able to mask out a set of variables contained in a PyTree (or PyTree of PyTrees), there are lots of situations where you know when creating a new class, that certain variables will only ever be constant. Furthermore, as models become much more complicated (or if others may utilise elements of your model) it becomes more cumbersome to have to mask these out / others have to know to do this.

dealing with non differentiable tree values under `jax.{grad,value_and_grad,...}`.

In JAX, certain transformations require all values to be inexact (e.g. jax.grad). In this issue, I propose how to deal with non-differentiable(noninexact) tree values with two function transformations.

First, the existing approach is to use pytc.static_field in the class definition on each non-inexact field; however, this requires the user to know in advance the type of data used.
I propose to add functionfilter_nondiff to mark nondifferentiable nodes static. and unfilter_nondiff to undo the marking of these fields. In essence unfilter_nondiff(filter_nondiff(x)) == x

Lets demonstrate these two transformation with an example,

@pytc.treeclass
class Test:
    a:int = 0
    b:float = 1. 
    c:jnp.ndarray = jnp.array([1,2,3])
    d:jnp.ndarray = jnp.array([1.,2.,3.])

t = Test()
print(t)
#Test(a=0,b=1.0,c=[1 2 3],d=[1. 2. 3.])

# * marks static field
print(filter_nondiff)
# Test(*a=0,b=1.0,*c=[1 2 3],d=[1. 2. 3.])

print(unfilter_nondiff(filter_nondiff(t)))
# Test(a=0,b=1.0,c=[1 2 3],d=[1. 2. 3.])

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.