当前位置: 首页>>代码示例>>Python>>正文


Python utils.parse_args方法代码示例

本文整理汇总了Python中utils.parse_args方法的典型用法代码示例。如果您正苦于以下问题:Python utils.parse_args方法的具体用法?Python utils.parse_args怎么用?Python utils.parse_args使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils的用法示例。


在下文中一共展示了utils.parse_args方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import parse_args [as 别名]
def main():
    # Parse the JSON arguments
    config_args = parse_args()

    # Create the experiment directories
    _, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(
        config_args.experiment_dir)

    model = MobileNetV2(config_args)

    if config_args.cuda:
        model.cuda()
        cudnn.enabled = True
        cudnn.benchmark = True

    print("Loading Data...")
    data = CIFAR10Data(config_args)
    print("Data loaded successfully\n")

    trainer = Train(model, data.trainloader, data.testloader, config_args)

    if config_args.to_train:
        try:
            print("Training...")
            trainer.train()
            print("Training Finished\n")
        except KeyboardInterrupt:
            pass

    if config_args.to_test:
        print("Testing...")
        trainer.test(data.testloader)
        print("Testing Finished\n") 
开发者ID:MG2033,项目名称:MobileNet-V2,代码行数:35,代码来源:main.py

示例2: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import parse_args [as 别名]
def main():
    global args
    args = parse_args()
    train_net(args) 
开发者ID:foamliu,项目名称:InsightFace-PyTorch,代码行数:6,代码来源:train.py

示例3: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import parse_args [as 别名]
def main(_):
    # Parsing Arguments
    args = utils.parse_args()
    mode = args.mode
    train_iter = args.training_num
    test_iter = args.test_iter
    ckpt = utils.ckpt_path(args.ckpt)
    input_list = {
        'batch_size': args.batch_size,
        'beta': args.beta,
        'learning_rate': args.learning_rate,
        'ckpt': ckpt,
        'class_threshold': args.class_th,
        'scale': args.scale}

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    tf.reset_default_graph()
    model = StpnModel()

    # Run Model
    with tf.Session(config=config) as sess:
        init = tf.global_variables_initializer()
        if mode == 'train':
            sess.run(init)
            train(sess, model, input_list, 'rgb', train_iter)  # Train RGB stream
            sess.run(init)
            train(sess, model, input_list, 'flow', train_iter)  # Train FLOW stream

        elif mode == 'test':
            sess.run(init)
            test(sess, model, init, input_list, test_iter)  # Test 
开发者ID:bellos1203,项目名称:STPN,代码行数:35,代码来源:stpn_main.py

示例4: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import parse_args [as 别名]
def main():
    # Parse the JSON arguments
    try:
        config_args = parse_args()
    except:
        print("Add a config file using \'--config file_name.json\'")
        exit(1)

    # Create the experiment directories
    _, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(config_args.experiment_dir)

    # Reset the default Tensorflow graph
    tf.reset_default_graph()

    # Tensorflow specific configuration
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Data loading
    data = DataLoader(config_args.batch_size, config_args.shuffle)
    print("Loading Data...")
    config_args.img_height, config_args.img_width, config_args.num_channels, \
    config_args.train_data_size, config_args.test_data_size = data.load_data()
    print("Data loaded\n\n")

    # Model creation
    print("Building the model...")
    model = MobileNet(config_args)
    print("Model is built successfully\n\n")

    # Summarizer creation
    summarizer = Summarizer(sess, config_args.summary_dir)
    # Train class
    trainer = Train(sess, model, data, summarizer)

    if config_args.to_train:
        try:
            print("Training...")
            trainer.train()
            print("Training Finished\n\n")
        except KeyboardInterrupt:
            trainer.save_model()

    if config_args.to_test:
        print("Final test!")
        trainer.test('val')
        print("Testing Finished\n\n") 
开发者ID:MG2033,项目名称:MobileNet,代码行数:50,代码来源:main.py

示例5: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import parse_args [as 别名]
def main():
    # Parse the JSON arguments
    config_args = parse_args()

    # Create the experiment directories
    _, config_args.summary_dir, config_args.checkpoint_dir = create_experiment_dirs(config_args.experiment_dir)

    # Reset the default Tensorflow graph
    tf.reset_default_graph()

    # Tensorflow specific configuration
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Data loading
    # The batch size is equal to 1 when testing to simulate the real experiment.
    data_batch_size = config_args.batch_size if config_args.train_or_test == "train" else 1
    data = DataLoader(data_batch_size, config_args.shuffle)
    print("Loading Data...")
    config_args.img_height, config_args.img_width, config_args.num_channels, \
    config_args.train_data_size, config_args.test_data_size = data.load_data()
    print("Data loaded\n\n")

    # Model creation
    print("Building the model...")
    model = ShuffleNet(config_args)
    print("Model is built successfully\n\n")

    # Parameters visualization
    show_parameters()

    # Summarizer creation
    summarizer = Summarizer(sess, config_args.summary_dir)
    # Train class
    trainer = Train(sess, model, data, summarizer)

    if config_args.train_or_test == 'train':
        try:
            # print("FLOPs for batch size = " + str(config_args.batch_size) + "\n")
            # calculate_flops()
            print("Training...")
            trainer.train()
            print("Training Finished\n\n")
        except KeyboardInterrupt:
            trainer.save_model()

    elif config_args.train_or_test == 'test':
        # print("FLOPs for single inference \n")
        # calculate_flops()
        # This can be 'val' or 'test' or even 'train' according to the needs.
        print("Testing...")
        trainer.test('val')
        print("Testing Finished\n\n")

    else:
        raise ValueError("Train or Test options only are allowed") 
开发者ID:MG2033,项目名称:ShuffleNet,代码行数:59,代码来源:main.py


注:本文中的utils.parse_args方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。