lstm 中的 初始状态和 batch size


def lstm1(X, init_tuple_state, cell_layers, cell_size=100):
    '''
    initial_state shape:
        [cell_layers, 2, batch_size, cell_size]
    '''
    cells = [ tf.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True) for i in range(cell_layers) ]    
    state_per_layer_list = tf.unstack( tf.transpose(init_tuple_state, [1,2,0,3]) ,axis=0)

    initial_state12 = tuple(
        [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
        for idx in range(cell_layers)]
    )

    outputs, output_state = tf.nn.dynamic_rnn( tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) ,\
     tf.expand_dims(X,1), initial_state=initial_state12 )
    outputs = tf.transpose( outputs, [1,0,2] )[0]
    output_state = tf.stack( [ tf.stack( [output_state[i].c, output_state[i].h], axis=1 ) for i in range(cell_layers) ], axis=1 )
    return outputs, output_state


# 参数 1:batch size
# 参数 2:sequence length
# 参数 3:input dimension
X = tf.placeholder(tf.float32,[None, None, input_dim])


对于 lstm 中 state 维度的理解:

深度学习推荐
深度学习推荐

墨之科技,版权所有 © Copyright 2017-2027

湘ICP备14012786号     邮箱:ai@inksci.com