当前位置: 首页>>代码示例>>Python>>正文


Python multiprocessing.get_context方法代码示例

本文整理汇总了Python中torch.multiprocessing.get_context方法的典型用法代码示例。如果您正苦于以下问题:Python multiprocessing.get_context方法的具体用法?Python multiprocessing.get_context怎么用?Python multiprocessing.get_context使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.multiprocessing的用法示例。


在下文中一共展示了multiprocessing.get_context方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __init__(self, dataset, class_num, image_mean, image_std, network,
                 multi_scales, is_flip, devices,
                 verbose=False, save_path=None, show_image=False):
        self.dataset = dataset
        self.ndata = self.dataset.get_length()
        self.class_num = class_num
        self.image_mean = image_mean
        self.image_std = image_std
        self.multi_scales = multi_scales
        self.is_flip = is_flip
        self.network = network
        self.devices = devices

        self.context = mp.get_context('spawn')
        self.val_func = None
        self.results_queue = self.context.Queue(self.ndata)

        self.verbose = verbose
        self.save_path = save_path
        if save_path is not None:
            ensure_dir(save_path)
        self.show_image = show_image 
开发者ID:StevenGrove,项目名称:TreeFilter-Torch,代码行数:24,代码来源:evaluator.py

示例2: test_SizedDict_shared

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def test_SizedDict_shared():
    d = SizedDict(shared=True)
    x = torch.randn(10)
    d["a"] = x

    mp = multiprocessing.get_context("forkserver")
    p = mp.Process(target=_set, args=(d,))
    p.start()
    p.join()
    assert d["a"][0] == 10 
开发者ID:espnet,项目名称:espnet,代码行数:12,代码来源:test_sized_dict.py

示例3: train

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def train(opt):
    torch.manual_seed(123)
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    if not os.path.isdir(opt.saved_path):
        os.makedirs(opt.saved_path)
    mp = _mp.get_context("spawn")
    global_model = ActorCritic(num_inputs=3, num_actions=90)
    global_icm = IntrinsicCuriosityModule(num_inputs=3, num_actions=90)
    if opt.use_gpu:
        global_model.cuda()
        global_icm.cuda()
    global_model.share_memory()
    global_icm.share_memory()

    optimizer = GlobalAdam(list(global_model.parameters()) + list(global_icm.parameters()), lr=opt.lr)
    processes = []
    for index in range(opt.num_processes):
        if index == 0:
            process = mp.Process(target=local_train, args=(index, opt, global_model, global_icm, optimizer, True))
        else:
            process = mp.Process(target=local_train, args=(index, opt, global_model, global_icm, optimizer))
        process.start()
        processes.append(process)
    for process in processes:
        process.join() 
开发者ID:uvipen,项目名称:Street-fighter-A3C-ICM-pytorch,代码行数:29,代码来源:train.py

示例4: __init__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __init__(self, dataset, class_num, image_mean, image_std, network,
                 multi_scales, is_flip, devices=0, out_idx=0, threds=3, config=None, logger=None,
                 verbose=False, save_path=None, show_prediction=False):
        self.dataset = dataset
        self.ndata = self.dataset.get_length()
        self.class_num = class_num
        self.image_mean = image_mean
        self.image_std = image_std
        self.multi_scales = multi_scales
        self.is_flip = is_flip
        self.network = network
        self.devices = devices
        if type(self.devices) == int: self.devices = [self.devices]
        self.out_idx = out_idx
        self.threds = threds
        self.config = config
        self.logger = logger

        self.context = mp.get_context('spawn')
        self.val_func = None
        self.results_queue = self.context.Queue(self.ndata)

        self.verbose = verbose
        self.save_path = save_path
        if save_path is not None:
            ensure_dir(save_path)
        self.show_prediction = show_prediction 
开发者ID:TAMU-VITA,项目名称:FasterSeg,代码行数:29,代码来源:tester.py

示例5: __init__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __init__(self, dataset, class_num, image_mean, image_std, network,
                 multi_scales, is_flip, devices=0, out_idx=0, threds=3, config=None, logger=None,
                 verbose=False, save_path=None, show_image=False, show_prediction=False):
        self.dataset = dataset
        self.ndata = self.dataset.get_length()
        self.class_num = class_num
        self.image_mean = image_mean
        self.image_std = image_std
        self.multi_scales = multi_scales
        self.is_flip = is_flip
        self.network = network
        self.devices = devices
        if type(self.devices) == int: self.devices = [self.devices]
        self.out_idx = out_idx
        self.threds = threds
        self.config = config
        self.logger = logger

        self.context = mp.get_context('spawn')
        self.val_func = None
        self.results_queue = self.context.Queue(self.ndata)

        self.verbose = verbose
        self.save_path = save_path
        if save_path is not None:
            ensure_dir(save_path)
        self.show_image = show_image
        self.show_prediction = show_prediction 
