当前位置: 首页>>代码示例>>Python>>正文


Python comm.get_rank方法代码示例

本文整理汇总了Python中maskrcnn_benchmark.utils.comm.get_rank方法的典型用法代码示例。如果您正苦于以下问题:Python comm.get_rank方法的具体用法?Python comm.get_rank怎么用?Python comm.get_rank使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在maskrcnn_benchmark.utils.comm的用法示例。


在下文中一共展示了comm.get_rank方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:50,代码来源:train_net.py

示例2: main

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        run_test(cfg, model, args.distributed) 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:62,代码来源:train_net.py

示例3: train

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)
    arguments["iteration"] = 0

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model 
开发者ID:Xiangyu-CAS,项目名称:R2CNN.pytorch,代码行数:51,代码来源:train_net.py

示例4: main

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="../configs/e2e_r2cnn_R_50_FPN_1x.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        default=True,
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        run_test(cfg, model, args.distributed) 
开发者ID:Xiangyu-CAS,项目名称:R2CNN.pytorch,代码行数:63,代码来源:train_net.py

示例5: fitness

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def fitness(gpu, ngpus_per_node, cfg, args, rngs, salt, conn):
    num_gpus = int(os.environ["WORLD_SIZE"]) \
        if "WORLD_SIZE" in os.environ else 1
    args["distributed"] = num_gpus > 1

    args["local_rank"] = gpu

    if args["distributed"]:
        torch.cuda.set_device(args["local_rank"])
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://",
            world_size=num_gpus, rank=args["local_rank"]
        )

    model = GeneralizedRCNN(cfg)
    model.to(cfg.MODEL.DEVICE)

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer2(
        cfg, model, save_dir=cfg.OUTPUT_DIR, save_to_disk=save_to_disk)
    extra_checkpoint_data = checkpointer.load(os.path.join(cfg.OUTPUT_DIR, salt+".pth"))

    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    output_folders = [None] * len(cfg.DATASETS.TEST)
    if cfg.OUTPUT_DIR:
        dataset_names = cfg.DATASETS.TEST
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(
                cfg.OUTPUT_DIR, "inference", dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder

    data_loaders_val = make_data_loader(
        cfg, is_train=False, is_distributed=args["distributed"])
    for output_folder, data_loader_val in zip(output_folders, data_loaders_val):
        results = inference(
            model,
            rngs,
            data_loader_val,
            iou_types=iou_types,
            box_only=False,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()

    if get_rank() == 0:
        conn.send(results.results['bbox']['AP'])
        conn.close() 
开发者ID:megvii-model,项目名称:DetNAS,代码行数:55,代码来源:test_server.py

示例6: train

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )



    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(  # clw note:创建数据集
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model 
开发者ID:clw5180,项目名称:remote_sensing_object_detection_2019,代码行数:53,代码来源:train_net.py

示例7: main

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="../configs/rrpn/e2e_rrpn_X_101_32x8d_FPN_1x_DOTA.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        test(cfg, model, args.distributed) 
开发者ID:clw5180,项目名称:remote_sensing_object_detection_2019,代码行数:62,代码来源:train_net.py

示例8: train

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if cfg.SOLVER.ENABLE_FP16:
        model.half()

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            # broadcast_buffers=False,
        )
    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    print(cfg.MODEL.WEIGHT)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        cfg
    )

    return model 
开发者ID:HRNet,项目名称:HRNet-MaskRCNN-Benchmark,代码行数:53,代码来源:train_net.py

示例9: main

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="configs/free_anchor_R-50-FPN_8gpu_1x.yaml",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank % torch.cuda.device_count())
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        test(cfg, model, args.distributed) 
开发者ID:zhangxiaosong18,项目名称:FreeAnchor,代码行数:61,代码来源:train_net.py

示例10: train

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.deprecated.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model 
开发者ID:chengyangfu,项目名称:retinamask,代码行数:50,代码来源:train_net.py

示例11: main

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.deprecated.init_process_group(
            backend="nccl", init_method="env://"
        )

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        test(cfg, model, args.distributed) 
开发者ID:chengyangfu,项目名称:retinamask,代码行数:61,代码来源:train_net.py

示例12: main

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed)

    if not args.skip_test:
        test(cfg, model, args.distributed) 
开发者ID:mlperf,项目名称:training,代码行数:62,代码来源:train_net.py

