當前位置: 首頁>>代碼示例>>Python>>正文


Python cmd_util.common_arg_parser方法代碼示例

本文整理匯總了Python中baselines.common.cmd_util.common_arg_parser方法的典型用法代碼示例。如果您正苦於以下問題:Python cmd_util.common_arg_parser方法的具體用法?Python cmd_util.common_arg_parser怎麽用?Python cmd_util.common_arg_parser使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在baselines.common.cmd_util的用法示例。


在下文中一共展示了cmd_util.common_arg_parser方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: from baselines.common import cmd_util [as 別名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 別名]
def main(custom_args=[]):
    # configure logger, disable logging in child MPI processes (with rank > 0) 
    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = {}
    for arg in custom_args:
        if arg in vars(args).keys():
            vars(args)[arg] = custom_args[arg]
        else:
            extra_args[arg] = custom_args[arg]
    
    #extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}

    
    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure(format_strs = ['stdout', 'tensorboard'])
    else:
        logger.configure(format_strs = ['stdout', 'tensorboard'])
        rank = MPI.COMM_WORLD.Get_rank()

    model, _ = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)
    

    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        obs = env.reset()
        while True:
            actions = model.step(obs)[0]
            obs, _, done, _  = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset() 
開發者ID:MaxSobolMark,項目名稱:HardRLWithYoutube,代碼行數:42,代碼來源:run.py

示例2: main

# 需要導入模塊: from baselines.common import cmd_util [as 別名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 別名]
def main():
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)
    env.close()

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        obs = env.reset()
        def initialize_placeholders(nlstm=128,**kwargs):
            return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
        state, dones = initialize_placeholders(**extra_args)
        while True:
            actions, _, state, _ = model.step(obs,S=state, M=dones)
            obs, _, done, _ = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset()

        env.close() 
開發者ID:hiwonjoon,項目名稱:ICML2019-TREX,代碼行數:40,代碼來源:run.py

示例3: main

# 需要導入模塊: from baselines.common import cmd_util [as 別名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 別名]
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if args.extra_import is not None:
        import_module(args.extra_import)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        episode_rew = 0
        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs,S=state, M=dones)
            else:
                actions, _, _, _ = model.step(obs)

            obs, rew, done, _ = env.step(actions)
            episode_rew += rew[0]
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done
            if done:
                print(f'episode_rew={episode_rew}')
                episode_rew = 0
                obs = env.reset()

    env.close()

    return model 
開發者ID:ethz-asl,項目名稱:reinmav-gym,代碼行數:51,代碼來源:run.py

示例4: main

# 需要導入模塊: from baselines.common import cmd_util [as 別名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 別名]
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if args.extra_import is not None:
        import_module(args.extra_import)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    # If argument indicate training to be done:
    model, env = train(args, extra_args)
    env.close()

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)
        saver = tf.train.Saver()

        #logger.info("saving the trained model")
        #start_time_save = time.time()
        #saver.save(sess, save_path + "ddpg_test_model")
        #logger.info('runtime saving: {}s'.format(time.time() - start_time_save))

    # If it is a test run on the learned model
    if args.play:
        logger.log("Running trained model")
        env = build_env(args)
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs,S=state, M=dones)
            else:
                actions, _, _, _ = model.step(obs)

            obs, _, done, _ = env.step(actions)
            env.render()
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                obs = env.reset()

        env.close()

    return model 
開發者ID:jiewwantan,項目名稱:StarTrader,代碼行數:58,代碼來源:run.py

示例5: main

# 需要導入模塊: from baselines.common import cmd_util [as 別名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 別名]
def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        configure_logger(args.log_path)
    else:
        rank = MPI.COMM_WORLD.Get_rank()
        configure_logger(args.log_path, format_strs=[])

    model, env = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
        while True:
            if state is not None:
                actions, _, state, _ = model.step(obs,S=state, M=dones)
            else:
                actions, _, _, _ = model.step(obs)

            obs, rew, done, _ = env.step(actions)
            episode_rew += rew
            env.render()
            done_any = done.any() if isinstance(done, np.ndarray) else done
            if done_any:
                for i in np.nonzero(done)[0]:
                    print('episode_rew={}'.format(episode_rew[i]))
                    episode_rew[i] = 0

    env.close()

    return model 
開發者ID:openai,項目名稱:baselines,代碼行數:48,代碼來源:run.py

示例6: main

# 需要導入模塊: from baselines.common import cmd_util [as 別名]
# 或者: from baselines.common.cmd_util import common_arg_parser [as 別名]
def main():
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    model, env = train(args, extra_args)
    env.close()

    # if args.save_path is not None and rank == 0:
    #     save_path = osp.expanduser(args.save_path)
    #     model.save(save_path)

    if args.play:
        # args.num_timesteps=686
        # model, env = train(args, extra_args)
        # env.close()

        logger.log("Running trained model")
        env = build_testenv(args)
        obs = env.reset()
        # done = False

        #hardcode the data length
        for i in range(686):

            actions, _, _, _ = model.step(obs)
            obs, _, done, _ = env.step(actions)
            # env.render()
            # done = done.any() if isinstance(done, np.ndarray) else done

            # if done:
            #     obs = env.reset()

        env.close() 
開發者ID:hust512,項目名稱:DQN-DDPG_Stock_Trading,代碼行數:45,代碼來源:run.py


注:本文中的baselines.common.cmd_util.common_arg_parser方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。