Coder Social home page Coder Social logo

能否介绍一下在机器学习,尤其是keras或者TensorFlow 当中,参数的shape具体该如何设置?谢谢啦 about tutorials HOT 12 CLOSED

morvanzhou avatar morvanzhou commented on May 16, 2024
能否介绍一下在机器学习,尤其是keras或者TensorFlow 当中,参数的shape具体该如何设置?谢谢啦

from tutorials.

Comments (12)

charli2014 avatar charli2014 commented on May 16, 2024 9

1.1 官方文档里面里面说明了以下3种写法是等价的:

model = Sequential()
model.add(LSTM(32, input_shape=(10, 64)))
model = Sequential()
model.add(LSTM(32, batch_input_shape=(None, 10, 64)))
model = Sequential()
model.add(LSTM(32, input_length=10, input_dim=64))

我使用input_shape是因为我在model.fit里面设置了batch_size参数,个人理解应该和直接设置batch_input_shape等价,不知道对不对

另外,你需要训练的数据一共500组,还是每次训练500组,这个是两个概率,batch_size的理解可以看下深度学习中的 Batch_Size

1.2 如果输入x是(3,1)维的话,time_step=3,怎么设置成20?


  1. 没太明白你问的是什么。。。。还是说下

stateful,默认的也是 false,意义是批和批之间是否有联系。直观的理解就是我们在读完二十步,第21步开始是接着前面二十步的。也就是第一个 batch中的最后一步与第二个 batch 中的第一步之间是有联系的。

return_sequences:布尔值,默认False,控制返回类型。若为True则返回整个序列,否则仅返回输出序列的最后一个输出

对于return_sequences:如果return_sequences=True:返回形如(samples,timesteps,output_dim)的3D张量;否则,返回形如(samples,output_dim)的2D张量

from tutorials.

charli2014 avatar charli2014 commented on May 16, 2024 5

最近我也在用keras,我说下我的理解:
你说了使用LSTM时序预测,那么肯定是用递归层,其输入输出格式为:

输入shape:形如(samples,timesteps,input_dim)的3D张量
输出shape:
1)如果return_sequences=True:返回形如(samples,timesteps,output_dim)的3D张量
2)否则,返回形如(samples,output_dim)的2D张量

你的输入输出为many to one,那么return_sequences=False,也就是输入x的shape为(500,3,1),输入y的shape为(500,1)

model = Sequential()
# LSTM layer
model.add(LSTM(input_shape=(3, 1),               #输入x的维度
                        output_dim = 5,          #隐藏层数量
                        return_sequences = False))
# dense layer
model.add(Dense(output_dim=1))                   #输出y的维度

from tutorials.

Wall-ee avatar Wall-ee commented on May 16, 2024

谢谢谢谢, 但我还是有2个问题没有理解。

1 关于many to one的问题

我对比了一下莫烦老师的rnn lstm的代码。 他的代码如下。

# build a LSTM RNN
model.add(LSTM(
    batch_input_shape=(BATCH_SIZE, TIME_STEPS, INPUT_SIZE),       # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,
    output_dim=CELL_SIZE,
    return_sequences=True,      # True: output at all steps. False: output as last step.
    stateful=True,              # True: the final state of batch1 is feed into the initial state of batch2
))

我对比了一下你的代码。发下有一个问题:

LSTM()的参数设置里,莫烦老师用的是batch_input_shape, 或者是 input_dim。 你的例子当中用的是input_shape

那有一个小问题就是何时采用batch_input_shape ,何时采用input_shape,何时采用input_dim呢?

即我遇到问题和莫烦老师的区别就是,他的例子是说Time_Step是说每一个x就是一个1*1的矩阵,然后对应一个Y。即x1=(1),y1=(1); x2=(2) , y2=(2)...

如果time_step=20,即每一组x1,x2...x20,一次输入20个时间点的(x,y),然后让模型自己去找这20个时间点的(x,y)之间x1,x2,x3...x20y影响大小的参数。

但是我的问题是每一个x是 一个(1*3)的矩阵,即X1=(1,2,3),Y1=(1); X2=(4,5,6),Y2=(2) 那这样的话,如果time_step=20, batch_size=500,每次的矩阵大小应该是: batch_size*time_step*x.shape=500*20*3*1 这样子的话那该如何设置呢?

