本文整理汇总了Python中util.load_data方法的典型用法代码示例。如果您正苦于以下问题:Python util.load_data方法的具体用法?Python util.load_data怎么用?Python util.load_data使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类util
的用法示例。
在下文中一共展示了util.load_data方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import util [as 别名]
# 或者: from util import load_data [as 别名]
def main():
# img_width, img_height = 48, 48
img_width, img_height = 200, 60
img_channels = 1
# batch_size = 1024
batch_size = 32
nb_epoch = 1000
post_correction = False
save_dir = 'save_model/' + str(datetime.now()).split('.')[0].split()[0] + '/' # model is saved corresponding to the datetime
train_data_dir = 'train_data/ip_train/'
# train_data_dir = 'train_data/single_1000000/'
val_data_dir = 'train_data/ip_val/'
test_data_dir = 'test_data//'
weights_file_path = 'save_model/2016-10-27/weights.11-1.58.hdf5'
char_set, char2idx = get_char_set(train_data_dir)
nb_classes = len(char_set)
max_nb_char = get_maxnb_char(train_data_dir)
label_set = get_label_set(train_data_dir)
# val 'char_set:', char_set
print 'nb_classes:', nb_classes
print 'max_nb_char:', max_nb_char
print 'size_label_set:', len(label_set)
model = build_shallow(img_channels, img_width, img_height, max_nb_char, nb_classes) # build CNN architecture
# model.load_weights(weights_file_path) # load trained model
val_data = load_data(val_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# val_data = None
train_data = load_data(train_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
train(model, batch_size, nb_epoch, save_dir, train_data, val_data, char_set)
# train_data = load_data(train_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# test(model, train_data, char_set, label_set, post_correction)
# val_data = load_data(val_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# test(model, val_data, char_set, label_set, post_correction)
# test_data = load_data(test_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# test(model, test_data, char_set, label_set, post_correction)
示例2: data_prepare
# 需要导入模块: import util [as 别名]
# 或者: from util import load_data [as 别名]
def data_prepare(print_image_shape=False, print_input_shape=False):
"""
prepare data for model.
:param print_image_shape: print image shape if set true.
:param print_input_shape: print input shape(after categorize) if set true
:return: list of input to model
"""
def reshape_mask(origin, cate, num_class):
return cate.reshape((origin.shape[0], origin.shape[1], origin.shape[2], num_class))
train_imgs, train_det_masks, train_cls_masks = load_data(data_path=DATA_DIR, type='train')
valid_imgs, valid_det_masks, valid_cls_masks = load_data(data_path=DATA_DIR, type='validation')
test_imgs, test_det_masks, test_cls_masks = load_data(data_path=DATA_DIR, type='test')
if print_image_shape:
print('Image shape print below: ')
print('train_imgs: {}, train_det_masks: {}, train_cls_masks: {}'.format(train_imgs.shape, train_det_masks.shape,
train_cls_masks.shape))
print('valid_imgs: {}, valid_det_masks: {}, validn_cls_masks: {}'.format(valid_imgs.shape, valid_det_masks.shape, valid_cls_masks.shape))
print('test_imgs: {}, test_det_masks: {}, test_cls_masks: {}'.format(test_imgs.shape, test_det_masks.shape, test_cls_masks.shape))
print()
train_det = np_utils.to_categorical(train_det_masks, 2)
train_det = reshape_mask(train_det_masks, train_det, 2)
train_cls = np_utils.to_categorical(train_cls_masks, 5)
train_cls = reshape_mask(train_cls_masks, train_cls, 5)
valid_det = np_utils.to_categorical(valid_det_masks, 2)
valid_det = reshape_mask(valid_det_masks, valid_det, 2)
valid_cls = np_utils.to_categorical(valid_cls_masks, 5)
valid_cls = reshape_mask(valid_cls_masks, valid_cls, 5)
test_det = np_utils.to_categorical(test_det_masks, 2)
test_det = reshape_mask(test_det_masks, test_det, 2)
test_cls = np_utils.to_categorical(test_cls_masks, 5)
test_cls = reshape_mask(test_cls_masks, test_cls, 5)
if print_input_shape:
print('input shape print below: ')
print('train_imgs: {}, train_det: {}, train_cls: {}'.format(train_imgs.shape, train_det.shape, train_cls.shape))
print('valid_imgs: {}, valid_det: {}, validn_cls: {}'.format(valid_imgs.shape, valid_det.shape, valid_cls.shape))
print('test_imgs: {}, test_det: {}, test_cls: {}'.format(test_imgs.shape, test_det.shape, test_cls.shape))
print()
return [train_imgs, train_det, train_cls, valid_imgs, valid_det, valid_cls, test_imgs, test_det, test_cls]