当前位置: 首页>>代码示例>>Python>>正文


Python config.batch_size方法代码示例

本文整理汇总了Python中config.config.batch_size方法的典型用法代码示例。如果您正苦于以下问题:Python config.batch_size方法的具体用法?Python config.batch_size怎么用?Python config.batch_size使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在config.config的用法示例。


在下文中一共展示了config.batch_size方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def main(_):
    if config.mode == "train":
        train_entry(config)
    elif config.mode == "data":
        preproc(config)
    elif config.mode == "debug":
        config.batch_size = 2
        config.num_steps = 32
        config.val_num_batches = 2
        config.checkpoint = 2
        config.period = 1
        train_entry(config)
    elif config.mode == "test":
        test_entry(config)
    else:
        print("Unknown mode")
        exit(0) 
开发者ID:andy840314,项目名称:QANet-pytorch-,代码行数:19,代码来源:main.py

示例2: main

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def main():
    data_transformer = DataTransformer(config.dataset_path, use_cuda=config.use_cuda)

    # define our models
    vanilla_encoder = VanillaEncoder(vocab_size=data_transformer.vocab_size,
                                     embedding_size=config.encoder_embedding_size,
                                     output_size=config.encoder_output_size)

    vanilla_decoder = VanillaDecoder(hidden_size=config.decoder_hidden_size,
                                     output_size=data_transformer.vocab_size,
                                     max_length=data_transformer.max_length,
                                     teacher_forcing_ratio=config.teacher_forcing_ratio,
                                     sos_id=data_transformer.SOS_ID,
                                     use_cuda=config.use_cuda)
    if config.use_cuda:
        vanilla_encoder = vanilla_encoder.cuda()
        vanilla_decoder = vanilla_decoder.cuda()


    seq2seq = Seq2Seq(encoder=vanilla_encoder,
                      decoder=vanilla_decoder)

    trainer = Trainer(seq2seq, data_transformer, config.learning_rate, config.use_cuda)
    trainer.train(num_epochs=config.num_epochs, batch_size=config.batch_size, pretrained=False) 
开发者ID:zake7749,项目名称:Sequence-to-Sequence-101,代码行数:26,代码来源:train.py

示例3: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std,
                                config.target_size)

    train_dataset = dataset(data_setting, "train", train_preprocess,
                            config.niters_per_epoch * config.batch_size)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:StevenGrove,项目名称:TreeFilter-Torch,代码行数:32,代码来源:dataloader.py

示例4: __init__

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def __init__(self, sess):
        self.sess = sess
        self.step_size = FLAGS.step_size / 255.0
        self.max_epsilon = FLAGS.max_epsilon / 255.0
        # Prepare graph
        batch_shape = [FLAGS.batch_size, 299, 299, 3]
        self.x_input = tf.placeholder(tf.float32, shape=batch_shape)
        x_max = tf.clip_by_value(self.x_input + self.max_epsilon, 0., 1.0)
        x_min = tf.clip_by_value(self.x_input - self.max_epsilon, 0., 1.0)

        self.y_input = tf.placeholder(tf.int64, shape=batch_shape[0])
        i = tf.constant(0)
        self.x_adv, _, _, _, _ = tf.while_loop(self.stop, self.graph,
                                               [self.x_input, self.y_input, i, x_max, x_min])
        self.restore() 
开发者ID:LiYingwei,项目名称:Regional-Homogeneity,代码行数:17,代码来源:attack.py

示例5: perturb

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def perturb(self, images, labels):
        batch_size = images.shape[0]
        if batch_size < FLAGS.batch_size:
            pad_num = FLAGS.batch_size - batch_size
            pad_img = np.zeros([pad_num, 299, 299, 3])
            images = np.concatenate([images, pad_img])
            pad_label = np.zeros([pad_num])
            labels = np.concatenate([labels, pad_label])
        adv_images = sess.run(self.x_adv, feed_dict={self.x_input: images, self.y_input: labels})
        return adv_images[:batch_size] 
开发者ID:LiYingwei,项目名称:Regional-Homogeneity,代码行数:12,代码来源:attack.py

示例6: build_in_eval

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def build_in_eval():
    with tf.Session() as sess:
        model = Evaluator(sess)
        df = PNGDataFlow(FLAGS.result_dir, FLAGS.test_list_filename, FLAGS.ground_truth_file,
                         img_num=FLAGS.img_num)
        df = BatchData(df, FLAGS.batch_size, remainder=True)
        df.reset_state()

        avgMetric = AvgMetric(datashape=[len(FLAGS.test_networks)])
        total_batch = df.ds.img_num / FLAGS.batch_size
        for batch_index, (x_batch, y_batch, name_batch) in tqdm(enumerate(df), total=total_batch):
            acc, pred = model.eval(x_batch, y_batch)
            avgMetric.update(acc)

    return 1 - avgMetric.get_status() 
开发者ID:LiYingwei,项目名称:Regional-Homogeneity,代码行数:17,代码来源:eval.py

示例7: val

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def val(args):
    list_threhold = [0.5]
    model = getattr(models, config.model_name)()
    if args.ckpt: model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4)
    for threshold in list_threhold:
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader, threshold)
        print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' % (threshold, val_loss, val_f1))

#提交结果使用 
开发者ID:JavisPeng,项目名称:ecg_pytorch,代码行数:15,代码来源:main.py

