當前位置: 首頁>>代碼示例>>Python>>正文


Python models.load_model方法代碼示例

本文整理匯總了Python中models.load_model方法的典型用法代碼示例。如果您正苦於以下問題:Python models.load_model方法的具體用法?Python models.load_model怎麽用?Python models.load_model使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在models的用法示例。


在下文中一共展示了models.load_model方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: import models [as 別名]
# 或者: from models import load_model [as 別名]
def main():
    '''Main demo function'''
    # Save prediction into a file named 'prediction.obj' or the given argument
    pred_file_name = sys.argv[1] if len(sys.argv) > 1 else 'prediction.obj'

    # load images
    demo_imgs = load_demo_images()

    # Download and load pretrained weights
    download_model(DEFAULT_WEIGHTS)

    # Use the default network model
    NetClass = load_model('ResidualGRUNet')

    # Define a network and a solver. Solver provides a wrapper for the test function.
    net = NetClass(compute_grad=False)  # instantiate a network
    net.load(DEFAULT_WEIGHTS)                        # load downloaded weights
    solver = Solver(net)                # instantiate a solver

    # Run the network
    voxel_prediction, _ = solver.test_output(demo_imgs)

    # Save the prediction to an OBJ file (mesh file).
    voxel2obj(pred_file_name, voxel_prediction[0, :, 1, :, :] > cfg.TEST.VOXEL_THRESH)

    # Use meshlab or other mesh viewers to visualize the prediction.
    # For Ubuntu>=14.04, you can install meshlab using
    # `sudo apt-get install meshlab`
    if cmd_exists('meshlab'):
        call(['meshlab', pred_file_name])
    else:
        print('Meshlab not found: please use visualization of your choice to view %s' %
              pred_file_name) 
開發者ID:chrischoy,項目名稱:3D-R2N2,代碼行數:35,代碼來源:demo.py

示例2: main

# 需要導入模塊: import models [as 別名]
# 或者: from models import load_model [as 別名]
def main():
    '''Main demo function'''
    # Save prediction into a file named 'prediction.obj' or the given argument
    pred_file_name = sys.argv[1] if len(sys.argv) > 1 else 'prediction.obj'

    # load images
    demo_imgs = load_demo_images()

    # Use the default network model
    NetClass = load_model('ResidualGRUNet')

    # Define a network and a solver. Solver provides a wrapper for the test function.
    net = NetClass()  # instantiate a network
    if torch.cuda.is_available():
        net.cuda()

    net.eval()

    solver = Solver(net)                # instantiate a solver
    solver.load(DEFAULT_WEIGHTS)

    # Run the network
    voxel_prediction, _ = solver.test_output(demo_imgs)
    voxel_prediction = voxel_prediction.detach().cpu().numpy()

    # Save the prediction to an OBJ file (mesh file).
    voxel2obj(pred_file_name, voxel_prediction[0, 1] > cfg.TEST.VOXEL_THRESH)

    # Use meshlab or other mesh viewers to visualize the prediction.
    # For Ubuntu>=14.04, you can install meshlab using
    # `sudo apt-get install meshlab`
    if cmd_exists('meshlab'):
        call(['meshlab', pred_file_name])
    else:
        print('Meshlab not found: please use visualization of your choice to view %s' %
              pred_file_name) 
開發者ID:heromanba,項目名稱:3D-R2N2-PyTorch,代碼行數:38,代碼來源:demo.py

示例3: train_net

# 需要導入模塊: import models [as 別名]
# 或者: from models import load_model [as 別名]
def train_net():
    '''Main training function'''
    # Set up the model and the solver
    NetClass = load_model(cfg.CONST.NETWORK_CLASS)
    print('Network definition: \n')
    print(inspect.getsource(NetClass.network_definition))
    net = NetClass()

    # Check that single view reconstruction net is not used for multi view
    # reconstruction.
    if net.is_x_tensor4 and cfg.CONST.N_VIEWS > 1:
        raise ValueError('Do not set the config.CONST.N_VIEWS > 1 when using' \
                         'single-view reconstruction network')

    # Generate the solver
    solver = Solver(net)

    # Prefetching data processes
    #
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    global train_queue, val_queue, train_processes, val_processes
    train_queue = Queue(cfg.QUEUE_SIZE)
    val_queue = Queue(cfg.QUEUE_SIZE)

    train_processes = make_data_processes(
        train_queue,
        category_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION),
        cfg.TRAIN.NUM_WORKER,
        repeat=True)
    val_processes = make_data_processes(
        val_queue,
        category_model_id_pair(dataset_portion=cfg.TEST.DATASET_PORTION),
        1,
        repeat=True,
        train=False)

    # Train the network
    solver.train(train_queue, val_queue)

    # Cleanup the processes and the queue.
    kill_processes(train_queue, train_processes)
    kill_processes(val_queue, val_processes) 
開發者ID:chrischoy,項目名稱:3D-R2N2,代碼行數:46,代碼來源:train_net.py

示例4: train_net

# 需要導入模塊: import models [as 別名]
# 或者: from models import load_model [as 別名]
def train_net():
    '''Main training function'''
    # Set up the model and the solver
    NetClass = load_model(cfg.CONST.NETWORK_CLASS)

    net = NetClass()
    print('\nNetwork definition: ')
    print(net)

    # Check that single view reconstruction net is not used for multi view
    # reconstruction.
    if net.is_x_tensor4 and cfg.CONST.N_VIEWS > 1:
        raise ValueError('Do not set the config.CONST.N_VIEWS > 1 when using' \
                         'single-view reconstruction network')

    # Prefetching data processes
    #
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.

    train_dataset = ShapeNetDataset(cfg.TRAIN.DATASET_PORTION)
    train_collate_fn = ShapeNetCollateFn()
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=cfg.CONST.BATCH_SIZE,
        shuffle=True,
        num_workers=cfg.TRAIN.NUM_WORKER,
        collate_fn=train_collate_fn,
        pin_memory=True
    )

    val_dataset = ShapeNetDataset(cfg.TEST.DATASET_PORTION)
    val_collate_fn = ShapeNetCollateFn(train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=cfg.CONST.BATCH_SIZE,
        shuffle=True,
        num_workers=1,
        collate_fn=val_collate_fn,
        pin_memory=True
    )

    net.cuda()

    # Generate the solver
    solver = Solver(net)

    # Train the network
    solver.train(train_loader, val_loader) 
開發者ID:heromanba,項目名稱:3D-R2N2-PyTorch,代碼行數:52,代碼來源:train_net.py


注:本文中的models.load_model方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。