Coder Social home page Coder Social logo

Comments (6)

HeyZero avatar HeyZero commented on August 23, 2024

Hello.

I had some problems regards DeepQNetwork implementation using Pytorch.

I ran the code showed in your youtube video. I've got this error code:

~/PROJECTS/PYTORCH_TUTORIAL/main_DQN_file.py in <module>
     35             brain.store_transition(observation, action, reward, observation_, done)
     36 
---> 37             brain.learn()
     38             observation = observation_
     39         scores.append(score)

~/PROJECTS/PYTORCH_TUTORIAL/simple_DQN.py in learn(self)
    123             print("Q_Target slice: ",q_target[batch_index,actions_random])
    124             q_target[batch_index, action_indices] = reward_batch + \
--> 125                 self.gamma*T.max(q_next,dim=1)[0]*terminal_batch
    126 
    127             self.epsilon = self.epsilon*self.eps_dec if self.epsilon > \

IndexError: The shape of the mask [64] at index 0 does not match the shape of the indexed tensor [64, 4] at index 1

This error shows if the action indices are calculated using dot operator.

When I use np.argmax function whole network works properly.

Have you encountered this type of problem?

I have encountered the same error, I don't quite understand this sentence
When I use np.argmax function whole network works properly.
Can you tell me how to modify the code to run?

Thank you!

from youtube-code-repository.

KuKuXia avatar KuKuXia commented on August 23, 2024

Hello, guys!
I found a solution to this problem, you need to change the line number of 124 from

 q_target[batch_index, action_indices] = reward_batch + self.gamma*T.max(q_next,dim=1)[0]*terminal_batch

to:

q_target[action_batch] = reward_batch + \
                self.GAMMA*T.max(q_next, dim=1)[0]*terminal_batch

I hope this could help.

from youtube-code-repository.

BumjunJung9287 avatar BumjunJung9287 commented on August 23, 2024

Hello!
I got the same index error as you guys
I fixed it by giving each element one by one as following

target_update = reward_batch + \
                            self.gamma*T.max(q_next, dim=1)[0]*terminal_batch
for i in range(len(batch_index)):
    q_target[batch_index[i], action_indices[i]] = target_update[i]

This should fix the error for other people too I hope.
It worked for me.

from youtube-code-repository.

philtabor avatar philtabor commented on August 23, 2024

The code still functions on my local machine. I'm scratching my head trying to find out where the issues are cropping up. I have no doubt you guys are having problems, but just posting a snippet isn't super helpful.

Can you guys post your version in a git and then link so I can view the code? There could be something subtle elsewhere that leads to an issue.

Using just the action batch will not work as you will end up with the wrong dimensions (you get batch_size x batch_size, I believe).

In hindsight, there is no reason to go to a 1 hot encoding and then back. It's needlessly complex and just introduces the potential for bugs. It's been so long since I've made the video that I can't remember my thought process behind it.

from youtube-code-repository.

philtabor avatar philtabor commented on August 23, 2024

Hello, guys!
I found a solution to this problem, you need to change the line number of 124 from

 q_target[batch_index, action_indices] = reward_batch + self.gamma*T.max(q_next,dim=1)[0]*terminal_batch

to:

q_target[action_batch] = reward_batch + \
                self.GAMMA*T.max(q_next, dim=1)[0]*terminal_batch

I hope this could help.

This will fix the dimensional mismatch, but gives the incorrect values for q_target. You can verify this by setting the batch size to something small (say 8), printing q_target before you index it with action_indices, and then setting q_target[action_indices] = dummy_value and printing q_target again. You will see that you don't get what you expect (or want).

from youtube-code-repository.

philtabor avatar philtabor commented on August 23, 2024

In a strange turn of events, I hosed my Anaconda install while trying to install manimlib. I had to reinstall Anaconda, and after doing so I get the same error.

The error comes in because the data type of the action_indices is np.uint8. Combining the uint8 with the int32 of the batch_index causes the error. Switching the datatype of action_indices to np.int32 fixes the problem with the dimensional mismatch and yields the expected results when using the test I propose in the comment above.

Fixed code is up on the repo.

from youtube-code-repository.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.