本文整理汇总了Python中utils.DataLoader方法的典型用法代码示例。如果您正苦于以下问题:Python utils.DataLoader方法的具体用法?Python utils.DataLoader怎么用?Python utils.DataLoader使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils
的用法示例。
在下文中一共展示了utils.DataLoader方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import utils [as 别名]
# 或者: from utils import DataLoader [as 别名]
def __init__(self):
self.nH = 256
self.nW = 256
self.nC = 3
self.data_loader = DataLoader()
self.image_shape = (self.nH, self.nW, self.nC)
self.image_A = Input(shape=self.image_shape)
self.image_B = Input(shape=self.image_shape)
self.discriminator = self.creat_discriminator()
self.discriminator.compile(loss='mse', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
self.generator = self.creat_generator()
self.fake_A = self.generator(self.image_B)
self.discriminator.trainable = False
self.valid = self.discriminator([self.fake_A, self.image_B])
self.combined = Model(inputs=[self.image_A, self.image_B], outputs=[self.valid, self.fake_A])
self.combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=Adam(0.0002, 0.5))
# Calculate output shape of D (PatchGAN)
self.disc_patch = (int(self.nH/2**4), int(self.nW/2**4), 1)
pass
示例2: __init__
# 需要导入模块: import utils [as 别名]
# 或者: from utils import DataLoader [as 别名]
def __init__(self, scale=4, num_res_blocks=32, pretrained_weights=None, name=None):
self.scale = scale
self.num_res_blocks = num_res_blocks
self.model = wdsr_b(scale=scale, num_res_blocks=num_res_blocks)
self.model.compile(optimizer=AdamWithWeightsNormalization(lr=0.001), \
loss=self.mae, metrics=[self.psnr])
if pretrained_weights != None:
self.model.load_weights(pretrained_weights)
print("[OK] weights loaded.")
pass
self.data_loader = DataLoader(scale=scale, crop_size=256)
self.pretrained_weights = pretrained_weights
self.default_weights_save_path = 'weights/wdsr-b-' + \
str(self.num_res_blocks) + '-x' + str(self.scale) + '.h5'
self.name = name
pass
示例3: train
# 需要导入模块: import utils [as 别名]
# 或者: from utils import DataLoader [as 别名]
def train(args):
datasets = range(4)
# Remove the leaveDataset from datasets
datasets.remove(args.leaveDataset)
# Create the data loader object. This object would preprocess the data in terms of
# batches each of size args.batch_size, of length args.seq_length
data_loader = DataLoader(args.batch_size, args.seq_length, datasets, forcePreProcess=True)
# Save the arguments int the config file
with open(os.path.join('save_lstm', 'config.pkl'), 'wb') as f:
pickle.dump(args, f)
# Create a Vanilla LSTM model with the arguments
model = Model(args)
# Initialize a TensorFlow session
with tf.Session() as sess:
# Initialize all the variables in the graph
sess.run(tf.initialize_all_variables())
# Add all the variables to the list of variables to be saved
saver = tf.train.Saver(tf.all_variables())
# For each epoch
for e in range(args.num_epochs):
# Assign the learning rate (decayed acc. to the epoch number)
sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
# Reset the pointers in the data loader object
data_loader.reset_batch_pointer()
# Get the initial cell state of the LSTM
state = sess.run(model.initial_state)
# For each batch in this epoch
for b in range(data_loader.num_batches):
# Tic
start = time.time()
# Get the source and target data of the current batch
# x has the source data, y has the target data
x, y = data_loader.next_batch()
# Feed the source, target data and the initial LSTM state to the model
feed = {model.input_data: x, model.target_data: y, model.initial_state: state}
# Fetch the loss of the model on this batch, the final LSTM state from the session
train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
# Toc
end = time.time()
# Print epoch, batch, loss and time taken
print(
"{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
.format(
e * data_loader.num_batches + b,
args.num_epochs * data_loader.num_batches,
e,
train_loss, end - start))
# Save the model if the current epoch and batch number match the frequency
if (e * data_loader.num_batches + b) % args.save_every == 0 and ((e * data_loader.num_batches + b) > 0):
checkpoint_path = os.path.join('save_lstm', 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
print("model saved to {}".format(checkpoint_path))