当前位置: 首页>>代码示例>>Python>>正文


Python config.num_workers方法代码示例

本文整理汇总了Python中config.config.num_workers方法的典型用法代码示例。如果您正苦于以下问题:Python config.num_workers方法的具体用法?Python config.num_workers怎么用?Python config.num_workers使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在config.config的用法示例。


在下文中一共展示了config.num_workers方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_workers [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std,
                                config.target_size)

    train_dataset = dataset(data_setting, "train", train_preprocess,
                            config.niters_per_epoch * config.batch_size)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:StevenGrove,项目名称:TreeFilter-Torch,代码行数:32,代码来源:dataloader.py

示例2: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_workers [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'train_root': config.train_root_folder,
                    'val_root': config.eval_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std)

    train_dataset = dataset(data_setting, "train", train_preprocess,    \
                            config.batch_size * config.niters_per_epoch)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    # import pdb;pdb.set_trace()


    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:JaminFong,项目名称:FNA,代码行数:34,代码来源:dataloader.py

示例3: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_workers [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std,
                                config.target_size)

    train_dataset = dataset(data_setting, "train", train_preprocess,
                            config.niters_per_epoch * config.batch_size)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=False,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:ycszen,项目名称:TorchSeg,代码行数:32,代码来源:dataloader.py

示例4: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_workers [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std)

    train_dataset = dataset(data_setting, "train", train_preprocess,
                            config.batch_size * config.niters_per_epoch)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:ycszen,项目名称:TorchSeg,代码行数:31,代码来源:dataloader.py

示例5: __init__

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import num_workers [as 别名]
def __init__(self, symbol, data_names, label_names,
                 logger=logging, context=ctx.cpu(), work_load_list=None,
                 asymbol = None,
                 args = None):
        super(ParallModule, self).__init__(logger=logger)
        self._symbol = symbol
        self._asymbol = asymbol
        self._data_names = data_names
        self._label_names = label_names
        self._context = context
        self._work_load_list = work_load_list
        self._num_classes = config.num_classes
        self._batch_size = args.batch_size
        self._verbose = args.verbose
        self._emb_size = config.emb_size
        self._local_class_start = args.local_class_start
        self._iter = 0

        self._curr_module = None

        self._num_workers = config.num_workers
        self._num_ctx = len(self._context)
        self._ctx_num_classes = args.ctx_num_classes
        self._nd_cache = {}
        self._ctx_cpu = mx.cpu()
        self._ctx_single_gpu = self._context[-1]
        self._fixed_param_names = None
        self._curr_module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
                        context=self._context, work_load_list=self._work_load_list,
                        fixed_param_names=self._fixed_param_names)
        self._arcface_modules = []
        self._ctx_class_start = []
        for i in range(len(self._context)):

          args._ctxid = i
          _module = Module(self._asymbol(args), self._data_names, self._label_names, logger=self.logger,
                          context=mx.gpu(i), work_load_list=self._work_load_list,
                          fixed_param_names=self._fixed_param_names)
          self._arcface_modules.append(_module)
          _c = args.local_class_start + i*args.ctx_num_classes
          self._ctx_class_start.append(_c)
        self._usekv = False
        if self._usekv:
          self._distkv = mx.kvstore.create('dist_sync')
          self._kvinit = {} 
开发者ID:deepinsight,项目名称:insightface,代码行数:47,代码来源:parall_module_local_v1.py


注:本文中的config.config.num_workers方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。