当前位置: 首页>>代码示例>>Python>>正文


Python logger.configure方法代码示例

本文整理汇总了Python中baselines.logger.configure方法的典型用法代码示例。如果您正苦于以下问题:Python logger.configure方法的具体用法?Python logger.configure怎么用?Python logger.configure使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在baselines.logger的用法示例。


在下文中一共展示了logger.configure方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', help='Environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'cnn_int'], default='cnn_int')
    parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='linear')
    parser.add_argument('--num-timesteps', type=int, default=int(50E6))
    parser.add_argument('--v-ex-coef', type=float, default=0.1)
    parser.add_argument('--r-ex-coef', type=float, default=1)
    parser.add_argument('--r-in-coef', type=float, default=0.01)
    parser.add_argument('--lr-alpha', type=float, default=7E-4)
    parser.add_argument('--lr-beta', type=float, default=7E-4)
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
          policy=args.policy, lrschedule=args.lrschedule, num_env=16,
          v_ex_coef=args.v_ex_coef, r_ex_coef=args.r_ex_coef, r_in_coef=args.r_in_coef,
          lr_alpha=args.lr_alpha, lr_beta=args.lr_beta) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:21,代码来源:run_atari.py

示例2: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', help='Environment ID', default='Walker2d-v2')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--policy', help='Policy architecture', choices=['mlp', 'mlp_int'], default='mlp_int')
    parser.add_argument('--num-timesteps', type=int, default=int(1E6))
    parser.add_argument('--r-ex-coef', type=float, default=0)
    parser.add_argument('--r-in-coef', type=float, default=1)
    parser.add_argument('--lr-alpha', type=float, default=3E-4)
    parser.add_argument('--lr-beta', type=float, default=1E-4)
    parser.add_argument('--reward-freq', type=int, default=20)
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, policy=args.policy,
          r_ex_coef=args.r_ex_coef, r_in_coef=args.r_in_coef,
          lr_alpha=args.lr_alpha, lr_beta=args.lr_beta,
          reward_freq=args.reward_freq) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:20,代码来源:run_mujoco.py

示例3: train

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def train(env_id, num_timesteps, seed):
    import baselines.common.tf_util as U
    sess = U.single_threaded_session()
    sess.__enter__()

    rank = MPI.COMM_WORLD.Get_rank()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
        logger.set_level(logger.DISABLED)
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    def policy_fn(name, ob_space, ac_space):
        return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=32, num_hid_layers=2)
    env = make_mujoco_env(env_id, workerseed)
    trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
        max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
    env.close() 
开发者ID:bowenliu16,项目名称:rl_graph_generation,代码行数:21,代码来源:run_mujoco.py

示例4: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    logger.configure()
    env = make_atari('PongNoFrameskip-v4')
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    model = deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True,
        lr=1e-4,
        total_timesteps=int(1e7),
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
    )

    model.save('pong_model.pkl')
    env.close() 
开发者ID:hiwonjoon,项目名称:ICML2019-TREX,代码行数:27,代码来源:train_pong.py

示例5: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(2e7))

    args = parser.parse_args()

    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
开发者ID:hiwonjoon,项目名称:ICML2019-TREX,代码行数:26,代码来源:run_humanoid.py

示例6: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    import argparse
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--feature_type', type=str, default='sensor')
    parser.add_argument('--tcn_run_idx', type=int, default=1)
    parser.add_argument('--split_idx', type=int, default=1)
    parser.add_argument('--run_idx', type=int, default=1)

    args = parser.parse_args()
    logger.configure()

    rng_seed = randint(0, 1000)
    print(rng_seed)

    if args.feature_type not in ['sensor', 'visual']:
        raise Exception('Invalid Feature Type')

    train(seed=rng_seed,
          feature_type=args.feature_type,
          tcn_run_idx=args.tcn_run_idx,
          split_idx=args.split_idx,
          run_idx=args.run_idx) 
开发者ID:Finspire13,项目名称:RL-Surgical-Gesture-Segmentation,代码行数:25,代码来源:trpo_train.py

