当前位置: 首页>>代码示例>>Python>>正文


Python config.batch_size方法代码示例

本文整理汇总了Python中config.batch_size方法的典型用法代码示例。如果您正苦于以下问题:Python config.batch_size方法的具体用法?Python config.batch_size怎么用?Python config.batch_size使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在config的用法示例。


在下文中一共展示了config.batch_size方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_loader

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def get_loader(train=False, val=False, test=False, trainval=False):
    """ Returns a data loader for the desired split """
    split = VQA(
        utils.path_for(train=train, val=val, test=test, trainval=trainval, question=True),
        utils.path_for(train=train, val=val, test=test, trainval=trainval, answer=True),
        config.preprocessed_trainval_path if not test else config.preprocessed_test_path,
        answerable_only=train or trainval,
        dummy_answers=test,
    )
    loader = torch.utils.data.DataLoader(
        split,
        batch_size=config.batch_size,
        shuffle=train or trainval,  # only shuffle the data in training
        pin_memory=True,
        num_workers=config.data_workers,
        collate_fn=collate_fn,
    )
    return loader 
开发者ID:KaihuaTang,项目名称:VQA2.0-Recent-Approachs-2018.pytorch,代码行数:20,代码来源:data.py

示例2: __init__

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
                               batch_size=config.batch_size, single_pass=False)
        time.sleep(5)
        
        if not os.path.exists(config.log_root):
            os.mkdir(config.log_root)

        self.model_dir = os.path.join(config.log_root, 'train_model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)
        
        self.eval_log = os.path.join(config.log_root, 'eval_log')
        if not os.path.exists(self.eval_log):
            os.mkdir(self.eval_log)
        self.summary_writer = tf.compat.v1.summary.FileWriter(self.eval_log) 
开发者ID:wyu-du,项目名称:Reinforce-Paraphrase-Generation,代码行数:19,代码来源:train.py

示例3: adv_train_generator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def adv_train_generator(self, g_step):
        """
        The gen is trained by MLE-like objective.
        """
        total_g_loss = 0
        for step in range(g_step):
            inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA)

            # ===Train===
            rewards = self.get_mali_reward(target)
            adv_loss = self.gen.adv_loss(inp, target, rewards)
            self.optimize(self.gen_adv_opt, adv_loss)
            total_g_loss += adv_loss.item()

        # ===Test===
        self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:18,代码来源:maligan_instructor.py

示例4: adv_train_generator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def adv_train_generator(self, g_step):
        total_loss = 0
        for step in range(g_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:19,代码来源:relgan_instructor.py

示例5: adv_train_discriminator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def adv_train_discriminator(self, d_step):
        total_loss = 0
        for step in range(d_step):
            real_samples = F.one_hot(self.oracle_data.random_batch()['target'], cfg.vocab_size).float()
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            _, d_loss = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.dis_opt, d_loss, self.dis)
            total_loss += d_loss.item()

        return total_loss / d_step if d_step != 0 else 0 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:19,代码来源:relgan_instructor.py

示例6: cal_metrics

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def cal_metrics(self, fmt_str=False):
        """
        Calculate metrics
        :param fmt_str: if return format string for logging
        """
        with torch.no_grad():
            # Prepare data for evaluation
            gen_data = GenDataIter(self.gen.sample(cfg.samples_num, 4 * cfg.batch_size))

            # Reset metrics
            self.nll_oracle.reset(self.oracle, gen_data.loader)
            self.nll_gen.reset(self.gen, self.oracle_data.loader)
            self.nll_div.reset(self.gen, gen_data.loader)

        if fmt_str:
            return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics])
        else:
            return [metric.get_score() for metric in self.all_metrics] 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:20,代码来源:instructor.py

示例7: train_discriminator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def train_discriminator(self, d_step, d_epoch, phase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        # prepare loader for validate
        global d_loss, train_acc

        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.train_data.target  # not re-sample the Oracle data
            neg_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)
            dis_data = DisDataIter(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # ===Train===
                d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion,
                                                         self.dis_opt)

            # ===Test===
            self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % (
                phase, step, d_loss, train_acc))

            if cfg.if_save and not cfg.if_test:
                torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:27,代码来源:maligan_instructor.py