2 关于return_sequences的问题

return_sequences的意思,我查了一下手册,应该说的是是否每一个x对应一个y,如果为False则表示一个time_step的x只返回一个y. 那这个变量和stateful的区别又在哪里呢?

这几个问题一直困扰着我,因为实际问题当中,这几个弄不明白的话,我的模型总感觉和普通的regressor结果差了太多。甚至还没有朴素贝叶斯的结果好。

from tutorials.

Wall-ee avatar Wall-ee commented on May 16, 2024

太感谢了,但是我好像还有些混淆shape的含义了。 那我再问一个问题哈~

莫烦老师的time_step的真实含义我现在有些混乱, 是以下2种方式的哪一种呢?:

  1. 如果time_step=3 表示的是说 x是一个(3,1)维的矩阵 即: x1=(1,2,3) x2=(2,3,4) 对应y1,y2?即 time_step表示的是x的数据维度。

  2. 如果time_step=3 表示的是说 x还是1维数据(x1=(1)), 但是建立模型的时候是用(x1,x2,x3) 对应一个y1?即 time_step 表示的是每次有X的三个lag 数据。 即 Xn, Xn-1, Xn-2。如果time_step是这种意思,那如果我的X 本身是一个 (3*1)的矩阵,即 X1=(1,2,3), 那当time_step=3的时候,input_shape 是不是应该等于 input_length=time_step*3, input_dim=1

请问正确的time_step理解方式应该是哪一种呢?

from tutorials.

charli2014 avatar charli2014 commented on May 16, 2024

第二种,相当于就是用(Xn-2,Xn-1,Xn)预测Xn+1时刻的值。
如果Xn是3维,例如X1=(1,2,3),那么input_length=3, input_dim=3,前者表示time_step,后者表示数据维度

from tutorials.

Wall-ee avatar Wall-ee commented on May 16, 2024

好的,终于弄明白了!谢谢谢谢!!谢谢。那也就是说如果 Xn 是一个 n*m 的矩阵 那么 input_dim=[n,m] 这样子了~~~ 太谢谢了~~~~~~

from tutorials.

charli2014 avatar charli2014 commented on May 16, 2024

input_dim=[n,m]没这么用过,感觉不行。具体看看文档。

from tutorials.

hgdwxx avatar hgdwxx commented on May 16, 2024

麻烦老师,看了老师与上面那个朋友的对话,我还是没太明白,这个input_shape该怎么设置。
我说下我的project背景。一个文件,总共有20000行,每一行有264个特征值,再加一个label值(0/1).
lstm的输入是个3D,确实特别懵逼,不知道该如何转,input_shape 该如何设置。还有input_length以及input_dim都表示什么?out_dim又表示什么,麻烦老师了。

from tutorials.

zhkmxx9302013 avatar zhkmxx9302013 commented on May 16, 2024

@Wall-ee
这东西是不是可以两种处理法,
第一种 是把数据按time_steps=3这样传进去
第二种 是把三个时间步上的数据整理成一个大矩阵,然后网络的time_steps=1,这样传进去

from tutorials.

Jasdent avatar Jasdent commented on May 16, 2024
         o  o   o  o   
[ ] [ ] [ ] [ ] [ ] [ ]
 i   i   i  

如果是这样的 many to many 呢

from tutorials.

Wall-ee avatar Wall-ee commented on May 16, 2024

@Wall-ee
这东西是不是可以两种处理法,
第一种 是把数据按time_steps=3这样传进去
第二种 是把三个时间步上的数据整理成一个大矩阵,然后网络的time_steps=1,这样传进去

用第二种, 第一种的time_steps=3 的做法是不对的 因为这样的话,会造成逻辑上训练数据的缺失~因为 每一个timestep 理论上应该是一对 x,y

from tutorials.

Wall-ee avatar Wall-ee commented on May 16, 2024
         o  o   o  o   
[ ] [ ] [ ] [ ] [ ] [ ]
 i   i   i  

如果是这样的 many to many 呢

这个就是y的形状问题

from tutorials.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.