开发者ID:TAMU-VITA,项目名称:FasterSeg,代码行数:30,代码来源:evaluator.py

示例6: __init__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __init__(
            self,
            _agent: AgentAbstract,
            _env: EnvAbstract,
            _epoch_max: int,
            _epoch_train: int,
            _train_update_target: int,
            _train_save: int,
            _process_core: int = None,
            _save_path: str = './save',
            _use_cmd: bool = True,
    ):
        self.agent: AgentAbstract = _agent
        self.agent.training()
        self.env: EnvAbstract = _env

        # multiprocessing for sampling
        self.mp = mp.get_context('spawn')
        self.process_core = _process_core
        self.pool = self.mp.Pool(self.process_core)

        # training control
        self.epoch = 0
        self.train_times = 0
        self.epoch_max = _epoch_max
        self.epoch_train = _epoch_train
        self.train_update_target = _train_update_target
        self.train_save = _train_save

        self.total_reward_buf = []

        self.save_path = _save_path
        self.use_cmd = _use_cmd
        if self.use_cmd:
            self.shell = TrainShell(self) 
开发者ID:ppaanngggg,项目名称:DeepRL,代码行数:37,代码来源:AsynTrainEpoch.py

示例7: parallel_predict

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def parallel_predict(model, test_sequences, args, num_processes=4):
  """Run prediction in parallel using torch.multiprocessing.

  This is a beta feature. It makes prediction slower on CPU. But it's reported
  that it makes prediction faster on GPU.

  Args:
    model: instance of UISRNN model
    test_sequences: a list of test sequences, or a single test
      sequence. Each test sequence is a 2-dim numpy array
      of real numbers. See `predict_single()` for details.
    args: Inference configurations. See `arguments.py` for details.
    num_processes: number of parallel processes.

  Returns:
    a list of the same size as test_sequences, where each element
    being a 1-dim list of strings.

  Raises:
      TypeError: If test_sequences is of wrong type.
  """
  if not isinstance(test_sequences, list):
    raise TypeError('test_sequences must be a list.')
  ctx = multiprocessing.get_context('forkserver')
  model.rnn_model.share_memory()
  pool = ctx.Pool(num_processes)
  results = pool.map(
      functools.partial(model.predict_single, args=args),
      test_sequences)
  pool.close()
  return results 
开发者ID:google,项目名称:uis-rnn,代码行数:33,代码来源:uisrnn.py

示例8: __init__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __init__(self, loader, prepro,
                 sort_key, batchify,
                 single_run=True, queue_size=8, fork=True):
        self._loader = loader
        self._prepro = prepro
        self._sort_key = sort_key
        self._batchify = batchify
        self._single_run = single_run
        if fork:
            ctx = mp.get_context('forkserver')
            self._queue = ctx.Queue(queue_size)
        else:
            # for easier debugging
            self._queue = None
        self._process = None 
开发者ID:ChenRocks,项目名称:fast_abs_rl,代码行数:17,代码来源:batcher.py

示例9: __call__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __call__(self, batch_size: int):
        def get_batches(hyper_batch):
            indexes = list(range(0, len(hyper_batch), batch_size))
            if not self._single_run:
                # random shuffle for training batches
                random.shuffle(hyper_batch)
                random.shuffle(indexes)
            hyper_batch.sort(key=self._sort_key)
            for i in indexes:
                batch = self._batchify(hyper_batch[i:i+batch_size])
                yield batch

        if self._queue is not None:
            ctx = mp.get_context('forkserver')
            self._process = ctx.Process(
                target=_batch2q,
                args=(self._loader, self._prepro,
                      self._queue, self._single_run)
            )
            self._process.start()
            while True:
                d = self._queue.get()
                if d is None:
                    break
                if isinstance(d, int):
                    print('\nepoch {} done'.format(d))
                    continue
                yield from get_batches(d)
            self._process.join()
        else:
            i = 0
            while True:
                for batch in self._loader:
                    yield from get_batches(self._prepro(batch))
                if self._single_run:
                    break
                i += 1
                print('\nepoch {} done'.format(i)) 
开发者ID:ChenRocks,项目名称:fast_abs_rl,代码行数:40,代码来源:batcher.py