示例8: adv_train_generator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def adv_train_generator(self, g_step):
        total_loss = 0
        for step in range(g_step):
            real_samples = self.train_data.random_batch()['target']
            gen_samples = self.gen.sample(cfg.batch_size, cfg.batch_size, one_hot=True)
            if cfg.CUDA:
                real_samples, gen_samples = real_samples.cuda(), gen_samples.cuda()
            real_samples = F.one_hot(real_samples, cfg.vocab_size).float()

            # ===Train===
            d_out_real = self.dis(real_samples)
            d_out_fake = self.dis(gen_samples)
            g_loss, _ = get_losses(d_out_real, d_out_fake, cfg.loss_type)

            self.optimize(self.gen_adv_opt, g_loss, self.gen)
            total_loss += g_loss.item()

        return total_loss / g_step if g_step != 0 else 0 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:20,代码来源:relgan_instructor.py

示例9: adv_train_generator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def adv_train_generator(self, g_step):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """
        rollout_func = rollout.ROLLOUT(self.gen, cfg.CUDA)
        total_g_loss = 0
        for step in range(g_step):
            inp, target = GenDataIter.prepare(self.gen.sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA)

            # ===Train===
            rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis)
            adv_loss = self.gen.batchPGLoss(inp, target, rewards)
            self.optimize(self.gen_adv_opt, adv_loss)
            total_g_loss += adv_loss.item()

        # ===Test===
        self.log.info('[ADV-GEN]: g_loss = %.4f, %s' % (total_g_loss, self.cal_metrics(fmt_str=True))) 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:20,代码来源:seqgan_instructor.py

示例10: train_discriminator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def train_discriminator(self, d_step, d_epoch, phase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        # prepare loader for validate
        global d_loss, train_acc
        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.train_data.target
            neg_samples = self.gen.sample(cfg.samples_num, 4 * cfg.batch_size)
            dis_data = DisDataIter(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # ===Train===
                d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion,
                                                         self.dis_opt)

            # ===Test===
            self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % (
                phase, step, d_loss, train_acc))

            if cfg.if_save and not cfg.if_test:
                torch.save(self.dis.state_dict(), cfg.pretrained_dis_path) 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:26,代码来源:seqgan_instructor.py

示例11: train_discriminator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def train_discriminator(self, d_step, d_epoch, phase='MLE'):
        """
        Training the discriminator on real_data_samples (positive) and generated samples from gen (negative).
        Samples are drawn d_step times, and the discriminator is trained for d_epoch d_epoch.
        """
        d_loss, train_acc = 0, 0
        for step in range(d_step):
            # prepare loader for training
            pos_samples = self.train_data.target
            neg_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis)
            dis_data = DisDataIter(pos_samples, neg_samples)

            for epoch in range(d_epoch):
                # ===Train===
                d_loss, train_acc = self.train_dis_epoch(self.dis, dis_data.loader, self.dis_criterion,
                                                         self.dis_opt)

            # ===Test===
            self.log.info('[%s-DIS] d_step %d: d_loss = %.4f, train_acc = %.4f,' % (
                phase, step, d_loss, train_acc)) 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:22,代码来源:leakgan_instructor.py

示例12: cal_metrics

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def cal_metrics(self, fmt_str=False):
        with torch.no_grad():
            # Prepare data for evaluation
            eval_samples = self.gen.sample(cfg.samples_num, cfg.batch_size, self.dis)
            gen_data = GenDataIter(eval_samples)
            gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict)
            gen_tokens_s = tensor_to_tokens(self.gen.sample(200, cfg.batch_size, self.dis), self.idx2word_dict)

            # Reset metrics
            self.bleu.reset(test_text=gen_tokens, real_text=self.test_data.tokens)
            self.nll_gen.reset(self.gen, self.train_data.loader, leak_dis=self.dis)
            self.nll_div.reset(self.gen, gen_data.loader, leak_dis=self.dis)
            self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens)
            self.ppl.reset(gen_tokens)

        if fmt_str:
            return ', '.join(['%s = %s' % (metric.get_name(), metric.get_score()) for metric in self.all_metrics])
        else:
            return [metric.get_score() for metric in self.all_metrics] 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:21,代码来源:leakgan_instructor.py

示例13: adv_train_generator

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def adv_train_generator(self, g_step):
        """
        The gen is trained using policy gradients, using the reward from the discriminator.
        Training is done for num_batches batches.
        """
        for i in range(cfg.k_label):
            rollout_func = rollout.ROLLOUT(self.gen_list[i], cfg.CUDA)
            total_g_loss = 0
            for step in range(g_step):
                inp, target = GenDataIter.prepare(self.gen_list[i].sample(cfg.batch_size, cfg.batch_size), gpu=cfg.CUDA)

                # ===Train===
                rewards = rollout_func.get_reward(target, cfg.rollout_num, self.dis, current_k=i)
                adv_loss = self.gen_list[i].batchPGLoss(inp, target, rewards)
                self.optimize(self.gen_opt_list[i], adv_loss)
                total_g_loss += adv_loss.item()

        # ===Test===
        self.log.info('[ADV-GEN]: %s', self.comb_metrics(fmt_str=True)) 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:21,代码来源:sentigan_instructor.py

示例14: cal_metrics_with_label

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def cal_metrics_with_label(self, label_i):
        assert type(label_i) == int, 'missing label'

        with torch.no_grad():
            # Prepare data for evaluation
            eval_samples = self.gen_list[label_i].sample(cfg.samples_num, 8 * cfg.batch_size)
            gen_data = GenDataIter(eval_samples)
            gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict)
            gen_tokens_s = tensor_to_tokens(self.gen_list[label_i].sample(200, 200), self.idx2word_dict)
            clas_data = CatClasDataIter([eval_samples], label_i)

            # Reset metrics
            self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens)
            self.nll_gen.reset(self.gen_list[label_i], self.train_data_list[label_i].loader)
            self.nll_div.reset(self.gen_list[label_i], gen_data.loader)
            self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens)
            self.clas_acc.reset(self.clas, clas_data.loader)
            self.ppl.reset(gen_tokens)

        return [metric.get_score() for metric in self.all_metrics] 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:22,代码来源:sentigan_instructor.py

示例15: cal_metrics_with_label

# 需要导入模块: import config [as 别名]
# 或者: from config import batch_size [as 别名]
def cal_metrics_with_label(self, label_i):
        assert type(label_i) == int, 'missing label'

        with torch.no_grad():
            # Prepare data for evaluation
            eval_samples = self.gen.sample(cfg.samples_num, 8 * cfg.batch_size, label_i=label_i)
            gen_data = GenDataIter(eval_samples)
            gen_tokens = tensor_to_tokens(eval_samples, self.idx2word_dict)
            gen_tokens_s = tensor_to_tokens(self.gen.sample(200, 200, label_i=label_i), self.idx2word_dict)
            clas_data = CatClasDataIter([eval_samples], label_i)

            # Reset metrics
            self.bleu.reset(test_text=gen_tokens, real_text=self.test_data_list[label_i].tokens)
            self.nll_gen.reset(self.gen, self.train_data_list[label_i].loader, label_i)
            self.nll_div.reset(self.gen, gen_data.loader, label_i)
            self.self_bleu.reset(test_text=gen_tokens_s, real_text=gen_tokens)
            self.clas_acc.reset(self.clas, clas_data.loader)
            self.ppl.reset(gen_tokens)

        return [metric.get_score() for metric in self.all_metrics] 
开发者ID:williamSYSU,项目名称:TextGAN-PyTorch,代码行数:22,代码来源:instructor.py


注:本文中的config.batch_size方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。