当前位置: 首页>>代码示例>>Python>>正文


Python cmd_util.common_arg_parser方法代码示例

本文整理汇总了Python中baselines.common.cmd_util.common_arg_parser方法的典型用法代码示例。如果您正苦于以下问题:Python cmd_util.common_arg_parser方法的具体用法?Python cmd_util.common_arg_parser怎么用?Python cmd_util.common_arg_parser使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在baselines.common.cmd_util的用法示例。


在下文中一共展示了cmd_util.common_arg_parser方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(custom_args=[]):
    # configure logger, disable logging in child MPI processes (with rank > 0) 
    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = {}
    for arg in custom_args:
        if arg in vars(args).keys():
            vars(args)[arg] = custom_args[arg]
        else:
            extra_args[arg] = custom_args[arg]
    
    #extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}

    
    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure(format_strs = ['stdout', 'tensorboard'])
    else:
        logger.configure(format_strs = ['stdout', 'tensorboard'])
        rank = MPI.COMM_WORLD.Get_rank()

    model, _ = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)
    

    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        obs = env.reset()
        while True:
            actions = model.step(obs)[0]
            obs, _, done, _  = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset() 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:42,代码来源:run.py

示例2: main

# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main():
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)
    env.close()

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        obs = env.reset()
        def initialize_placeholders(nlstm=128,**kwargs):
            return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
        state, dones = initialize_placeholders(**extra_args)
        while True:
            actions, _, state, _ = model.step(obs,S=state, M=dones)
            obs, _, done, _ = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset()

        env.close() 
开发者ID:hiwonjoon,项目名称:ICML2019-TREX,代码行数:40,代码来源:run.py

示例3: main

# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if args.extra_import is not None:
        import_module(args.extra_import)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        episode_rew = 0
        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs,S=state, M=dones)
            else:
                actions, _, _, _ = model.step(obs)

            obs, rew, done, _ = env.step(actions)
            episode_rew += rew[0]
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done
            if done:
                print(f'episode_rew={episode_rew}')
                episode_rew = 0
                obs = env.reset()

    env.close()

    return model 
开发者ID:ethz-asl,项目名称:reinmav-gym,代码行数:51,代码来源:run.py

示例4: main

# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if args.extra_import is not None:
        import_module(args.extra_import)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    # If argument indicate training to be done:
    model, env = train(args, extra_args)
    env.close()

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)
        saver = tf.train.Saver()

        #logger.info("saving the trained model")
        #start_time_save = time.time()
        #saver.save(sess, save_path + "ddpg_test_model")
        #logger.info('runtime saving: {}s'.format(time.time() - start_time_save))

    # If it is a test run on the learned model
    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs,S=state, M=dones)
            else:
                actions, _, _, _ = model.step(obs)

            obs, _, done, _ = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset()

        env.close()

    return model 
开发者ID:jiewwantan,项目名称:StarTrader,代码行数:58,代码来源:run.py

示例5: main

# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        configure_logger(args.log_path)
    else:
        rank = MPI.COMM_WORLD.Get_rank()
        configure_logger(args.log_path, format_strs=[])

    model, env = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs,S=state, M=dones)
            else:
                actions, _, _, _ = model.step(obs)

            obs, rew, done, _ = env.step(actions)
            episode_rew += rew
            env.render()
            done_any = done.any() if isinstance(done, np.ndarray) else done
            if done_any:
                for i in np.nonzero(done)[0]:
                    print('episode_rew={}'.format(episode_rew[i]))
                    episode_rew[i] = 0

    env.close()

    return model 
开发者ID:openai,项目名称:baselines,代码行数:48,代码来源:run.py

示例6: main

# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 别名]
def main():
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)
    env.close()

    # if args.save_path is not None and rank == 0:
    #     save_path = osp.expanduser(args.save_path)
    #     model.save(save_path)

    if args.play:
        # args.num_timesteps=686
        # model, env = train(args, extra_args)
        # env.close()

        logger.log("Running trained model")
        env = build_testenv(args)
        obs = env.reset()
        # done = False

        #hardcode the data length
        for i in range(686):

            actions, _, _, _ = model.step(obs)
            obs, _, done, _ = env.step(actions)
            # env.render()
            # done = done.any() if isinstance(done, np.ndarray) else done

            # if done:
            #     obs = env.reset()

        env.close() 
开发者ID:hust512,项目名称:DQN-DDPG_Stock_Trading,代码行数:45,代码来源:run.py


注:本文中的baselines.common.cmd_util.common_arg_parser方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。