示例10: get_multiprocess_batch_queue

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def get_multiprocess_batch_queue(name_prefix: str, target_function, files, conf, _logger, queue_size=100) -> Tuple[mp.Queue, List[mp.Process], mp.Event]:
    ctx = mp.get_context('spawn') # also set so that windows & linux behave the same 
    _queue = ctx.Queue(queue_size)
    _processes = []
    _finish_notification = ctx.Event()

    if len(files) == 0:
        _logger.error("No files for multiprocess loading specified, for: " + name_prefix)
        exit(1)
    else:
        _logger.info("Starting "+str(len(files))+" data loader processes, for:" + name_prefix)

    if conf["token_embedder_type"] == "fasttext":
        global fasttext_vocab_cached_mapping
        global fasttext_vocab_cached_data
        if fasttext_vocab_cached_data is None:
            fasttext_vocab_cached_mapping, fasttext_vocab_cached_data = FastTextVocab.load_ids(conf["fasttext_vocab_mapping"],conf["fasttext_max_subwords"])
            fasttext_vocab_cached_data.share_memory_()

    for proc_number, file in enumerate(files):
        process = ctx.Process(name=name_prefix + "-" + str(proc_number),
                             target=target_function,
                             args=(proc_number, conf, _queue, _finish_notification, file,fasttext_vocab_cached_mapping,fasttext_vocab_cached_data))
        process.start()
        _processes.append(process)
    return _queue, _processes, _finish_notification


#
# training instance generator
#   - filling the _queue with ready to run training batches
#   - everything is thread local
# 
开发者ID:sebastian-hofstaetter,项目名称:sigir19-neural-ir,代码行数:35,代码来源:multiprocess_input_pipeline.py

示例11: __init__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __init__(
        self, spec: WorkerSpec, start_method="spawn", exit_barrier_timeout: float = 300
    ):
        super().__init__(spec, exit_barrier_timeout)
        self._start_method = start_method
        # pyre-fixme[8]: Attribute has type `ProcessContext`; used as `None`.
        self._process_context: mp.ProcessContext = None
        # a map that holds return values for each worker fn
        # ret_val[0] holds the return value for worker_0 (global rank 0)
        self._manager = mp.get_context(start_method).Manager()
        self._ret_vals = self._manager.dict() 
开发者ID:pytorch,项目名称:elastic,代码行数:13,代码来源:local_elastic_agent.py

示例12: get_multiprocess_batch_queue

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def get_multiprocess_batch_queue(name_prefix: str, target_function, files, conf, _logger, queue_size=100) -> Tuple[mp.Queue, List[mp.Process], mp.Event]:
    ctx = mp.get_context('spawn') # also set so that windows & linux behave the same 
    _processes = []
    _finish_notification = ctx.Event()

    if len(files) == 0:
        _logger.error("No files for multiprocess loading specified, for: " + name_prefix)
        exit(1)
    else:
        _logger.info("Starting "+str(len(files))+" data loader processes, for:" + name_prefix)

    if conf["token_embedder_type"] == "fasttext":
        global fasttext_vocab_cached_mapping
        global fasttext_vocab_cached_data
        if fasttext_vocab_cached_data is None:
            fasttext_vocab_cached_mapping, fasttext_vocab_cached_data = FastTextVocab.load_ids(conf["fasttext_vocab_mapping"],conf["fasttext_max_subwords"])
            fasttext_vocab_cached_data.share_memory_()

    _queue_list = []
    #_queue = ctx.Queue(queue_size)
    for proc_number, file in enumerate(files):
        _queue = ctx.Queue(queue_size)
        process = ctx.Process(name=name_prefix + "-" + str(proc_number),
                             target=target_function,
                             args=(proc_number, conf, _queue, _finish_notification, file,fasttext_vocab_cached_mapping,fasttext_vocab_cached_data))
        process.start()
        _processes.append(process)
        _queue_list.append(_queue)
    return DeterministicQueue(_queue_list), _processes, _finish_notification
    #return _queue, _processes, _finish_notification


#
# training instance generator
#   - filling the _queue with ready to run training batches
#   - everything is thread local
# 
开发者ID:sebastian-hofstaetter,项目名称:transformer-kernel-ranking,代码行数:39,代码来源:multiprocess_input_pipeline.py

