Coder Social home page Coder Social logo

Comments (5)

mlech26l avatar mlech26l commented on September 27, 2024 3

The LTC usually requires higher learning rates than feedforward networks.

So, in our examples, 0.01 or 0.005 (autonomous driving) worked okay for supervised learning.

As you mentioned, RL typically works better with smaller learning rates, so I expect 3e-3 to 3e-4 would be a good start.

from ncps.

mlech26l avatar mlech26l commented on September 27, 2024 2

Hi,
[batch size, 1, 256] is the correct shape, i.e., a batch of length 1 sequences, each having 256 features.
The output tensor of the LTC network should be then [batch size,1 , action_dim]. The squeeze operation then removes the sequence dimension, leaving a [batch size, action_dim] tensor.

In the above code, you are using 8 neurons for the LTC. This is very small and could be one of the reasons why the model is not learning. Try maybe 64 or 128. If that also does not work, maybe change the learning rate.

from ncps.

mlech26l avatar mlech26l commented on September 27, 2024 1

Hi @sprakashdash

Glad that you liked our work!

There are a few issues in the QNetwork_w_LTC:

First, in_features are the size of the input tensor (=number of input features), which is 64 in your case.

Second, RNNSequence processes entire sequences of samples instead of just individual samples.
So, the input should have the size/shape (batch size, sequence length, in_features).
What your code is doing is to represent the input as (batch size, observation dim + action dim, 1), i.e., instead of processing the input at ones, it looks at each component after another. This is probably the reason why your code is so slow. (Note that the LTC is expect to be 10x slower than a single layer due to the ODE solver).

Generally, the Q learning are doing does not seem like it has a temporal component, i.e., it processes individual samples instead of also providing past observations. In such context, I would not expect any advantage of RNNs (such as the LTC) over just normal feedforward networks.
If you want to give it a try anyway, you can create sequences of length 1 by replacing

#x = x.unsqueeze(-1)
# with 
x = x.unsqueeze(1)

from ncps.

sprakashdash avatar sprakashdash commented on September 27, 2024

@mlech26l Thanks a lot for the heads up. Now it seems the training is 10x slower than generic nn.Linear layers! I understand that as LTC cells are similar to RNN cells (in their abstraction). So just training a random batch would not produce an out of the box result. But I was thinking of keeping the causality of the replay buffer and picking a random causal batch of <s,a,r,s'> for training.

The last thing I would like to ask is about the learning rate of LTC cells. I found in all the notebooks that the learning rate is 0.01. While most RL function approximations are kept at a very low learning rate (3e-4). So does the learning rate depends on the application (like supervised learning vs RL) or it depends on the type of NN (LTC vs ANN)?

Could you also tell me what was the learning rate when you trained LTC neurons on real car data for autonomous driving?

from ncps.

sprakashdash avatar sprakashdash commented on September 27, 2024

As you previously mentioned, there won't be any advantage of LTCs over ANNs and there is no temporal component in RL and the batch is chosen at random, but at least I still got to train the Qnet with LTC and ActorNet with ANN. But when I tried to run my experiments for QNet and Actor both with LTC The actor is not able to learn.
Here is the previous ActorNet:

class Actor(nn.Module):
    def __init__(self, env):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return torch.tanh(self.fc_mu(x))

Here is the ActorNet with LTC:

class Actor(nn.Module):
    def __init__(self, env):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        # self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))
        ###########################
        self.wiring = kncp.wirings.FullyConnected(units=8, output_dim=np.prod(env.single_action_space.shape))
        self.ltc_cell = LTCCell(wiring=self.wiring, in_features=256)
        #replace the self.fc_mu layer
        ##########################

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.unsqueeze(1)
        ltc_sequence = RNNSequence(self.ltc_cell)
        x = ltc_sequence.forward(x)
        x = torch.tanh(x).squeeze()
        return x

As you pointed out, I have kept the in_features to be the hidden state dimension of the NN, that is 256. But you have also mentioned that the shape of the input tensor to the LTC Sequence should be [batch size, sequence length, in_features]. Now the shape of the input tensor to ltc_sequence is [batch size, 1, 256] but I think it should be [batch size, action_dim, 256]. Could you help me understand how to create such a dimension of tensor?

from ncps.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.