当前位置: 首页>>代码示例>>Python>>正文


Python utils.DataLoader方法代码示例

本文整理汇总了Python中utils.DataLoader方法的典型用法代码示例。如果您正苦于以下问题:Python utils.DataLoader方法的具体用法?Python utils.DataLoader怎么用?Python utils.DataLoader使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils的用法示例。


在下文中一共展示了utils.DataLoader方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: import utils [as 别名]
# 或者: from utils import DataLoader [as 别名]
def __init__(self):
        self.nH = 256
        self.nW = 256
        self.nC = 3
        self.data_loader = DataLoader()
        self.image_shape = (self.nH, self.nW, self.nC)
        self.image_A = Input(shape=self.image_shape)
        self.image_B = Input(shape=self.image_shape)
        self.discriminator = self.creat_discriminator()
        self.discriminator.compile(loss='mse', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
        self.generator = self.creat_generator()
        self.fake_A = self.generator(self.image_B)
        self.discriminator.trainable = False
        self.valid = self.discriminator([self.fake_A, self.image_B])
        self.combined = Model(inputs=[self.image_A, self.image_B], outputs=[self.valid, self.fake_A])
        self.combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=Adam(0.0002, 0.5))
        # Calculate output shape of D (PatchGAN)
        self.disc_patch = (int(self.nH/2**4), int(self.nW/2**4), 1)
        pass 
开发者ID:wmylxmj,项目名称:Pix2Pix-Keras,代码行数:21,代码来源:model.py

示例2: __init__

# 需要导入模块: import utils [as 别名]
# 或者: from utils import DataLoader [as 别名]
def __init__(self, scale=4, num_res_blocks=32, pretrained_weights=None, name=None):
        self.scale = scale
        self.num_res_blocks = num_res_blocks
        self.model = wdsr_b(scale=scale, num_res_blocks=num_res_blocks)
        self.model.compile(optimizer=AdamWithWeightsNormalization(lr=0.001), \
                           loss=self.mae, metrics=[self.psnr])
        if pretrained_weights != None:
            self.model.load_weights(pretrained_weights)
            print("[OK] weights loaded.")
            pass
        self.data_loader = DataLoader(scale=scale, crop_size=256)
        self.pretrained_weights = pretrained_weights
        self.default_weights_save_path = 'weights/wdsr-b-' + \
        str(self.num_res_blocks) + '-x' + str(self.scale) + '.h5'
        self.name = name
        pass 
开发者ID:wmylxmj,项目名称:Anime-Super-Resolution,代码行数:18,代码来源:train.py

示例3: train

# 需要导入模块: import utils [as 别名]
# 或者: from utils import DataLoader [as 别名]
def train(args):
    datasets = range(4)
    # Remove the leaveDataset from datasets
    datasets.remove(args.leaveDataset)

    # Create the data loader object. This object would preprocess the data in terms of
    # batches each of size args.batch_size, of length args.seq_length
    data_loader = DataLoader(args.batch_size, args.seq_length, datasets, forcePreProcess=True)

    # Save the arguments int the config file
    with open(os.path.join('save_lstm', 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # Create a Vanilla LSTM model with the arguments
    model = Model(args)

    # Initialize a TensorFlow session
    with tf.Session() as sess:
        # Initialize all the variables in the graph
        sess.run(tf.initialize_all_variables())
        # Add all the variables to the list of variables to be saved
        saver = tf.train.Saver(tf.all_variables())

        # For each epoch
        for e in range(args.num_epochs):
            # Assign the learning rate (decayed acc. to the epoch number)
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            # Reset the pointers in the data loader object
            data_loader.reset_batch_pointer()
            # Get the initial cell state of the LSTM
            state = sess.run(model.initial_state)

            # For each batch in this epoch
            for b in range(data_loader.num_batches):
                # Tic
                start = time.time()
                # Get the source and target data of the current batch
                # x has the source data, y has the target data
                x, y = data_loader.next_batch()

                # Feed the source, target data and the initial LSTM state to the model
                feed = {model.input_data: x, model.target_data: y, model.initial_state: state}
                # Fetch the loss of the model on this batch, the final LSTM state from the session
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                # Toc
                end = time.time()
                # Print epoch, batch, loss and time taken
                print(
                    "{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(
                        e * data_loader.num_batches + b,
                        args.num_epochs * data_loader.num_batches,
                        e,
                        train_loss, end - start))

                # Save the model if the current epoch and batch number match the frequency
                if (e * data_loader.num_batches + b) % args.save_every == 0 and ((e * data_loader.num_batches + b) > 0):
                    checkpoint_path = os.path.join('save_lstm', 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path)) 
开发者ID:SZamboni,项目名称:Social_lstm_pedestrian_prediction,代码行数:62,代码来源:train.py


注:本文中的utils.DataLoader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。