Coder Social home page Coder Social logo

howuhh / faster-trajectory-transformer Goto Github PK

View Code? Open in Web Editor NEW
99.0 99.0 12.0 37.48 MB

Implementation of Trajectory Transformer with attention caching and batched beam search

License: MIT License

Python 100.00%
reinforcement-learning trajectory-transformer transformer

faster-trajectory-transformer's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

faster-trajectory-transformer's Issues

Jax Code

Hi,

Thank you very much for your contribution.

I would like to ask is it possible to release the code based on Jax

Best

training input and target

Hi, I have another question regarding the training input and target.
Why is the training input x sliced from the flattened tensor? what I'm getting is that the input x is removing the value from last step of the sequence length while the target y is removing the 1st dimension of obs in the 1st step? shouldn't the training inputs be sliced along the seq_len rather than flattening everything and only taking out 1 token?
In the code below, shouldn't we not reshape and slice it.

joined_discrete = self.discretizer.encode(joined).reshape(-1).astype(np.long) loss_pad_mask = loss_pad_mask.reshape(-1) return joined_discrete[:-1], joined_discrete[1:], loss_pad_mask[:-1]

q,k,v linear layers

Hi, can i ask why is the q,k,v used as the input itself instead of passing through a linear layer?
Comparing to the original code, this seems to be different.
Thanks!

possible issue with rewards to go?

values[t] = (rewards[t + 1:] * discounts[:-t - 1]).sum()

Hi, I have been using this repo for some experiments and was digging into some parts of it and was wondering is this correct? These values don't seem to match up with what I see for the sample trajectory off of https://github.com/jannerm/trajectory-transformer/blob/8834a6ed04ceeab8fdb9465e145c6e041c05d71b/trajectory/datasets/sequence.py#L97

There is also a high likelihood I am wrong but the rewards to go for me seem like they are much bigger than I would expect. If so its possible this line is just supposed to be (rewards[t + 1 :].T @ discounts[: -t - 1]) (which for me then matches up with the original repo RTG values) but I am not certain if I am understanding everything correctly.

For comparing values on halfcheetah-medium-v2 the rtg on the first trajectory using your calculation:

(Pdb) values.sum()
242487360.0

while on the original repo:

(Pdb) self.values_segmented[0].sum()
432073.4913363826

Novel Dataset Preparation

Do you have any recommendations or resources you could point me to for preparing a novel dataset for use in Trajectory Transformer?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.