本文整理汇总了Python中torch.optim.Adam.load_state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python Adam.load_state_dict方法的具体用法?Python Adam.load_state_dict怎么用?Python Adam.load_state_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.optim.Adam
的用法示例。
在下文中一共展示了Adam.load_state_dict方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: fit
# 需要导入模块: from torch.optim import Adam [as 别名]
# 或者: from torch.optim.Adam import load_state_dict [as 别名]
def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
epochs=1000, restart=0, gpu=False):
"""Train model
Parameters
----------
model : nn.Module
Embedding model
feature_extraction :
Feature extraction.
protocol : pyannote.database.Protocol
log_dir : str
Directory where models and other log files are stored.
subset : {'train', 'development', 'test'}, optional
Defaults to 'train'.
epochs : int, optional
Train model for that many epochs.
restart : int, optional
Restart training at this epoch. Defaults to train from scratch.
gpu : bool, optional
"""
import tensorboardX
writer = tensorboardX.SummaryWriter(log_dir=log_dir)
checkpoint = Checkpoint(log_dir=log_dir,
restart=restart > 0)
batch_generator = self.get_batch_generator(feature_extraction)
batches = batch_generator(protocol, subset=subset)
batch = next(batches)
batches_per_epoch = batch_generator.batches_per_epoch
# save list of classes (one speaker per line)
labels = batch_generator.labels
classes_txt = self.CLASSES_TXT.format(log_dir=log_dir)
with open(classes_txt, mode='w') as fp:
for label in labels:
fp.write(f'{label}\n')
# initialize classifier
n_classes = batch_generator.n_classes
classifier = Classifier(model.output_dim, n_classes,
linear=self.linear)
# load precomputed weights in case of restart
if restart > 0:
weights_pt = checkpoint.WEIGHTS_PT.format(
log_dir=log_dir, epoch=restart)
model.load_state_dict(torch.load(weights_pt))
classifier_pt = self.CLASSIFIER_PT.format(
log_dir=log_dir, epoch=restart)
# send models to GPU
if gpu:
model = model.cuda()
classifier = classifier.cuda(device=None)
model.internal = False
optimizer = Adam(list(model.parameters()) + \
list(classifier.parameters()))
if restart > 0:
optimizer_pt = checkpoint.OPTIMIZER_PT.format(
log_dir=log_dir, epoch=restart)
optimizer.load_state_dict(torch.load(optimizer_pt))
if gpu:
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
epoch = restart if restart > 0 else -1
while True:
epoch += 1
if epoch > epochs:
break
loss_avg = 0.
log_epoch = (epoch < 10) or (epoch % 5 == 0)
if log_epoch:
pass
desc = 'Epoch #{0}'.format(epoch)
for i in tqdm(range(batches_per_epoch), desc=desc):
model.zero_grad()
batch = next(batches)
X = batch['X']
y = batch['y']
if not getattr(model, 'batch_first', True):
X = np.rollaxis(X, 0, 2)
X = np.array(X, dtype=np.float32)
X = Variable(torch.from_numpy(X))
y = Variable(torch.from_numpy(y))
#.........这里部分代码省略.........
示例2: fit
# 需要导入模块: from torch.optim import Adam [as 别名]
# 或者: from torch.optim.Adam import load_state_dict [as 别名]
def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
epochs=1000, restart=0, gpu=False):
import tensorboardX
writer = tensorboardX.SummaryWriter(log_dir=log_dir)
checkpoint = Checkpoint(log_dir=log_dir,
restart=restart > 0)
batch_generator = SpeechSegmentGenerator(
feature_extraction,
per_label=self.per_label, per_fold=self.per_fold,
duration=self.duration, parallel=self.parallel)
batches = batch_generator(protocol, subset=subset)
batch = next(batches)
batches_per_epoch = batch_generator.batches_per_epoch
if restart > 0:
weights_pt = checkpoint.WEIGHTS_PT.format(
log_dir=log_dir, epoch=restart)
model.load_state_dict(torch.load(weights_pt))
if gpu:
model = model.cuda()
model.internal = False
parameters = list(model.parameters())
if self.variant in [2, 3, 4, 5, 6, 7, 8]:
# norm batch-normalization
self.norm_bn = nn.BatchNorm1d(
1, eps=1e-5, momentum=0.1, affine=True)
if gpu:
self.norm_bn = self.norm_bn.cuda()
parameters += list(self.norm_bn.parameters())
if self.variant in [9]:
# norm batch-normalization
self.norm_bn = nn.BatchNorm1d(
1, eps=1e-5, momentum=0.1, affine=False)
if gpu:
self.norm_bn = self.norm_bn.cuda()
parameters += list(self.norm_bn.parameters())
if self.variant in [5, 6, 7]:
self.positive_bn = nn.BatchNorm1d(
1, eps=1e-5, momentum=0.1, affine=False)
self.negative_bn = nn.BatchNorm1d(
1, eps=1e-5, momentum=0.1, affine=False)
if gpu:
self.positive_bn = self.positive_bn.cuda()
self.negative_bn = self.negative_bn.cuda()
parameters += list(self.positive_bn.parameters())
parameters += list(self.negative_bn.parameters())
if self.variant in [8, 9]:
self.delta_bn = nn.BatchNorm1d(
1, eps=1e-5, momentum=0.1, affine=False)
if gpu:
self.delta_bn = self.delta_bn.cuda()
parameters += list(self.delta_bn.parameters())
optimizer = Adam(parameters)
if restart > 0:
optimizer_pt = checkpoint.OPTIMIZER_PT.format(
log_dir=log_dir, epoch=restart)
optimizer.load_state_dict(torch.load(optimizer_pt))
if gpu:
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
epoch = restart if restart > 0 else -1
while True:
epoch += 1
if epoch > epochs:
break
loss_avg, tloss_avg, closs_avg = 0., 0., 0.
if epoch % 5 == 0:
log_positive = []
log_negative = []
log_delta = []
log_norm = []
desc = 'Epoch #{0}'.format(epoch)
for i in tqdm(range(batches_per_epoch), desc=desc):
model.zero_grad()
batch = next(batches)
X = batch['X']
if not getattr(model, 'batch_first', True):
#.........这里部分代码省略.........
示例3: fit
# 需要导入模块: from torch.optim import Adam [as 别名]
# 或者: from torch.optim.Adam import load_state_dict [as 别名]
def fit(self, model, feature_extraction, protocol, log_dir, subset='train',
epochs=1000, restart=None, gpu=False):
import tensorboardX
writer = tensorboardX.SummaryWriter(log_dir=log_dir)
checkpoint = Checkpoint(
log_dir=log_dir, restart=(False if restart is None else True))
try:
batch_generator = SpeechSegmentGenerator(
feature_extraction,
per_label=self.per_label, per_fold=self.per_fold,
duration=self.duration)
batches = batch_generator(protocol, subset=subset)
batch = next(batches)
except OSError as e:
del batch_generator.data_
batch_generator = SpeechSegmentGenerator(
feature_extraction,
per_label=self.per_label, per_fold=self.per_fold,
duration=self.duration, fast=False)
batches = batch_generator(protocol, subset=subset)
batch = next(batches)
# one minute per speaker
duration_per_epoch = 60. * batch_generator.n_labels
duration_per_batch = self.duration * batch_generator.n_sequences_per_batch
batches_per_epoch = int(np.ceil(duration_per_epoch / duration_per_batch))
if restart is not None:
weights_pt = checkpoint.WEIGHTS_PT.format(
log_dir=log_dir, epoch=restart)
model.load_state_dict(torch.load(weights_pt))
if gpu:
model = model.cuda()
model.internal = False
n_domains = len(batch_generator.domains_[self.domain])
if n_domains < 2:
raise ValueError('There must be more than one domain.')
domain_clf = DomainClassifier(model.output_dim, n_domains, alpha=1.)
if gpu:
domain_clf = domain_clf.cuda()
domain_loss = nn.CrossEntropyLoss()
optimizer = Adam(list(model.parameters()) + list(domain_clf.parameters()))
if restart is not None:
optimizer_pt = checkpoint.OPTIMIZER_PT.format(
log_dir=log_dir, epoch=restart)
optimizer.load_state_dict(torch.load(optimizer_pt))
if gpu:
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
restart = 0 if restart is None else restart + 1
for epoch in range(restart, restart + epochs):
tloss_avg = 0.
dloss_avg = 0.
loss_avg = 0.
dacc_avg = 0.
positive, negative = [], []
if not model.normalize:
norms = []
desc = 'Epoch #{0}'.format(epoch)
for i in tqdm(range(batches_per_epoch), desc=desc):
model.zero_grad()
batch = next(batches)
X = batch['X']
if not getattr(model, 'batch_first', True):
X = np.rollaxis(X, 0, 2)
X = np.array(X, dtype=np.float32)
X = Variable(torch.from_numpy(X))
y = batch['y']
y_domain = batch['y_{domain}'.format(domain=self.domain)]
if gpu:
X = X.cuda()
fX = model(X)
if not model.normalize:
if gpu:
fX_ = fX.data.cpu().numpy()
else:
fX_ = fX.data.numpy()
norms.append(np.linalg.norm(fX_, axis=0))
triplet_losses = []
for d, domain in enumerate(np.unique(y_domain)):
#.........这里部分代码省略.........