Coder Social home page Coder Social logo

Differentiation wrt system parameters about brax HOT 6 OPEN

google avatar google commented on May 21, 2024
Differentiation wrt system parameters

from brax.

Comments (6)

cdfreeman-google avatar cdfreeman-google commented on May 21, 2024 3

Yes, this is indeed quite a bit simpler! Please take a look at our new introductory colabs. The system data is now all represented within a pytree (not a proto), and can be differentiated-with-respect-to out of the box. Some fields are a little bit harder to track down than others within the pytree, so do let us know if you can't figure out how to do what you want to do!

from brax.

erikfrey avatar erikfrey commented on May 21, 2024 2

fyi making this less irritating is on our roadmap of improvements.

from brax.

RolandZhu avatar RolandZhu commented on May 21, 2024 1

Hello, I appreciate the detailed explanation provided earlier. I am keen to know if the functionality for differentiation over physical parameters has been integrated into v2? I aim to identify system parameters using gradient descent, which would require gradient calculations with respect to elements like mass or friction. Are there resources available that could guide me in utilizing such a function, or is this feature still in the development phase?

from brax.

cdfreeman-google avatar cdfreeman-google commented on May 21, 2024

This is currently somewhat irritating to do. Brax ingests all of this data from the config protobuf, and fills fields of internal datastructures during the initialization of the system, starting hereabouts: https://github.com/google/brax/blob/main/brax/physics/system.py#L38

For things to behave properly, you'd have to essentially overwrite that mass value in all of the internal datastructures at the end of system initialization, so, something like this:

  def __init__(self, config: config_pb2.Config, differentiable_mass_scale=1.0):
    self.config = validate_config(config)

    self.num_bodies = len(config.bodies)
    self.body_idx = {b.name: i for i, b in enumerate(config.bodies)}

    self.active_pos = 1. * jnp.logical_not(
        jnp.array([vec_to_np(b.frozen.position) for b in config.bodies]))
    self.active_rot = 1. * jnp.logical_not(
        jnp.array([vec_to_np(b.frozen.rotation) for b in config.bodies]))

    self.box_plane = colliders.BoxPlane(config)
    self.capsule_plane = colliders.CapsulePlane(config)
    self.capsule_capsule = colliders.CapsuleCapsule(config)

    self.num_joints = len(config.joints)
    self.joint_revolute = joints.Revolute.from_config(config)
    self.joint_universal = joints.Universal.from_config(config)
    self.joint_spherical = joints.Spherical.from_config(config)

    self.num_actuators = len(config.actuators)
    self.num_joint_dof = sum(len(j.angle_limit) for j in config.joints)

    self.angle_1d = actuators.Angle.from_config(config, self.joint_revolute)
    self.angle_2d = actuators.Angle.from_config(config, self.joint_universal)
    self.angle_3d = actuators.Angle.from_config(config, self.joint_spherical)
    self.torque_1d = actuators.Torque.from_config(config, self.joint_revolute)
    self.torque_2d = actuators.Torque.from_config(config, self.joint_universal)
    self.torque_3d = actuators.Torque.from_config(config, self.joint_spherical)

    # reinit with data that we want to differentiate by
    self.box_plane.box = self.box_plane.box.replace(mass = mass * differentiable_mass_scale)
    self.box_plane.plane = self.box_plane.plane.replace(mass = mass * differentiable_mass_scale)
    # etc. etc. everywhere mass is used

This is, of course, extremely silly and defeats the purpose of having a simple initialization scheme if it means you have to go through and replumb all of the data. I'll noodle a bit on how to simplify this. It shouldn't be too horrible to make this "just work" because all the proto is really doing is assigning fields, which should be easily traceable by Jax, but it's a little bit indirect.

from brax.

bayerj avatar bayerj commented on May 21, 2024

Ok, I can try this. I see how this is maybe not ideal, but that's ok for now.

I am just wondering how I can find out how I could get something like a list of all the Body instances, as that is where the masses are.

from brax.

cdfreeman-google avatar cdfreeman-google commented on May 21, 2024

So, you can definitely get a list of this data via:

all_masses = [b.mass for b in some_brax_env.sys.config.bodies]

But, like I said above, this data is unpackaged and repackaged into brax-internal datastructures at system initialization in a way that isn't traceable by Jax (currently). So, while this will tell you what the masses were at init, modifying this data won't actually change the masses Brax knows about (unless you jump through the hoops I mentioned in the initialization).

from brax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.