本文整理汇总了Python中metrics.Metrics方法的典型用法代码示例。如果您正苦于以下问题:Python metrics.Metrics方法的具体用法?Python metrics.Metrics怎么用?Python metrics.Metrics使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类metrics
的用法示例。
在下文中一共展示了metrics.Metrics方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: valid
# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def valid(valid_loader, running_acc, model, device):
acc_metric = Metrics(**args)
model.eval()
with torch.no_grad():
for step, data in enumerate(valid_loader):
x_input = data['data']
annotations = data['annots']
if isinstance(x_input, torch.Tensor):
outputs = model(x_input.to(device))
else:
for i, item in enumerate(x_input):
if isinstance(item, torch.Tensor):
x_input[i] = item.to(device)
outputs = model(*x_input)
running_acc.append(acc_metric.get_accuracy(outputs, annotations))
if step % 100 == 0:
print('Step: {}/{} | validation acc: {:.4f}'.format(step, len(valid_loader), running_acc[-1]))
# END FOR: Validation Accuracy
return running_acc
示例2: debug_mixture_classifier
# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def debug_mixture_classifier(opts, step, probs, points, num_plot=320, real=True):
"""Small debugger for the mixture classifier's output.
"""
num = len(points)
if len(probs) != num:
return
if num < 2 * num_plot:
return
sorted_vals_and_ids = sorted(zip(probs, range(num)))
if real:
correct = sorted_vals_and_ids[-num_plot:]
wrong = sorted_vals_and_ids[:num_plot]
else:
correct = sorted_vals_and_ids[:num_plot]
wrong = sorted_vals_and_ids[-num_plot:]
correct_ids = [_id for val, _id in correct]
wrong_ids = [_id for val, _id in wrong]
idstring = 'real' if real else 'fake'
logging.debug('Correctly classified %s points probs:' %\
idstring)
logging.debug([val[0] for val, _id in correct])
logging.debug('Incorrectly classified %s points probs:' %\
idstring)
logging.debug([val[0] for val, _id in wrong])
metrics = metrics_lib.Metrics()
metrics.make_plots(opts, step,
None, points[correct_ids],
prefix='c_%s_correct_' % idstring)
metrics.make_plots(opts, step,
None, points[wrong_ids],
prefix='c_%s_wrong_' % idstring)
示例3: debug_updated_weights
# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def debug_updated_weights(opts, steps, weights, data):
""" Various debug plots for updated weights of training points.
"""
assert data.num_points == len(weights), 'Length mismatch'
ws_and_ids = sorted(zip(weights,
range(len(weights))))
num_plot = 20 * 16
if num_plot > len(weights):
return
ids = [_id for w, _id in ws_and_ids[:num_plot]]
plot_points = data.data[ids]
metrics = metrics_lib.Metrics()
metrics.make_plots(opts, steps,
None, plot_points,
prefix='d_least_')
ids = [_id for w, _id in ws_and_ids[-num_plot:]]
plot_points = data.data[ids]
metrics = metrics_lib.Metrics()
metrics.make_plots(opts, steps,
None, plot_points,
prefix='d_most_')
plt.clf()
ax1 = plt.subplot(211)
ax1.set_title('Weights over data points')
plt.plot(range(len(weights)), sorted(weights))
plt.axis([0, len(weights), 0., 2. * np.max(weights)])
if data.labels is not None:
all_labels = np.unique(data.labels)
w_per_label = -1. * np.ones(len(all_labels))
for _id, y in enumerate(all_labels):
w_per_label[_id] = np.sum(
weights[np.where(data.labels == y)[0]])
ax2 = plt.subplot(212)
ax2.set_title('Weights over labels')
plt.scatter(range(len(all_labels)), w_per_label, s=30)
filename = 'data_w{:02d}.png'.format(steps)
create_dir(opts['work_dir'])
plt.savefig(o_gfile((opts["work_dir"], filename), 'wb'))
示例4: _train_internal
# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def _train_internal(self, opts):
"""Train a GAN model.
"""
batches_num = self._data.num_points / opts['batch_size']
train_size = self._data.num_points
counter = 0
logging.debug('Training GAN')
for _epoch in xrange(opts["gan_epoch_num"]):
for _idx in xrange(batches_num):
data_ids = np.random.choice(train_size, opts['batch_size'],
replace=False, p=self._data_weights)
batch_images = self._data.data[data_ids].astype(np.float)
batch_noise = utils.generate_noise(opts, opts['batch_size'])
# Update discriminator parameters
for _iter in xrange(opts['d_steps']):
_ = self._session.run(
self._d_optim,
feed_dict={self._real_points_ph: batch_images,
self._noise_ph: batch_noise})
# Update generator parameters
for _iter in xrange(opts['g_steps']):
_ = self._session.run(
self._g_optim, feed_dict={self._noise_ph: batch_noise})
counter += 1
if opts['verbose'] and counter % opts['plot_every'] == 0:
metrics = Metrics()
points_to_plot = self._run_batch(
opts, self._G, self._noise_ph,
self._noise_for_plots[0:320])
data_ids = np.random.choice(train_size, 320,
replace=False,
p=self._data_weights)
metrics.make_plots(
opts, counter,
self._data.data[data_ids],
points_to_plot,
prefix='sample_e%04d_mb%05d_' % (_epoch, _idx))
示例5: __init__
# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def __init__(self, metric_names):
self.model_name = P.MODEL_ID
self.setup_folders()
initialize_logger(os.path.join(self.model_folder, 'log.txt').format(self.model_name))
P.write_to_file(os.path.join(self.model_folder, 'config.ini'))
logging.info(P.to_string())
self.train_metrics = metrics.Metrics('train', metric_names, P.N_CLASSES)
self.val_metrics = metrics.Metrics('validation', metric_names, P.N_CLASSES)
self.epoch = -1
示例6: main
# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def main(args):
conf = getattr(configs, 'config_'+args.model)()
# Set the random seed manually for reproducibility.
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
else:
print("Note that our pre-trained models require CUDA to evaluate.")
# Load data
test_set=APIDataset(args.data_path+'test.desc.h5', args.data_path+'test.apiseq.h5', conf['max_sent_len'])
test_loader=torch.utils.data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=1)
vocab_api = load_dict(args.data_path+'vocab.apiseq.json')
vocab_desc = load_dict(args.data_path+'vocab.desc.json')
metrics=Metrics()
# Load model checkpoints
model = getattr(models, args.model)(conf)
ckpt=f'./output/{args.model}/{args.expname}/{args.timestamp}/models/model_epo{args.reload_from}.pkl'
model.load_state_dict(torch.load(ckpt))
f_eval = open(f"./output/{args.model}/{args.expname}/results.txt".format(args.model, args.expname), "w")
evaluate(model, metrics, test_loader, vocab_desc, vocab_api, args.n_samples, args.decode_mode , f_eval)
示例7: __init__
# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def __init__(self, db, config):
self.database = db
self.config = config
super(PokemonGoBot, self).__init__()
self.fort_timeouts = dict()
self.pokemon_list = json.load(
open(os.path.join(_base_dir, 'data', 'pokemon.json'))
)
self.item_list = json.load(open(os.path.join(_base_dir, 'data', 'items.json')))
# @var Metrics
self.metrics = Metrics(self)
self.latest_inventory = None
self.cell = None
self.recent_forts = [None] * config.forts_max_circle_size
self.tick_count = 0
self.softban = False
self.start_position = None
self.last_map_object = None
self.last_time_map_object = 0
self.logger = logging.getLogger(type(self).__name__)
self.alt = self.config.gps_default_altitude
# Make our own copy of the workers for this instance
self.workers = []
# Theading setup for file writing
self.web_update_queue = Queue.Queue(maxsize=1)
self.web_update_thread = threading.Thread(target=self.update_web_location_worker)
self.web_update_thread.start()
# Heartbeat limiting
self.heartbeat_threshold = self.config.heartbeat_threshold
self.heartbeat_counter = 0
self.last_heartbeat = time.time()
self.capture_locked = False # lock catching while moving to VIP pokemon
client_id_file_path = os.path.join(_base_dir, 'data', 'mqtt_client_id')
saved_info = shelve.open(client_id_file_path)
key = 'client_id'.encode('utf-8')
if key in saved_info:
self.config.client_id = saved_info[key]
else:
self.config.client_id = str(uuid.uuid4())
saved_info[key] = self.config.client_id
saved_info.close()