示例7: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    parser = mujoco_arg_parser()
    parser.add_argument('--lr', type=float, default=3e-4, help="Learning rate")
    parser.add_argument('--sil-update', type=float, default=10, help="Number of updates per iteration")
    parser.add_argument('--sil-value', type=float, default=0.01, help="Weight for value update")
    parser.add_argument('--sil-alpha', type=float, default=0.6, help="Alpha for prioritized replay")
    parser.add_argument('--sil-beta', type=float, default=0.1, help="Beta for prioritized replay")

    args = parser.parse_args()
    logger.configure()
    model, env = train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
            lr=args.lr,
            sil_update=args.sil_update, sil_value=args.sil_value,
            sil_alpha=args.sil_alpha, sil_beta=args.sil_beta)

    if args.play:
        logger.log("Running trained model")
        obs = np.zeros((env.num_envs,) + env.observation_space.shape)
        obs[:] = env.reset()
        while True:
            actions = model.step(obs)[0]
            obs[:]  = env.step(actions)[0]
            env.render() 
开发者ID:junhyukoh,项目名称:self-imitation-learning,代码行数:25,代码来源:run_mujoco_sil.py

示例8: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized)
    )
    # act.save("pong_model.pkl") XXX
    env.close() 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:36,代码来源:run_atari.py

示例9: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    args = atari_arg_parser().parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:6,代码来源:run_atari.py

示例10: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    args = mujoco_arg_parser().parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:6,代码来源:run_mujoco.py

示例11: train

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def train(env_id, num_timesteps, seed):
    from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
    from baselines.trpo_mpi import trpo_mpi
    import baselines.common.tf_util as U
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])

    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)
    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
    env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
        max_timesteps=int(num_timesteps * 1.1), gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00)
    env.close() 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:28,代码来源:run_atari.py

示例12: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
    args = parser.parse_args()
    logger.configure()
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
        policy=args.policy) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:9,代码来源:run_atari.py

示例13: train

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def train(env_id, num_timesteps, seed):
    from baselines.ppo1 import pposgd_simple, cnn_policy
    import baselines.common.tf_util as U
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    if rank == 0:
        logger.configure()
    else:
        logger.configure(format_strs=[])
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = make_atari(env_id)
    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)

    env = wrap_deepmind(env)
    env.seed(workerseed)

    pposgd_simple.learn(env, policy_fn,
        max_timesteps=int(num_timesteps * 1.1),
        timesteps_per_actorbatch=256,
        clip_param=0.2, entcoeff=0.01,
        optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
        gamma=0.99, lam=0.95,
        schedule='linear'
    )
    env.close() 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:33,代码来源:run_atari.py

示例14: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    parser = atari_arg_parser()
    parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
    parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
    parser.add_argument('--logdir', help ='Directory for logging')
    args = parser.parse_args()
    logger.configure(args.logdir)
    train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
          policy=args.policy, lrschedule=args.lrschedule, num_cpu=16) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:11,代码来源:run_atari.py

示例15: main

# 需要导入模块: from baselines import logger [as 别名]
# 或者: from baselines.logger import configure [as 别名]
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
    parser.add_argument('--seed', help='RNG seed', type=int, default=0)
    parser.add_argument('--prioritized', type=int, default=1)
    parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
    parser.add_argument('--dueling', type=int, default=1)
    parser.add_argument('--num-timesteps', type=int, default=int(10e6))
    parser.add_argument('--checkpoint-freq', type=int, default=10000)
    parser.add_argument('--checkpoint-path', type=str, default=None)

    args = parser.parse_args()
    logger.configure()
    set_global_seeds(args.seed)
    env = make_atari(args.env)
    env = bench.Monitor(env, logger.get_dir())
    env = deepq.wrap_atari_dqn(env)

    deepq.learn(
        env,
        "conv_only",
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=bool(args.dueling),
        lr=1e-4,
        total_timesteps=args.num_timesteps,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=bool(args.prioritized),
        prioritized_replay_alpha=args.prioritized_replay_alpha,
        checkpoint_freq=args.checkpoint_freq,
        checkpoint_path=args.checkpoint_path,
    )

    env.close() 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:42,代码来源:run_atari.py


注:本文中的baselines.logger.configure方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。