当前位置: 首页>>代码示例>>Python>>正文


Python metrics.Metrics方法代码示例

本文整理汇总了Python中metrics.Metrics方法的典型用法代码示例。如果您正苦于以下问题:Python metrics.Metrics方法的具体用法?Python metrics.Metrics怎么用?Python metrics.Metrics使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在metrics的用法示例。


在下文中一共展示了metrics.Metrics方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: valid

# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def valid(valid_loader, running_acc, model, device):
    acc_metric = Metrics(**args)
    model.eval()

    with torch.no_grad():
        for step, data in enumerate(valid_loader):
            x_input     = data['data']
            annotations = data['annots'] 

            if isinstance(x_input, torch.Tensor):
                outputs = model(x_input.to(device))
            else:
                for i, item in enumerate(x_input):
                    if isinstance(item, torch.Tensor):
                        x_input[i] = item.to(device)
                outputs = model(*x_input)
        
            running_acc.append(acc_metric.get_accuracy(outputs, annotations))

            if step % 100 == 0:
                print('Step: {}/{} | validation acc: {:.4f}'.format(step, len(valid_loader), running_acc[-1]))
    
        # END FOR: Validation Accuracy

    return running_acc 
开发者ID:MichiganCOG,项目名称:ViP,代码行数:27,代码来源:train.py

示例2: debug_mixture_classifier

# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def debug_mixture_classifier(opts, step, probs, points, num_plot=320, real=True):
    """Small debugger for the mixture classifier's output.

    """
    num = len(points)
    if len(probs) != num:
        return
    if num < 2 * num_plot:
        return
    sorted_vals_and_ids = sorted(zip(probs, range(num)))
    if real:
        correct = sorted_vals_and_ids[-num_plot:]
        wrong = sorted_vals_and_ids[:num_plot]
    else:
        correct = sorted_vals_and_ids[:num_plot]
        wrong = sorted_vals_and_ids[-num_plot:]
    correct_ids = [_id for val, _id in correct]
    wrong_ids = [_id for val, _id in wrong]
    idstring = 'real' if real else 'fake'
    logging.debug('Correctly classified %s points probs:' %\
                  idstring)
    logging.debug([val[0] for val, _id in correct])
    logging.debug('Incorrectly classified %s points probs:' %\
                  idstring)
    logging.debug([val[0] for val, _id in wrong])
    metrics = metrics_lib.Metrics()
    metrics.make_plots(opts, step,
                       None, points[correct_ids],
                       prefix='c_%s_correct_' % idstring)
    metrics.make_plots(opts, step,
                       None, points[wrong_ids],
                       prefix='c_%s_wrong_' % idstring) 
开发者ID:tolstikhin,项目名称:adagan,代码行数:34,代码来源:utils.py

示例3: debug_updated_weights

# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def debug_updated_weights(opts, steps, weights, data):
    """ Various debug plots for updated weights of training points.

    """
    assert data.num_points == len(weights), 'Length mismatch'
    ws_and_ids = sorted(zip(weights,
                        range(len(weights))))
    num_plot = 20 * 16
    if num_plot > len(weights):
        return
    ids = [_id for w, _id in ws_and_ids[:num_plot]]
    plot_points = data.data[ids]
    metrics = metrics_lib.Metrics()
    metrics.make_plots(opts, steps,
                       None, plot_points,
                       prefix='d_least_')
    ids = [_id for w, _id in ws_and_ids[-num_plot:]]
    plot_points = data.data[ids]
    metrics = metrics_lib.Metrics()
    metrics.make_plots(opts, steps,
                       None, plot_points,
                       prefix='d_most_')
    plt.clf()
    ax1 = plt.subplot(211)
    ax1.set_title('Weights over data points')
    plt.plot(range(len(weights)), sorted(weights))
    plt.axis([0, len(weights), 0., 2. * np.max(weights)])
    if data.labels is not None:
        all_labels = np.unique(data.labels)
        w_per_label = -1. * np.ones(len(all_labels))
        for _id, y in enumerate(all_labels):
            w_per_label[_id] = np.sum(
                    weights[np.where(data.labels == y)[0]])
        ax2 = plt.subplot(212)
        ax2.set_title('Weights over labels')
        plt.scatter(range(len(all_labels)), w_per_label, s=30)
    filename = 'data_w{:02d}.png'.format(steps)
    create_dir(opts['work_dir'])
    plt.savefig(o_gfile((opts["work_dir"], filename), 'wb')) 
开发者ID:tolstikhin,项目名称:adagan,代码行数:41,代码来源:utils.py

示例4: _train_internal

# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def _train_internal(self, opts):
        """Train a GAN model.

        """

        batches_num = self._data.num_points / opts['batch_size']
        train_size = self._data.num_points

        counter = 0
        logging.debug('Training GAN')
        for _epoch in xrange(opts["gan_epoch_num"]):
            for _idx in xrange(batches_num):
                data_ids = np.random.choice(train_size, opts['batch_size'],
                                            replace=False, p=self._data_weights)
                batch_images = self._data.data[data_ids].astype(np.float)
                batch_noise = utils.generate_noise(opts, opts['batch_size'])
                # Update discriminator parameters
                for _iter in xrange(opts['d_steps']):
                    _ = self._session.run(
                        self._d_optim,
                        feed_dict={self._real_points_ph: batch_images,
                                   self._noise_ph: batch_noise})
                # Update generator parameters
                for _iter in xrange(opts['g_steps']):
                    _ = self._session.run(
                        self._g_optim, feed_dict={self._noise_ph: batch_noise})
                counter += 1
                if opts['verbose'] and counter % opts['plot_every'] == 0:
                    metrics = Metrics()
                    points_to_plot = self._run_batch(
                        opts, self._G, self._noise_ph,
                        self._noise_for_plots[0:320])
                    data_ids = np.random.choice(train_size, 320,
                                                replace=False,
                                                p=self._data_weights)
                    metrics.make_plots(
                        opts, counter,
                        self._data.data[data_ids],
                        points_to_plot,
                        prefix='sample_e%04d_mb%05d_' % (_epoch, _idx)) 
开发者ID:tolstikhin,项目名称:adagan,代码行数:42,代码来源:gan.py

示例5: __init__

# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def __init__(self, metric_names):
        self.model_name = P.MODEL_ID
        self.setup_folders()

        initialize_logger(os.path.join(self.model_folder, 'log.txt').format(self.model_name))
        P.write_to_file(os.path.join(self.model_folder, 'config.ini'))
        logging.info(P.to_string())

        self.train_metrics = metrics.Metrics('train', metric_names, P.N_CLASSES)
        self.val_metrics = metrics.Metrics('validation', metric_names, P.N_CLASSES)
        self.epoch = -1 
开发者ID:gzuidhof,项目名称:luna16,代码行数:13,代码来源:trainer.py

示例6: main

# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def main(args):
    conf = getattr(configs, 'config_'+args.model)()
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    else:
        print("Note that our pre-trained models require CUDA to evaluate.")
    
    # Load data
    test_set=APIDataset(args.data_path+'test.desc.h5', args.data_path+'test.apiseq.h5', conf['max_sent_len'])
    test_loader=torch.utils.data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=1)
    vocab_api = load_dict(args.data_path+'vocab.apiseq.json')
    vocab_desc = load_dict(args.data_path+'vocab.desc.json')
    metrics=Metrics()
    
    # Load model checkpoints   
    model = getattr(models, args.model)(conf)
    ckpt=f'./output/{args.model}/{args.expname}/{args.timestamp}/models/model_epo{args.reload_from}.pkl'
    model.load_state_dict(torch.load(ckpt))
    
    f_eval = open(f"./output/{args.model}/{args.expname}/results.txt".format(args.model, args.expname), "w")
    
    evaluate(model, metrics, test_loader, vocab_desc, vocab_api, args.n_samples, args.decode_mode , f_eval) 
开发者ID:guxd,项目名称:deepAPI,代码行数:28,代码来源:sample.py

示例7: __init__

# 需要导入模块: import metrics [as 别名]
# 或者: from metrics import Metrics [as 别名]
def __init__(self, db, config):

        self.database = db

        self.config = config
        super(PokemonGoBot, self).__init__()

        self.fort_timeouts = dict()
        self.pokemon_list = json.load(
            open(os.path.join(_base_dir, 'data', 'pokemon.json'))
        )
        self.item_list = json.load(open(os.path.join(_base_dir, 'data', 'items.json')))
        # @var Metrics
        self.metrics = Metrics(self)
        self.latest_inventory = None
        self.cell = None
        self.recent_forts = [None] * config.forts_max_circle_size
        self.tick_count = 0
        self.softban = False
        self.start_position = None
        self.last_map_object = None
        self.last_time_map_object = 0
        self.logger = logging.getLogger(type(self).__name__)
        self.alt = self.config.gps_default_altitude

        # Make our own copy of the workers for this instance
        self.workers = []

        # Theading setup for file writing
        self.web_update_queue = Queue.Queue(maxsize=1)
        self.web_update_thread = threading.Thread(target=self.update_web_location_worker)
        self.web_update_thread.start()

        # Heartbeat limiting
        self.heartbeat_threshold = self.config.heartbeat_threshold
        self.heartbeat_counter = 0
        self.last_heartbeat = time.time()

        self.capture_locked = False  # lock catching while moving to VIP pokemon

        client_id_file_path = os.path.join(_base_dir, 'data', 'mqtt_client_id')
        saved_info = shelve.open(client_id_file_path)
        key = 'client_id'.encode('utf-8')
        if key in saved_info:
            self.config.client_id = saved_info[key]
        else:
            self.config.client_id = str(uuid.uuid4())
            saved_info[key] = self.config.client_id
        saved_info.close() 
开发者ID:PokemonGoF,项目名称:PokemonGo-Bot-Backup,代码行数:51,代码来源:__init__.py


注:本文中的metrics.Metrics方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。