def smooth(scalars, weight): ''' def smooth(scalars: List[float], weight: float) -> List[float]: # Weight between 0 and 1 ''' last = scalars[0] # First value in the plot (first timestep) smoothed = list() for point in scalars: smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value smoothed.append(smoothed_val) # Save it last = smoothed_val # Anchor the last smoothed value return smoothed path_list = [ # '/media/inksci/doc/deepolicy/deepfire-yun-data/yun113-core32/bf-nn3-0909-1945-1e-5-bs128-rerun2/battle-framework/agent/deepfire/data/reward/', # '/media/inksci/doc/deepolicy/data-deepfire-paper/bf-nn3-0909-1945-1e-5-bs128-1217-rerun4.reward/reward/', '/media/inksci/doc/deepolicy/data-deepfire-paper/4-rewards-99/bf-nn3-0909-1945.1e-5.reward.ip222/reward/', '/media/inksci/doc/deepolicy/data-deepfire-paper/rewards-7envs-paper/bf-nn3-0909-1945.1e-5.reward.ip223/reward/', '/media/inksci/doc/deepolicy/data-deepfire-paper/4-rewards-99/bf-nn3-0909-1945.1e-5.reward.ip224/reward/', '/media/inksci/doc/deepolicy/data-deepfire-paper/4-rewards-99/bf-nn3-0909-1945-1e-5-bs32-1217-rerun2.reward.ip224/reward/', ] for color, path in zip(['#ff7043', '#0077bb', '#cc3311', '#33bbee'], path_list): last_reward_list = get_last_reward_list(path) plt.plot(last_reward_list[:99], color, alpha=0.2) for color, path in zip(['#ff7043', '#0077bb', '#cc3311', '#33bbee'], path_list): last_reward_list = get_last_reward_list(path) plt.plot(smooth(last_reward_list[:99], 0.8), color) plt.xlabel('Episode') plt.ylabel('Reward') plt.grid() plt.ylim(-2, 4) plt.savefig('im.pdf') plt.show()
墨之科技,版权所有 © Copyright 2017-2027
湘ICP备14012786号 邮箱:ai@inksci.com