def lstm1(X, init_tuple_state, cell_layers, cell_size=100): ''' initial_state shape: [cell_layers, 2, batch_size, cell_size] ''' cells = [ tf.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True) for i in range(cell_layers) ] state_per_layer_list = tf.unstack( tf.transpose(init_tuple_state, [1,2,0,3]) ,axis=0) initial_state12 = tuple( [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1]) for idx in range(cell_layers)] ) outputs, output_state = tf.nn.dynamic_rnn( tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) ,\ tf.expand_dims(X,1), initial_state=initial_state12 ) outputs = tf.transpose( outputs, [1,0,2] )[0] output_state = tf.stack( [ tf.stack( [output_state[i].c, output_state[i].h], axis=1 ) for i in range(cell_layers) ], axis=1 ) return outputs, output_state
# 参数 1:batch size # 参数 2:sequence length # 参数 3:input dimension X = tf.placeholder(tf.float32,[None, None, input_dim])
对于 lstm 中 state 维度的理解:
墨之科技,版权所有 © Copyright 2017-2027
湘ICP备14012786号 邮箱:ai@inksci.com