asem000 / pytreeclass Goto Github PK
View Code? Open in Web Editor NEWVisualize, create, and operate on pytrees in the most intuitive way possible.
Home Page: https://pytreeclass.rtfd.io/en/latest
License: Apache License 2.0
Visualize, create, and operate on pytrees in the most intuitive way possible.
Home Page: https://pytreeclass.rtfd.io/en/latest
License: Apache License 2.0
This issue proposes to move .freeze
/ .unfreeze
out of .at
methods.
The rationale is that .freeze works only on treeclass
instances, while.at work on both treeclass
instances and treeclass
leaves.
for example, the following won't have any effect on the node, because it's not treeclass
moreover will convey a wrong impression to the user that the value has been frozen.
@pytc.treeclass
class Test:
a: Any = 1
t = Test()
t.at["a"].freeze()
New design is to add tree_freeze
and tree_unfreeze
in pytreeclass.tree_util
The current design registers only nodes registered as dataclass
fields. However, this increases the boilerplate code significantly.
New design should register any value that is wrapped by treeclass
the motivation for this class definition is to expose the items of a list/tuple/dict/set as dataclass
fields to enable field-specific operations on the items.
Operations include filtering by type, filtering non-differentiable variables etc.
For example :
import pytreeclass as pytc
from pytreeclass.src.container import Container
from pytreeclass.src.misc import filter_nondiff
from typing import Any
@pytc.treeclass
class Test:
a: Container
t = Test(Container({"A":1, "B":2}))
print(t.tree_diagram())
# Test
# └── a=Container
# ├── A=1
# └── B=2
print( t == int)
# Test(a=Container(A=True,B=True))
print( filter_nondiff(t))
# Test(a=Container(*A=1,*B=2)). # * means static
In contrast, using unwrapped dict
import pytreeclass as pytc
from pytreeclass.src.container import Container
from typing import Any
@pytc.treeclass
class Test:
a: Any
t = Test(({"A":1, "B":2}))
print(t.tree_diagram())
# Test
# └── a={'A': 1, 'B': 2}
print( t == int)
# Test(a=False)
print( filter_nondiff(t))
# Test(a={A:1,B:2})
deprecate the usage of data classes from v0.2,
preferably implementing a minimal version of data classes. (i.e. only init generation for our case).
See: google/jax#14295
model.summary()
in favor of pytc.tree_summary(model)
model.tree_diagram()
in favor of pytc.tree_diagram(model)
model.tree_box()
tree_summary
, tree_diagram
, tree_str
, and tree_repr
are rewritten from scratch to work on any registered PyTree
out of the box.import pytreeclass as pytc
import jax
@jax.tree_util.register_pytree_node_class
class Tree:
def __init__(self, x):
self.x = x
self.y = x**2
def tree_flatten(self):
return (self.x,self.y), None
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
tree = Tree(1)
print(pytc.tree_repr(tree))
print(pytc.tree_str(tree))
print(pytc.tree_diagram(tree))
print(pytc.tree_summary(tree))
Tree(leaf_0=1, leaf_1=1)
Tree(leaf_0=1, leaf_1=1)
Tree
├── leaf_0=1
└── leaf_1=1
┌──────┬────┬─────┬──────┐
│Name │Type│Count│Size │
├──────┼────┼─────┼──────┤
│leaf_0│int │1 │28.00B│
├──────┼────┼─────┼──────┤
│leaf_1│int │1 │28.00B│
├──────┼────┼─────┼──────┤
│Σ │Tree│2 │56.00B│
└──────┴────┴─────┴──────┘
import flax
import pytreeclass as pytc
class FlaxTree(flax.struct.PyTreeNode):
a: int = 2
b: int = jnp.ones([2,2])
def __call__(self, x):
return self.a * x**2
tree = FlaxTree()
print(pytc.tree_repr(tree))
print(pytc.tree_str(tree))
print(pytc.tree_diagram(tree))
print(pytc.tree_summary(tree))
FlaxTree(a=2, b=f32{2x2}∈[1.0,1.0])
FlaxTree(a=2, b=[[1. 1.] [1. 1.]])
FlaxTree
├── a=2
└── b=f32{2x2}∈[1.0,1.0]
┌────┬────────┬─────┬──────┐
│Name│Type │Count│Size │
├────┼────────┼─────┼──────┤
│a │int │1 │28.00B│
├────┼────────┼─────┼──────┤
│b │f32{2x2}│4 │16.00B│
├────┼────────┼─────┼──────┤
│Σ │FlaxTree│5 │44.00B│
└────┴────────┴─────┴──────┘
depth
parameter in tree_summary
, and tree_diagram
to control to which extent to print. default is depth=float('inf')
import pytreeclass as pytc
tree = {"d1": {"d2_0": {"d3_0":1},"d2_1": 2}}
with inf
depth
print(pytc.tree_diagram(tree))
print(pytc.tree_summary(tree))
dict
└── ['d1']:dict
├── ['d2_0']:dict
│ └── ['d3_0']=1
└── ['d2_1']=2
┌──────────────────────┬────┬─────┬──────┐
│Name │Type│Count│Size │
├──────────────────────┼────┼─────┼──────┤
│['d1']['d2_0']['d3_0']│int │1 │28.00B│
├──────────────────────┼────┼─────┼──────┤
│['d1']['d2_1'] │int │1 │28.00B│
├──────────────────────┼────┼─────┼──────┤
│Σ │dict│2 │56.00B│
└──────────────────────┴────┴─────┴──────┘
With depth=1
print(pytc.tree_diagram(tree, depth=1))
print(pytc.tree_summary(tree, depth=1))
dict
└── ['d1']={d2_0:{d3_0:1}, d2_1:2}
┌────┬────┬─────┬──────┐
│Name│Type│Count│Size │
├────┼────┼─────┼──────┤
│Σ │dict│2 │56.00B│
└────┴────┴─────┴──────┘
With depth=2
print(pytc.tree_diagram(tree, depth=2))
print(pytc.tree_summary(tree, depth=2))
dict
└── ['d1']:dict
├── ['d2_0']={d3_0:1}
└── ['d2_1']=2
┌──────────────┬────┬─────┬──────┐
│Name │Type│Count│Size │
├──────────────┼────┼─────┼──────┤
│['d1']['d2_0']│dict│1 │28.00B│
├──────────────┼────┼─────┼──────┤
│['d1']['d2_1']│int │1 │28.00B│
├──────────────┼────┼─────┼──────┤
│Σ │dict│2 │56.00B│
└──────────────┴────┴─────┴──────┘
Technical note:
The new API achieves the above functionality by fetching iter handlers from jax._src.tree_util._registry
.
tree_repr
array representation to include min/max statsThe change inspired by Alexey Zaytsev lovely-jax
import pytreeclass as pytc
import jax
x = jax.numpy.array([1,2,3])
print(pytc.tree_repr(x))
i32{3}∈[1,3]
from V0.2 PyTreeClass
Implements new dataclasses
alternative.
The new implementation includes callbacks
in the field
to apply a sequence of functions on input at setting the attribute stage. The callback is quite useful in several cases, for instance, to ensure a certain input type within a valid range. See example:
import jax
import pytreeclass as pytc
def positive_int_callback(value):
if not isinstance(value, int):
raise TypeError("Value must be an integer")
if value <= 0:
raise ValueError("Value must be positive")
return value
@pytc.treeclass
class Tree:
in_features:int = pytc.field(callbacks=[positive_int_callback])
tree = Tree(1)
# no error
tree = Tree(0)
# ValueError: Error for field=`in_features`:
# Value must be positive
tree = Tree(1.0)
# TypeError: Error for field=`in_features`:
# Value must be an integer
Other useful applications are type conversion or resolution (canonicalization) etc.
Relevant resources:
bcmap(func, is_leaf)
maps a function over PyTrees leaves with automatic broadcasting for scalar arguments.
bcmap
is function transformation that broadcast a scalar to match the first argument of the function this enables us to convert a function like jnp.where
to work with arbitrary tree structures without the need to write a specific function for each broadcasting case
For example, lets say we want to use jnp.where
to zeros out all values in an arbitrary tree structure that are less than 0
tree = ([1], {"a":1, "b":2}, (1,), -1,)
we can use jax.tree_util.tree_map
to apply jnp.where
to the tree but we need to write a specific function for broadcasting the scalar to the tree
def map_func(leaf):
# here we encoded the scalar `0` inside the function
return jnp.where(leaf>0, leaf, 0)
jtu.tree_map(map_func, tree)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(0, dtype=int32, weak_type=True))
However, lets say we want to use jnp.where
to set a value to a leaf value from another tree that looks like this
def map_func2(lhs_leaf, rhs_leaf):
# here we encoded the scalar `0` inside the function
return jnp.where(lhs_leaf>0, lhs_leaf, rhs_leaf)
tree2 = jtu.tree_map(lambda x: 1000, tree)
jtu.tree_map(map_func2, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(1000, dtype=int32, weak_type=True))
Now, bcmap
solves this problem by figuring out the broadcasting case.
broadcastable_where = pytc.bcmap(jnp.where)
mask = jtu.tree_map(lambda x: x>0, tree)
case 1
broadcastable_where(mask, tree, 0)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(0, dtype=int32, weak_type=True))
case 2
broadcastable_where(mask, tree, tree2)
# ([Array(1, dtype=int32, weak_type=True)],
# {'a': Array(1, dtype=int32, weak_type=True),
# 'b': Array(2, dtype=int32, weak_type=True)},
# (Array(1, dtype=int32, weak_type=True),),
# Array(1000, dtype=int32, weak_type=True))
lets then take this a step further to eliminate mask
from the equation
by using pytreeclass
with leafwise=True
@ft.partial(pytc.treeclass, leafwise=True)
class Tree:
tree : tuple = ([1], {"a":1, "b":2}, (1,), -1,)
tree = Tree()
# Tree(tree=([1], {a:1, b:2}, (1), -1))
case 1: broadcast scalar to tree
print(broadcastable_where(tree>0, tree, 0))
# Tree(tree=([1], {a:1, b:2}, (1), 0))
case 2: broadcast tree to tree
```python
print(broadcastable_where(tree>0, tree, tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
bcmap
also works with all kind of arguments in the wrapped function
print(broadcastable_where(tree>0, x=tree, y=tree+100))
# Tree(tree=([1], {a:1, b:2}, (1), 99))
in concolusion, bcmap
is a function transformation that can be used to
to make functions work with arbitrary tree structures without the need to write
a specific function for each broadcasting case
moreover, bcmap
can be more powerful when used with pytreeclass
to
facilitate operation of arbitrary functions on PyTree
objects
without the need to use tree_map
For less boilerplate code and clarity, consider adding @compact
References
[1] https://flax.readthedocs.io/en/latest/flax.linen.html?highlight=compact#flax.linen.compact
summary
, tree_box
to serket
with the aim of simplifying the package structuretree_summary
appropriate to deep learning models (e.g. compact=True
)Currently
.at[].{set,apply,reduce}
, .freeze()
.unfreeze()
returns new instance without affecting the original treeclass instance.
Since pytreeclass relies mostly on the functional .at[]
methods, it makes more sense to move entirely to .at[]
methods to handle the rest of the operations.
Moreover, internal state updates can be undetected in some scenarios, which might cause a problem
This issue is to propose a couple of new things for a full immutable experience.
.at["method_name"](*args, **kwargs)
should be called on a new tree. the method should return new_tree, and the call output. . However, the __call__
method should be used without raising an error if no side effect is caused..at["attribute name"].set()
should set attribute to a new tree instance and return the updated tree.Specifically, show why it can be helpful compared to tree-specific flatten rules foundflax.struct.PyTreeNode
/ equinox.Module
and other PyTree
based libraries.
Use case:
*https://mermaid-js.github.io/mermaid/#/
**https://jax.readthedocs.io/en/latest/profiling.html
Document the structure of pytreeclass
Update getting started
Document public api
to make reasoning about field declaration simpler.
Consider adding functionality to enable users to add nodes without declaring them as dataclass fields.
Consider the following example for motivation.
@treeclass
class Linear :
weight : jnp.ndarray
bias : jnp.ndarray
def __init__(self,key,in_dim,out_dim):
self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
self.bias = jnp.ones((1,out_dim))
def __call__(self,x):
return x @ self.weight + self.bias
@treeclass
class MLP:
def __init__(self,layers:list[int],key=jax.random.PRNGKey(0)):
keys = jax.random.split(key,len(layers))
for i,(key,in_dim,out_dim) in enumerate(zip(keys,layers[:-1],layers[1:])):
layer_name = f"linear_{i}"
layer_val = Linear(key,in_dim,out_dim)
# pytreeclass consider only data in dataclass_fields as nodes.
# we can set all all layers to be one node with no problem.
# however , for tree_viz to work properly , we need to split each layer to a single node.
# create a dataclass field entry
self.__dataclass_fields__ = {**self.__dataclass_fields__ ,**{f"linear_{i}":field()}}
# set the name of the field
setattr(self.__dataclass_fields__[layer_name],"name", layer_name)
# set the value in __dict__
self.__dict__[layer_name] = layer_val
def __call__(self,x):
layers = [v for k,v in self.__dict__.items() if "linear" in k ]
for layer in layers[:-1]:
x = layer(x)
return layers[-1](x)
First, sorry to put it in issues, but there is no discussion tab. Second, I don't want this to sound negative or critical, rather I am genuinely interested in the reasoning (and it looks like your library has some nice unique features e.g. visualising PyTrees).
There are already a number of mature JAX libraries such as Equinox that handle the idea of constructing classes as PyTrees and layering on top convenient methods to manipulate them (plus I notice you written a NN library which then builds further). I was wondering why another set of libraries? What are the advantages of PyTreeClass and Serket over something like Equinox?
Hi @ASEM000 ,
I really like your library compared to some other automated pytree alternatives and would love to see more people using it.
I was interested in using pytreeclass in CoLA, a numerical linear algebra library that I have been involved in developing. One of the design constraints is that we need to be able to support usage in both jax and pytorch, whether jax is installed, pytorch is installed, or both. This decision depends on the LinearOperator objects that the user creates , and there can be scenarios even where both jax and pytorch objects exist simultaneously.
We were hoping to use pytreeclass as the base pytree for the LinearOperator objects, but have run into some issues with this cross-platform support. We know that pytreeclass was designed with support for both jax and pytorch in mind, but I couldn't find details on this topic in the docs.
Having a look in pytreeclass/_src/backend/init.py is this specified using the environment variable?
Is there any way that pytree class can function whether or not jax or pytorch is installed based on whether the imports succeed or fail? Also do you have any thoughts for whether it would be possible to have jax and pytorch pytrees existing at the same time?
Cheers,
Marc
Reading the documentation, I understand that you can freeze variables by using a mask based upon name or type.
Is it possible to set a variable to "frozen" within the class definition i.e in the way Equinox has static_field option.
While I understand the concept behind being able to mask out a set of variables contained in a PyTree (or PyTree of PyTrees), there are lots of situations where you know when creating a new class, that certain variables will only ever be constant. Furthermore, as models become much more complicated (or if others may utilise elements of your model) it becomes more cumbersome to have to mask these out / others have to know to do this.
In JAX, certain transformations require all values to be inexact (e.g. jax.grad
). In this issue, I propose how to deal with non-differentiable(noninexact) tree values with two function transformations.
First, the existing approach is to use pytc.static_field
in the class definition on each non-inexact field; however, this requires the user to know in advance the type of data used.
I propose to add functionfilter_nondiff
to mark nondifferentiable nodes static. and unfilter_nondiff
to undo the marking of these fields. In essence unfilter_nondiff(filter_nondiff(x)) == x
Lets demonstrate these two transformation with an example,
@pytc.treeclass
class Test:
a:int = 0
b:float = 1.
c:jnp.ndarray = jnp.array([1,2,3])
d:jnp.ndarray = jnp.array([1.,2.,3.])
t = Test()
print(t)
#Test(a=0,b=1.0,c=[1 2 3],d=[1. 2. 3.])
# * marks static field
print(filter_nondiff)
# Test(*a=0,b=1.0,*c=[1 2 3],d=[1. 2. 3.])
print(unfilter_nondiff(filter_nondiff(t)))
# Test(a=0,b=1.0,c=[1 2 3],d=[1. 2. 3.])
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.