示例13: __init__

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def __init__(
        self, ann_file, root, mode, resolution, backbone, post_branch, 
        remove_images_without_annotations, transforms=None,
        special_deal=False
    ):
        super(COCODataset, self).__init__(root, ann_file)
        self.mode = mode
        self.rank = get_rank()
        self.backbone = backbone
        self.transforms = transforms
        self.resolution = resolution
        self.post_branch = post_branch
        self.special_deal = special_deal
        self._img_search_set = os.path.join(ann_file, "%s.txt")

        self._split_feature = os.path.join(ann_file, "features_%s_%s", "%s.pth")
        self._split_target  = os.path.join(ann_file, self.post_branch, "targets_%s_%s", "%s.pth")
        self._split_label   = os.path.join(ann_file, self.post_branch, "labels_%s_%s", "%s.pth")
        self._split_reg     = os.path.join(ann_file, self.post_branch, "regs_%s_%s", "%s.pth")

        if self.mode == 0:
            # sort indices for reproducible results
            self.ids = sorted(self.ids)

            # filter images without detection annotations
            if remove_images_without_annotations:
                self.ids = [
                    img_id
                    for img_id in self.ids
                    if len(self.coco.getAnnIds(imgIds=img_id, iscrowd=None)) > 0
                ]

            self.json_category_id_to_contiguous_id = {
                v: i + 1 for i, v in enumerate(self.coco.getCatIds())
            }
            self.contiguous_category_id_to_json_id = {
                v: k for k, v in self.json_category_id_to_contiguous_id.items()
            }
            self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}

        elif self.mode == 1:
            try:
                image_set_name = "coco_%s_%s_%s_search" % (
                    self.backbone, self.resolution, self.post_branch
                )
                with open(self._img_search_set % image_set_name) as f:
                    self.ids = f.readlines()
                self.ids = [x.strip("\n") for x in enumerate(self.ids)]
            except FileNotFoundError:
                raise ValueError("Can not open file {}.txt".format(image_set_name))

        else:
            raise ValueError("Mode {} is not available for COCO Dataset".format(self.mode)) 
开发者ID:Lausannen,项目名称:NAS-FCOS,代码行数:55,代码来源:coco.py

示例14: main

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def main():  
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")  
    parser.add_argument(  
        "--config-file",  
        default="",  
        metavar="FILE",  
        help="path to config file",  
        type=str,  
    )  
    parser.add_argument("--local_rank", type=int, default=0)  
    parser.add_argument(  
        "--skip-test",  
        dest="skip_test",  
        help="Do not test the final model",  
        action="store_true",  
    )  
    parser.add_argument(  
        "opts",  
        help="Modify config options using the command-line",  
        default=None,  
        nargs=argparse.REMAINDER,  
    )  

    args = parser.parse_args()  

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1  
    args.distributed = num_gpus > 1  

    if args.distributed:  
        torch.cuda.set_device(args.local_rank)  
        torch.distributed.init_process_group(  
            backend="nccl", init_method="env://"  
        )  
        synchronize()  

    cfg.merge_from_file(args.config_file)  
    cfg.merge_from_list(args.opts)  
    cfg.freeze()  

    output_dir = cfg.OUTPUT_DIR  
    if output_dir:  
        mkdir(output_dir)  

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())  
    logger.info("Using {} GPUs".format(num_gpus))  
    logger.info(args)  

    logger.info("Collecting env info (might take some time)")  
    logger.info("\n" + collect_env_info())  

    logger.info("Loaded configuration file {}".format(args.config_file))  
    with open(args.config_file, "r") as cf:  
        config_str = "\n" + cf.read()  
        logger.info(config_str)  
    logger.info("Running with config:\n{}".format(cfg))  

    logger.info("Arch Decoder: {}".format(cfg.SEARCH.DECODER.CONFIG))  
    model = train(cfg, args.local_rank, args.distributed)  

    if not args.skip_test:  
        run_test(cfg, model, args.distributed) 
开发者ID:Lausannen,项目名称:NAS-FCOS,代码行数:63,代码来源:train_net.py

示例15: train

# 需要导入模块: from maskrcnn_benchmark.utils import comm [as 别名]
# 或者: from maskrcnn_benchmark.utils.comm import get_rank [as 别名]
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )



    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model 
开发者ID:mjq11302010044,项目名称:RRPN_pytorch,代码行数:53,代码来源:train_net.py


注:本文中的maskrcnn_benchmark.utils.comm.get_rank方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。