本文整理汇总了Python中torch.set_num_threads方法的典型用法代码示例。如果您正苦于以下问题:Python torch.set_num_threads方法的具体用法?Python torch.set_num_threads怎么用?Python torch.set_num_threads使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.set_num_threads方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def main(config_file, dataset_name,
outer_k, outer_processes, inner_k, inner_processes, result_folder, debug=False):
# Needed to avoid thread spawning, conflicts with multi-processing. You may set a number > 1 but take into account
# the number of processes on the machine
torch.set_num_threads(1)
experiment_class = EndToEndExperiment
model_configurations = Grid(config_file, dataset_name)
model_configuration = Config(**model_configurations[0])
exp_path = os.path.join(result_folder, f'{model_configuration.exp_name}_assessment')
model_selector = HoldOutSelector(max_processes=inner_processes)
risk_assesser = KFoldAssessment(outer_k, model_selector, exp_path, model_configurations,
outer_processes=outer_processes)
risk_assesser.risk_assessment(experiment_class, debug=debug)
示例2: _ms_loop
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
_set_worker_signal_handlers()
torch.set_num_threads(1)
torch.manual_seed(seed)
while True:
r = index_queue.get()
if r is None:
break
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
示例3: _worker_loop
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def _worker_loop(dataset, index_queue, data_queue, collate_fn, rng_seed):
global _use_shared_memory
_use_shared_memory = True
np.random.seed(rng_seed)
torch.set_num_threads(1)
while True:
r = index_queue.get()
if r is None:
data_queue.put(None)
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
示例4: main
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def main():
torch.set_num_threads(multiprocessing.cpu_count())
args = parse_args()
if args.set == 'gta':
from model.model import Model
elif args.set == 'kitti':
from model.model_cen import Model
else:
raise ValueError("Model not found")
model = Model(args.arch,
args.roi_name,
args.down_ratio,
args.roi_kernel)
model = nn.DataParallel(model)
model = model.to(args.device)
if args.phase == 'train':
run_training(model, args)
elif args.phase == 'test':
test_model(model, args)
示例5: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def __init__(
self,
env,
policy,
exploration_policy,
max_path_length,
train_rollout_function,
eval_rollout_function,
):
torch.set_num_threads(1)
self._env = env
self._policy = policy
self._exploration_policy = exploration_policy
self._max_path_length = max_path_length
self.train_rollout_function = cloudpickle.loads(train_rollout_function)
self.eval_rollout_function = cloudpickle.loads(eval_rollout_function)
示例6: run
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def run(self):
# TODO Fix this dependency. The policy itself sets the thread limit
# to 1, but this configuration seems to be per-thread in pytorch
# so need to set it here too :(
import torch
torch.set_num_threads(1)
while not self._close_event.is_set():
# If queue is full, wait for it not to be
while len(self._queue) >= self._max_pending:
self._queue_empty_event.wait()
self._queue_empty_event.clear()
# Get the next sample(s)
samples = super().get_samples(1)
self._queue.extend(samples)
self._queue_fill_event.set()
示例7: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def __init__(self, observation_space):
"""Initializes the model with the given observation space
Currently supported observation spaces are:
- Box spaces
- A tuple of box spaces, where the 1st one is the 'main' observation,
and the rest contain additional 1D vectors of linear features for
the model which are fed to one of the non-convolutional layers
(Usually the RNN layer)
"""
super().__init__()
# When using multiple actors each with it's own CPU copy of the model,
# we need to limit them to be single-threaded otherwise they slow each
# other down. This should not effect training time if training is on
# the GPU
torch.set_num_threads(1)
self._setup_inputs(observation_space)
示例8: _collector_worker
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def _collector_worker(statistics, buffer, distributor,
collector, done, piecewise):
torch.set_num_threads(1)
while True:
if done.value:
break
result = collector.sample_trajectory()
trajectory_statistics = collector.compute_statistics(result)
trajectory = distributor.commit_trajectory(result)
if piecewise:
for item in trajectory:
buffer.append(item)
else:
buffer.append(trajectory)
statistics.update(trajectory_statistics)
示例9: _worker_loop
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
global _use_shared_memory
_use_shared_memory = True
torch.set_num_threads(1)
while True:
r = index_queue.get()
if r is None:
data_queue.put(None)
break
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
示例10: share
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def share(self):
"""Share model parameters."""
shared = super().share()
shared['model'] = self.model
if self.opt.get('numthreads', 1) > 1 and isinstance(self.metrics, dict):
torch.set_num_threads(1)
# move metrics and model to shared memory
self.metrics = SharedTable(self.metrics)
self.model.share_memory()
shared['metrics'] = self.metrics
shared['fixed_candidates'] = self.fixed_candidates
shared['fixed_candidate_vecs'] = self.fixed_candidate_vecs
shared['fixed_candidate_encs'] = self.fixed_candidate_encs
shared['vocab_candidates'] = self.vocab_candidates
shared['vocab_candidate_vecs'] = self.vocab_candidate_vecs
shared['optimizer'] = self.optimizer
return shared
示例11: main
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def main(env='MinitaurTrottingEnv-v0'):
env = gym.make(env)
env = envs.AddTimestep(env)
env = envs.Logger(env, interval=PPO_STEPS)
env = envs.Normalizer(env, states=True, rewards=True)
env = envs.Torch(env)
# env = envs.Recorder(env)
env = envs.Runner(env)
env.seed(SEED)
th.set_num_threads(1)
policy = ActorCriticNet(env)
optimizer = optim.Adam(policy.parameters(), lr=LR, eps=1e-5)
num_updates = TOTAL_STEPS // PPO_STEPS + 1
lr_schedule = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1 - epoch/num_updates)
get_action = lambda state: get_action_value(state, policy)
for epoch in range(num_updates):
# We use the Runner collector, but could've written our own
replay = env.run(get_action, steps=PPO_STEPS, render=RENDER)
# Update policy
update(replay, optimizer, policy, env, lr_schedule)
示例12: set_random_seed
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def set_random_seed(random_seed):
if random_seed is not None:
print("Set random seed as {}".format(random_seed))
os.environ['PYTHONHASHSEED'] = str(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.set_num_threads(1)
cudnn.benchmark = False
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
示例13: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def __init__(self,
domain_name, env_seed, policy_producer,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
):
torch.set_num_threads(1)
env = env_producer(domain_name, env_seed)
self._policy_producer = policy_producer
super().__init__(env,
max_num_epoch_paths_saved=max_num_epoch_paths_saved,
render=render,
render_kwargs=render_kwargs,
)
示例14: get_example_outputs
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def get_example_outputs(agent, env, examples, subprocess=False):
"""Do this in a sub-process to avoid setup conflict in master/workers (e.g.
MKL)."""
if subprocess: # i.e. in subprocess.
import torch
torch.set_num_threads(1) # Some fix to prevent MKL hang.
o = env.reset()
a = env.action_space.sample()
o, r, d, env_info = env.step(a)
r = np.asarray(r, dtype="float32") # Must match torch float dtype here.
agent.reset()
agent_inputs = torchify_buffer(AgentInputs(o, a, r))
a, agent_info = agent.step(*agent_inputs)
if "prev_rnn_state" in agent_info:
# Agent leaves B dimension in, strip it: [B,N,H] --> [N,H]
agent_info = agent_info._replace(prev_rnn_state=agent_info.prev_rnn_state[0])
examples["observation"] = o
examples["reward"] = r
examples["done"] = d
examples["env_info"] = env_info
examples["action"] = a # OK to put torch tensor here, could numpify.
examples["agent_info"] = agent_info
示例15: initialize_worker
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_num_threads [as 别名]
def initialize_worker(rank, seed=None, cpu=None, torch_threads=None):
"""Assign CPU affinity, set random seed, set torch_threads if needed to
prevent MKL deadlock.
"""
log_str = f"Sampler rank {rank} initialized"
cpu = [cpu] if isinstance(cpu, int) else cpu
p = psutil.Process()
try:
if cpu is not None:
p.cpu_affinity(cpu)
cpu_affin = p.cpu_affinity()
except AttributeError:
cpu_affin = "UNAVAILABLE MacOS"
log_str += f", CPU affinity {cpu_affin}"
torch_threads = (1 if torch_threads is None and cpu is not None else
torch_threads) # Default to 1 to avoid possible MKL hang.
if torch_threads is not None:
torch.set_num_threads(torch_threads)
log_str += f", Torch threads {torch.get_num_threads()}"
if seed is not None:
set_seed(seed)
time.sleep(0.3) # (so the printing from set_seed is not intermixed)
log_str += f", Seed {seed}"
logger.log(log_str)