Comments (12)
1.1 官方文档里面里面说明了以下3种写法是等价的:
model = Sequential()
model.add(LSTM(32, input_shape=(10, 64)))
model = Sequential()
model.add(LSTM(32, batch_input_shape=(None, 10, 64)))
model = Sequential()
model.add(LSTM(32, input_length=10, input_dim=64))
我使用input_shape
是因为我在model.fit
里面设置了batch_size
参数,个人理解应该和直接设置batch_input_shape
等价,不知道对不对
另外,你需要训练的数据一共500组,还是每次训练500组,这个是两个概率,batch_size
的理解可以看下深度学习中的 Batch_Size
1.2 如果输入x是(3,1)维的话,time_step=3,怎么设置成20?
- 没太明白你问的是什么。。。。还是说下
stateful,默认的也是 false,意义是批和批之间是否有联系。直观的理解就是我们在读完二十步,第21步开始是接着前面二十步的。也就是第一个 batch中的最后一步与第二个 batch 中的第一步之间是有联系的。
return_sequences:布尔值,默认False,控制返回类型。若为True则返回整个序列,否则仅返回输出序列的最后一个输出
对于return_sequences
:如果return_sequences=True
:返回形如(samples,timesteps,output_dim)的3D张量;否则,返回形如(samples,output_dim)的2D张量
from tutorials.
最近我也在用keras,我说下我的理解:
你说了使用LSTM时序预测
,那么肯定是用递归层,其输入输出格式为:
输入shape:形如(samples,timesteps,input_dim)的3D张量
输出shape:
1)如果return_sequences=True:返回形如(samples,timesteps,output_dim)的3D张量
2)否则,返回形如(samples,output_dim)的2D张量
你的输入输出为many to one
,那么return_sequences=False
,也就是输入x的shape为(500,3,1),输入y的shape为(500,1)
model = Sequential()
# LSTM layer
model.add(LSTM(input_shape=(3, 1), #输入x的维度
output_dim = 5, #隐藏层数量
return_sequences = False))
# dense layer
model.add(Dense(output_dim=1)) #输出y的维度
from tutorials.
谢谢谢谢, 但我还是有2个问题没有理解。
1 关于many to one
的问题
我对比了一下莫烦老师的rnn lstm的代码。 他的代码如下。
# build a LSTM RNN
model.add(LSTM(
batch_input_shape=(BATCH_SIZE, TIME_STEPS, INPUT_SIZE), # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,
output_dim=CELL_SIZE,
return_sequences=True, # True: output at all steps. False: output as last step.
stateful=True, # True: the final state of batch1 is feed into the initial state of batch2
))
我对比了一下你的代码。发下有一个问题:
在LSTM()
的参数设置里,莫烦老师用的是batch_input_shape
, 或者是 input_dim
。 你的例子当中用的是input_shape
。
那有一个小问题就是何时采用batch_input_shape
,何时采用input_shape
,何时采用input_dim
呢?
即我遇到问题和莫烦老师的区别就是,他的例子是说Time_Step
是说每一个x
就是一个1*1
的矩阵,然后对应一个Y
。即x1=(1),y1=(1); x2=(2) , y2=(2)...
。
如果time_step=20
,即每一组x1,x2...x20
,一次输入20个时间点的(x,y)
,然后让模型自己去找这20个时间点的(x,y)
之间x1,x2,x3...x20
对y
影响大小的参数。
但是我的问题是每一个x
是 一个(1*3)的矩阵,即X1=(1,2,3),Y1=(1); X2=(4,5,6),Y2=(2)
那这样的话,如果time_step=20, batch_size=500
,每次的矩阵大小应该是: batch_size*time_step*x.shape=500*20*3*1
这样子的话那该如何设置呢?
2 关于return_sequences
的问题
return_sequences
的意思,我查了一下手册,应该说的是是否每一个x对应一个y,如果为False
则表示一个time_step
的x只返回一个y. 那这个变量和stateful
的区别又在哪里呢?
这几个问题一直困扰着我,因为实际问题当中,这几个弄不明白的话,我的模型总感觉和普通的regressor结果差了太多。甚至还没有朴素贝叶斯的结果好。
from tutorials.
太感谢了,但是我好像还有些混淆shape的含义了。 那我再问一个问题哈~
莫烦老师的time_step的真实含义我现在有些混乱, 是以下2种方式的哪一种呢?:
-
如果
time_step=3
表示的是说 x是一个(3,1)维的矩阵 即: x1=(1,2,3) x2=(2,3,4) 对应y1,y2?即time_step
表示的是x的数据维度。 -
如果
time_step=3
表示的是说 x还是1维数据(x1=(1)), 但是建立模型的时候是用(x1,x2,x3) 对应一个y1?即time_step
表示的是每次有X的三个lag 数据。 即 Xn, Xn-1, Xn-2。如果time_step
是这种意思,那如果我的X 本身是一个(3*1)
的矩阵,即 X1=(1,2,3), 那当time_step=3的时候,input_shape 是不是应该等于 input_length=time_step*3, input_dim=1
请问正确的time_step理解方式应该是哪一种呢?
from tutorials.
第二种,相当于就是用(Xn-2,Xn-1,Xn)预测Xn+1时刻的值。
如果Xn是3维,例如X1=(1,2,3),那么input_length=3, input_dim=3
,前者表示time_step
,后者表示数据维度
from tutorials.
好的,终于弄明白了!谢谢谢谢!!谢谢。那也就是说如果 Xn 是一个 n*m
的矩阵 那么 input_dim=[n,m]
这样子了~~~ 太谢谢了~~~~~~
from tutorials.
input_dim=[n,m]没这么用过,感觉不行。具体看看文档。
from tutorials.
麻烦老师,看了老师与上面那个朋友的对话,我还是没太明白,这个input_shape该怎么设置。
我说下我的project背景。一个文件,总共有20000行,每一行有264个特征值,再加一个label值(0/1).
lstm的输入是个3D,确实特别懵逼,不知道该如何转,input_shape 该如何设置。还有input_length以及input_dim都表示什么?out_dim又表示什么,麻烦老师了。
from tutorials.
@Wall-ee
这东西是不是可以两种处理法,
第一种 是把数据按time_steps=3这样传进去
第二种 是把三个时间步上的数据整理成一个大矩阵,然后网络的time_steps=1,这样传进去
from tutorials.
o o o o
[ ] [ ] [ ] [ ] [ ] [ ]
i i i
如果是这样的 many to many 呢
from tutorials.
@Wall-ee
这东西是不是可以两种处理法,
第一种 是把数据按time_steps=3这样传进去
第二种 是把三个时间步上的数据整理成一个大矩阵,然后网络的time_steps=1,这样传进去
用第二种, 第一种的time_steps=3 的做法是不对的 因为这样的话,会造成逻辑上训练数据的缺失~因为 每一个timestep 理论上应该是一对 x,y
from tutorials.
o o o o [ ] [ ] [ ] [ ] [ ] [ ] i i i
如果是这样的 many to many 呢
这个就是y的形状问题
from tutorials.
Related Issues (20)
- visualize cpu history HOT 3
- lstm的batch_size要刚好被数据集整除么 HOT 1
- tensorflowTUT/tf12_plot_result/full_code.py NameError: name 'time' is not defined HOT 3
- ssl.SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:777)5-classifier_example.py
- 有没有读取自己的图像训练集的代码呀
- 您好,sk9_cross_validation2.py中的实例在python3.6.6中运行报错
- Classification那一课 HOT 1
- 为什么我的没有画线 HOT 9
- 在tensorflow中的dropout中修改建议 HOT 1
- 建议
- 激活函数
- pandas values HOT 4
- 推荐增加一些简易的新算法
- 求助一下 tensorflow是2.0.0的 报错train中没有GradientDescentOptimizer
- 请问强化学习是怎么影响神经网络参数的
- Is this your website? https://www.echenshe.com/class/tensorflow/ HOT 1
- Qw
- tensorflow 还支持另外一种batch normalization 的方法 HOT 2
- tensorflow 2.1.0 error HOT 3
- 网站进不去了 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tutorials.