以下代码尽管无法运行,但可供阅读参考。
class
def __init__(self):
self.sess = tf.Session()
self.summary()
self.sess.run(tf.global_variables_initializer())
def summary(self):
'''
First, define summary_merged and summary_write.
'''
self.summary_merged = tf.summary.merge_all()
self.summary_write = tf.summary.FileWriter(self.work_path + 'log/BC/%s' % time.strftime('%m%d%H%M%S', time.localtime()), graph=self.sess.graph)
def build_graph(self):
tf.summary.scalar('bc_loss', tf.reshape(self.bc_loss, []))
'''
Add any scalar you want to monitor.
'''
def train_step(self, state, action):
fetches = {
'summary': self.summary_merged,
}
'''
Second, run the summary_merged with sess.
'''
def train(self):
train_step = 0
for _ in range(50):
train_step += 1
i_train = train_step
self.summary_write.add_summary(results['summary'], i_train)
'''
Third,
Use summary_write here to add summary.
Do not forget set i_train.
'''
墨之科技,版权所有 © Copyright 2017-2027
湘ICP备14012786号 邮箱:ai@inksci.com