本文整理汇总了Python中baselines.common.cmd_util.atari_arg_parser方法的典型用法代码示例。如果您正苦于以下问题:Python cmd_util.atari_arg_parser方法的具体用法?Python cmd_util.atari_arg_parser怎么用?Python cmd_util.atari_arg_parser使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.common.cmd_util
的用法示例。
在下文中一共展示了cmd_util.atari_arg_parser方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
args = atari_arg_parser().parse_args()
logger.configure()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32)
示例2: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
args = atari_arg_parser().parse_args()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
示例3: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
parser = atari_arg_parser()
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
args = parser.parse_args()
logger.configure()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
policy=args.policy)
示例4: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
parser = atari_arg_parser()
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
parser.add_argument('--logdir', help ='Directory for logging')
args = parser.parse_args()
logger.configure(args.logdir)
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
policy=args.policy, lrschedule=args.lrschedule, num_cpu=16)
示例5: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
parser = atari_arg_parser()
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
args = parser.parse_args()
logger.configure()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
policy=args.policy, lrschedule=args.lrschedule, num_env=16)
示例6: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
parser = atari_arg_parser()
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
parser.add_argument('--hparams_path', help='Load json hparams from this file', type=str, default='')
parser.add_argument('--gpu_num', help='cuda gpu #', type=str, default='')
args = parser.parse_args()
with open(args.hparams_path, 'r') as f:
hparams = json.load(f)
if args.gpu_num:
assert(int(args.gpu_num) >= -1 and int(args.gpu_num) <= 8)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_num
elif 'gpu_num' in hparams:
os.environ['CUDA_VISIBLE_DEVICES'] = str(hparams.get('gpu_num'))
log_path = os.path.join(hparams['base_dir'], 'logs', hparams['experiment_name'])
print('experiment_params: {}'.format(hparams))
print('chosen env: {}'.format(hparams['env_id']))
seed = 0
if hparams.get('atari_seed'): seed = hparams['atari_seed']
logger.configure(dir=log_path)
train(
env_id=hparams['env_id'],
num_timesteps=hparams['total_timesteps'],
seed=seed,
policy=hparams['policy'],
lrschedule=args.lrschedule,
num_env=hparams['num_env'],
ckpt_path=hparams['restore_from_ckpt_path'],
hparams=hparams,
)
示例7: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
parser = atari_arg_parser()
parser.add_argument('--hparams_path', help='Load json hparams from this file', type=str, default='')
parser.add_argument('--gpu_num', help='cuda gpu #', type=str, default='')
args = parser.parse_args()
with open(args.hparams_path, 'r') as f:
hparams = json.load(f)
if args.gpu_num:
assert(int(args.gpu_num) >= -1 and int(args.gpu_num) <= 8)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_num
elif 'gpu_num' in hparams:
os.environ['CUDA_VISIBLE_DEVICES'] = str(hparams.get('gpu_num'))
log_path = os.path.join(hparams['base_dir'], 'logs', hparams['experiment_name'])
logger.configure(dir=log_path)
print('experiment_params: {}'.format(hparams))
print('chosen env: {}'.format(hparams['env_id']))
seed = 0
if hparams.get('atari_seed'): seed = hparams['atari_seed']
train(hparams['env_id'], num_timesteps=args.num_timesteps, seed=seed,
policy=hparams['policy'], hparams=hparams)
示例8: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
parser = atari_arg_parser()
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn')
args = parser.parse_args()
logger.configure()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
policy=args.policy)
示例9: main
# 需要导入模块: from baselines.common import cmd_util [as 别名]
# 或者: from baselines.common.cmd_util import atari_arg_parser [as 别名]
def main():
parser = atari_arg_parser()
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
parser.add_argument('--sil-update', type=int, default=4, help="Number of updates per iteration")
parser.add_argument('--sil-beta', type=float, default=0.1, help="Beta for weighted IS")
parser.add_argument('--log', default='/tmp/a2c')
args = parser.parse_args()
logger.configure(dir=args.log)
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
policy=args.policy, lrschedule=args.lrschedule,
sil_update=args.sil_update, sil_beta=args.sil_beta,
num_env=16)