Comments (3)
@morestart, you probably already know the answer. However, in case anyone else was wondering. LSTM in pytorch is a multi layer network, that is why you can select the number of layers. LSTMCell, on the other hand, is just a single cell. The author uses the latter here because of the way the attention has to be computed at each step in the training process. With a multilayer LSTM you could not do that, as the layer connections and forward pass are hard coded.
from a-pytorch-tutorial-to-image-captioning.
@thanhtvt Exactly!
And this is precisely the example PyTorch provides in the docs:
If you take a look at the LSTM page:
rnn = nn.LSTM(10, 20, 2) # (10 = input size, 20 = hidden size, 2 = this is the number of layers)
input = torch.randn(5, 3, 10) # (5 = this is the sequence length, 3 = this is the batch size,
# 10 = this is the last dimension, has to be equal to the input shape of the LSTM)
h0 = torch.randn(2, 3, 20) # (2 = here is the number of layers again, 3 = the batch size has to match,
# 20 = the hidden state has to match)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0)) # the output here is going to be of size [5,3,20], just like the input
Then at the LSTMCell page, it's pretty much the same thing, but using a for
loop:
rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
for i in range(input.size()[0]):
hx, cx = rnn(input[i], (hx, cx))
output.append(hx)
output = torch.stack(output, dim=0) # output.size() will be [2,3,20], as you stacked the hx's [3,20] across the first dimension.
from a-pytorch-tutorial-to-image-captioning.
@AndreiMoraru123 so if I set the number of layers in LSTM as 2, is it the same as I build a 2-time for-loop with LSTMCell?
from a-pytorch-tutorial-to-image-captioning.
Related Issues (20)
- Asking about List of Packages Version
- Failed to create .hdf5 files
- bleu-4 does not increase HOT 2
- Code stopped @ train.py with this error,
- Anyone , Please help to solve this error HOT 3
- If i want to use RL for this , how will i do ?
- eval.val is not working.
- Dimension error HOT 2
- Dataset not available HOT 3
- ValueError: max() arg is an empty sequence HOT 3
- High loss and low bleu-4 for training HOT 7
- Example Notebook?? HOT 1
- In eval.py, it seems that for each image, the model has to recalculate 5 times, is it too inefficient?
- Please help me with the error:RuntimeError: Expected target size [32, 9490], got [32, 51] HOT 1
- I think this a bug. caption.py 140 HOT 5
- RuntimeError: CUDA error: device-side assert triggered
- Only the author mentions half the bleu-4 score HOT 2
- can this model detect and recognize text in images containing text HOT 1
- I get this issues:Dimension out of range (expected to be in range of [-2, 1], but got 2) HOT 1
- The code is wrong..... HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from a-pytorch-tutorial-to-image-captioning.