本文整理汇总了Python中envs.create_env方法的典型用法代码示例。如果您正苦于以下问题:Python envs.create_env方法的具体用法?Python envs.create_env怎么用?Python envs.create_env使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类envs
的用法示例。
在下文中一共展示了envs.create_env方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_penalty_env
# 需要导入模块: import envs [as 别名]
# 或者: from envs import create_env [as 别名]
def test_penalty_env(env):
import envs
env = envs.create_env("Pong", location="bottom", catastrophe_type="1",
classifier_file=save_classifier_path + '/0/final.ckpt')
import matplotlib.pyplot as plt
observation = env.reset()
for _ in range(20):
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
plt.imshow(observation[:,:,0])
plt.show()
print('Cat: ', info['frame/is_catastrophe'])
print('reward: ', reward)
if done:
break
示例2: run
# 需要导入模块: import envs [as 别名]
# 或者: from envs import create_env [as 别名]
def run(args):
env = create_env(args.env_id)
trainer = A3C(env, None, args.visualise, args.intrinsic_type, args.bptt)
# Variable names that start with "local" are not saved in checkpoints.
variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
init_op = tf.variables_initializer(variables_to_save)
init_all_op = tf.global_variables_initializer()
saver = FastSaver(variables_to_save)
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
logger.info('Trainable vars:')
for v in var_list:
logger.info(' %s %s', v.name, v.get_shape())
def init_fn(ses):
logger.info("Initializing all parameters.")
ses.run(init_all_op)
logdir = os.path.join(args.log_dir, 'train')
summary_writer = tf.summary.FileWriter(logdir)
logger.info("Events directory: %s", logdir)
sv = tf.train.Supervisor(is_chief=True,
logdir=logdir,
saver=saver,
summary_op=None,
init_op=init_op,
init_fn=init_fn,
summary_writer=summary_writer,
ready_op=tf.report_uninitialized_variables(variables_to_save),
global_step=None,
save_model_secs=0,
save_summaries_secs=0)
video_dir = os.path.join(args.log_dir, 'test_videos_' + args.intrinsic_type)
if not os.path.exists(video_dir):
os.makedirs(video_dir)
video_filename = video_dir + "/%s_%02d_%d.gif"
print("Video saved at %s" % video_dir)
with sv.managed_session() as sess, sess.as_default():
trainer.start(sess, summary_writer)
rewards = []
lengths = []
for i in range(10):
frames, reward, length = trainer.evaluate(sess)
rewards.append(reward)
lengths.append(length)
imageio.mimsave(video_filename % (args.env_id, i, reward), frames, fps=30)
print('Evaluation: avg. reward %.2f avg.length %.2f' %
(sum(rewards) / 10.0, sum(lengths) / 10.0))
# Ask for all the services to stop.
sv.stop()