This is a port of TRFL to JAX.
The included functions are:
discrete_policy_gradient
discrete_policy_gradient_loss
policy_gradient
policy_gradient_loss
scan_discounted_sum
batched_index
There are a few classes implementing Tensorflow Probability interfaces since some of the TRFL functions expect them. These are:
A few new functions included in this package are:
assert_array
- check array shape and dtypebroadcast_index
- a more generalbatched_index
PRNGSequence
- an infinite iterator ofPRNGKey
s