示例13: train

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def train(opt):
    torch.manual_seed(123)
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    if not os.path.isdir(opt.saved_path):
        os.makedirs(opt.saved_path)
    mp = _mp.get_context("spawn")
    env, num_states, num_actions = create_train_env(opt.world, opt.stage, opt.action_type)
    global_model = ActorCritic(num_states, num_actions)
    if opt.use_gpu:
        global_model.cuda()
    global_model.share_memory()
    if opt.load_from_previous_stage:
        if opt.stage == 1:
            previous_world = opt.world - 1
            previous_stage = 4
        else:
            previous_world = opt.world
            previous_stage = opt.stage - 1
        file_ = "{}/a3c_super_mario_bros_{}_{}".format(opt.saved_path, previous_world, previous_stage)
        if os.path.isfile(file_):
            global_model.load_state_dict(torch.load(file_))

    optimizer = GlobalAdam(global_model.parameters(), lr=opt.lr)
    processes = []
    for index in range(opt.num_processes):
        if index == 0:
            process = mp.Process(target=local_train, args=(index, opt, global_model, optimizer, True))
        else:
            process = mp.Process(target=local_train, args=(index, opt, global_model, optimizer))
        process.start()
        processes.append(process)
    process = mp.Process(target=local_test, args=(opt.num_processes, opt, global_model))
    process.start()
    processes.append(process)
    for process in processes:
        process.join() 
开发者ID:uvipen,项目名称:Super-mario-bros-A3C-pytorch,代码行数:40,代码来源:train.py

示例14: __init__

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def __init__(self, env_fns, context='spawn'):
        ctx = mp.get_context(context)
        dummy = env_fns[0]()
        observation_space, action_space = dummy.observation_space, dummy.action_space
        self.spec = dummy.spec
        dummy.close()
        del dummy
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)
        self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space)
        self.obs_bufs = [
            {k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
            for _ in env_fns]
        self.parent_pipes = []
        self.procs = []
        with clear_mpi_env_vars():
            for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
                wrapped_fn = CloudpickleWrapper(env_fn)
                parent_pipe, child_pipe = ctx.Pipe()
                proc = ctx.Process(
                    target=subproc_worker,
                    args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
                proc.daemon = True
                self.procs.append(proc)
                self.parent_pipes.append(parent_pipe)
                proc.start()
                child_pipe.close()
        self.waiting_step = False
        self.viewer = None 
开发者ID:kengz,项目名称:SLM-Lab,代码行数:30,代码来源:vec_env.py

示例15: inference

# 需要导入模块: from torch import multiprocessing [as 别名]
# 或者: from torch.multiprocessing import get_context [as 别名]
def inference(
        detection_cfg,
        skeleton_cfg,
        dataset_cfg,
        gpus=1,
        worker_per_gpu=1,
):
    # get frame num
    video_file = dataset_cfg.video_file
    video_name = video_file.strip('/n').split('/')[-1]
    video_frames = mmcv.VideoReader(video_file)
    num_frames = len(video_frames)
    del video_frames

    data_cfg = skeleton_cfg.data_cfg
    if data_cfg.save_video:
        data_cfg.img_dir = os.path.join(data_cfg.save_dir,
                                        '{}.img'.format(video_name))

        if os.path.exists(data_cfg.img_dir):
            import shutil
            shutil.rmtree(data_cfg.img_dir)

        os.makedirs(data_cfg.img_dir)

    # cache model checkpoints
    cache_checkpoint(detection_cfg.checkpoint_file)
    cache_checkpoint(skeleton_cfg.checkpoint_file)

    # multiprocess settings
    context = mp.get_context('spawn')
    result_queue = context.Queue(num_frames)
    procs = []
    for w in range(gpus * worker_per_gpu):
        shred_list = list(range(w, num_frames, gpus * worker_per_gpu))
        p = context.Process(target=worker,
                            args=(video_file, shred_list, detection_cfg,
                                  skeleton_cfg, data_cfg, w % gpus,
                                  result_queue))
        p.start()
        procs.append(p)
    all_result = []
    print('\nPose estimation start:')
    prog_bar = ProgressBar(num_frames)
    for i in range(num_frames):
        t = result_queue.get()
        all_result.append(t)
        prog_bar.update()
    for p in procs:
        p.join()
    if len(all_result) == num_frames and data_cfg.save_video:
        print('\n\nGenerate video:')
        video_path = os.path.join(data_cfg.save_dir, video_name)
        mmcv.frames2video(data_cfg.img_dir,
                          video_path,
                          filename_tmpl='{:01d}.png')
        print('Video was saved to {}'.format(video_path)) 
开发者ID:open-mmlab,项目名称:mmskeleton,代码行数:59,代码来源:image2skeleton.py


注:本文中的torch.multiprocessing.get_context方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。