当前位置: 首页>>代码示例>>Python>>正文


Python args.get_args方法代码示例

本文整理汇总了Python中args.get_args方法的典型用法代码示例。如果您正苦于以下问题:Python args.get_args方法的具体用法?Python args.get_args怎么用?Python args.get_args使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在args的用法示例。


在下文中一共展示了args.get_args方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    save_args(args, "morph")

    morph(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:7,代码来源:morph.py

示例2: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    # command line args
    args = get_args()
    save_dir = os.path.join("checkpoints", args.log_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        os.makedirs(os.path.join(save_dir, 'images'))

    with open(os.path.join(save_dir, 'command.sh'), 'w') as f:
        f.write('python -X faulthandler ' + ' '.join(sys.argv))
        f.write('\n')

    if args.seed is None:
        args.seed = random.randint(0, 1000000)
    set_random_seed(args.seed)

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    if args.sync_bn:
        assert args.distributed

    print("Arguments:")
    print(args)

    ngpus_per_node = torch.cuda.device_count()
    if args.distributed:
        args.world_size = ngpus_per_node * args.world_size
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(save_dir, ngpus_per_node, args))
    else:
        main_worker(args.gpu, save_dir, ngpus_per_node, args) 
开发者ID:stevenygd,项目名称:PointFlow,代码行数:37,代码来源:train.py

示例3: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    train(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:8,代码来源:train.py

示例4: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    save_args(args, "generate")

    generate(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:7,代码来源:generate.py

示例5: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    save_args(args, "match")

    match(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:7,代码来源:match.py

示例6: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    save_args(args, "train")

    train(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:7,代码来源:train_with_mgpu.py

示例7: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    save_args(args)
    train(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:6,代码来源:train_mgpu.py

示例8: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    save_args(args, "generate")
    generate(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:6,代码来源:generate.py

示例9: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    train = True
    if args.data_type == "train":
        train = True
    elif args.data_type == "val":
        train = False

    prepare_pix2pix_dataset(args.dataset, train) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:11,代码来源:prepare_datasets.py

示例10: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    save_args(args, "generate")
    interpolate(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:6,代码来源:interpolate.py

示例11: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    # Context
    extension_module = args.context
    ctx = get_extension_context(
        extension_module, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    train(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:11,代码来源:train.py

示例12: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)
    iterations = []
    mean_iou = []
    model_dir = args.model_load_path
    for filename in os.listdir(model_dir):
        args.model_load_path = model_dir+filename
        miou = eval.validate(args)
        iterations.append(filename.split('.')[0])
        mean_iou.append(miou)

    for i in range(len(iterations)):
        iterations[i] = iterations[i].replace('param_', '')

    itr = list(map(int, iterations))

    # Plot Iterations Vs mIOU
    plt.axes([0, max(itr), 0.0, 1.0])
    plt.xlabel('Iterations')
    plt.ylabel('Accuracy - mIOU')
    plt.scatter(itr, mean_iou)
    plt.show()

    print(iterations)
    print(mean_iou)
    with open('iterations.txt', 'w') as f:
        for item in iterations:
            f.write('%s\n' % item)
    with open('miou.txt', 'w') as f2:
        for item in mean_iou:
            f2.write('%s\n' % item)

    #plt.plot(iterations, mean_iou) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:43,代码来源:plot_accuracy.py

示例13: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    '''
    Arguments:
    train-file = txt file containing randomly selected image filenames to be taken as training set.
    val-file = txt file containing randomly selected image filenames to be taken as validation set.
    data-dir = dataset directory
    Usage: python dataset_utils.py --train-file="" --val-file="" --data_dir=""
    '''

    args = get_args()
    data_dir = args.data_dir

    generate_path_files(data_dir, args.train_file, args.val_file) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:15,代码来源:prepare_lfw_data.py

示例14: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    miou = validate(args) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:14,代码来源:eval.py

示例15: main

# 需要导入模块: import args [as 别名]
# 或者: from args import get_args [as 别名]
def main():
    ''' 
    Main

    Usage: python convert_tf_nnabla.py --input-ckpt-file=/path to ckpt file --output-nnabla-file=/output .h5 file

    '''

    # Parse the arguments
    args = get_args()

    # convert the input file(.ckpt) to the output file(.h5)
    convert(args.input_ckpt_file, args.output_nnabla_file) 
开发者ID:sony,项目名称:nnabla-examples,代码行数:15,代码来源:convert_tf_nnabla.py


注:本文中的args.get_args方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。