示例8: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'train_root': config.train_root_folder,
                    'val_root': config.eval_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std)

    train_dataset = dataset(data_setting, "train", train_preprocess,    \
                            config.batch_size * config.niters_per_epoch)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    # import pdb;pdb.set_trace()


    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:JaminFong,项目名称:FNA,代码行数:34,代码来源:dataloader.py

示例9: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std,
                                config.target_size)

    train_dataset = dataset(data_setting, "train", train_preprocess,
                            config.niters_per_epoch * config.batch_size)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=False,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:ycszen,项目名称:TorchSeg,代码行数:32,代码来源:dataloader.py

示例10: get_train_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def get_train_loader(engine, dataset):
    data_setting = {'img_root': config.img_root_folder,
                    'gt_root': config.gt_root_folder,
                    'train_source': config.train_source,
                    'eval_source': config.eval_source}
    train_preprocess = TrainPre(config.image_mean, config.image_std)

    train_dataset = dataset(data_setting, "train", train_preprocess,
                            config.batch_size * config.niters_per_epoch)

    train_sampler = None
    is_shuffle = True
    batch_size = config.batch_size

    if engine.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
        batch_size = config.batch_size // engine.world_size
        is_shuffle = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=config.num_workers,
                                   drop_last=True,
                                   shuffle=is_shuffle,
                                   pin_memory=True,
                                   sampler=train_sampler)

    return train_loader, train_sampler 
开发者ID:ycszen,项目名称:TorchSeg,代码行数:31,代码来源:dataloader.py

示例11: __init__

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def __init__(self, npz_file, batch_size):
        data = np.load(npz_file)
        self.context_idxs = data["context_idxs"]
        self.context_char_idxs = data["context_char_idxs"]
        self.ques_idxs = data["ques_idxs"]
        self.ques_char_idxs = data["ques_char_idxs"]
        self.y1s = data["y1s"]
        self.y2s = data["y2s"]
        self.ids = data["ids"]
        self.num = len(self.ids) 
开发者ID:andy840314,项目名称:QANet-pytorch-,代码行数:12,代码来源:main.py

示例12: get_loader

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def get_loader(npz_file, batch_size):
    dataset = SQuADDataset(npz_file, batch_size)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=5,
                                              collate_fn=collate)
    return data_loader 
开发者ID:andy840314,项目名称:QANet-pytorch-,代码行数:10,代码来源:main.py

示例13: test_entry

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def test_entry(config):
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    dev_dataset = get_loader(config.dev_record_file, config.batch_size)
    fn = os.path.join(config.save_dir, "model.pt")
    model = torch.load(fn)
    test(model, dev_dataset, dev_eval_file, 0) 
开发者ID:andy840314,项目名称:QANet-pytorch-,代码行数:9,代码来源:main.py

示例14: train

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def train(self, num_epochs, batch_size, pretrained=False):

        if pretrained:
            self.load_model()

        step = 0

        for epoch in range(0, num_epochs):
            mini_batches = self.data_transformer.mini_batches(batch_size=batch_size)
            for input_batch, target_batch in mini_batches:
                self.optimizer.zero_grad()
                decoder_outputs, decoder_hidden = self.model(input_batch, target_batch)

                # calculate the loss and back prop.
                cur_loss = self.get_loss(decoder_outputs, target_batch[0])

                # logging
                step += 1
                if step % 50 == 0:
                    print("Step:", step, "char-loss: ", cur_loss.data.numpy())
                    self.save_model()
                cur_loss.backward()

                # optimize
                self.optimizer.step()

        self.save_model() 
开发者ID:zake7749,项目名称:Sequence-to-Sequence-101,代码行数:29,代码来源:train.py

示例15: train_entry

# 需要导入模块: from config import config [as 别名]
# 或者: from config.config import batch_size [as 别名]
def train_entry(config):
    from models import QANet

    with open(config.word_emb_file, "rb") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "rb") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)

    print("Building model...")

    train_dataset = get_loader(config.train_record_file, config.batch_size)
    dev_dataset = get_loader(config.dev_record_file, config.batch_size)

    lr = config.learning_rate
    base_lr = 1
    lr_warm_up_num = config.lr_warm_up_num

    model = QANet(word_mat, char_mat).to(device)

    ema = EMA(config.decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    parameters = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adam(lr=base_lr, betas=(0.9, 0.999), eps=1e-7, weight_decay=5e-8, params=parameters)
    cr = lr / math.log2(lr_warm_up_num)
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda ee: cr * math.log2(ee + 1) if ee < lr_warm_up_num else lr)
    best_f1 = 0
    best_em = 0
    patience = 0
    unused = False
    for iter in range(config.num_epoch):
        train(model, optimizer, scheduler, train_dataset, dev_dataset, dev_eval_file, iter, ema)
        ema.assign(model)
        metrics = test(model, dev_dataset, dev_eval_file, (iter+1)*len(train_dataset))
        dev_f1 = metrics["f1"]
        dev_em = metrics["exact_match"]
        if dev_f1 < best_f1 and dev_em < best_em:
            patience += 1
            if patience > config.early_stop:
                break
        else:
            patience = 0
            best_f1 = max(best_f1, dev_f1)
            best_em = max(best_em, dev_em)

        fn = os.path.join(config.save_dir, "model.pt")
        torch.save(model, fn)
        ema.resume(model) 
开发者ID:andy840314,项目名称:QANet-pytorch-,代码行数:56,代码来源:main.py


注:本文中的config.config.batch_size方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。