当前位置: 首页>>代码示例>>Python>>正文


Python util.load_data方法代码示例

本文整理汇总了Python中util.load_data方法的典型用法代码示例。如果您正苦于以下问题:Python util.load_data方法的具体用法?Python util.load_data怎么用?Python util.load_data使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在util的用法示例。


在下文中一共展示了util.load_data方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import util [as 别名]
# 或者: from util import load_data [as 别名]
def main():
	# img_width, img_height = 48, 48
	img_width, img_height = 200, 60
	img_channels = 1 
	# batch_size = 1024
	batch_size = 32
	nb_epoch = 1000
	post_correction = False

	save_dir = 'save_model/' + str(datetime.now()).split('.')[0].split()[0] + '/' # model is saved corresponding to the datetime
	train_data_dir = 'train_data/ip_train/'
	# train_data_dir = 'train_data/single_1000000/'
	val_data_dir = 'train_data/ip_val/'
	test_data_dir = 'test_data//'
	weights_file_path = 'save_model/2016-10-27/weights.11-1.58.hdf5'
	char_set, char2idx = get_char_set(train_data_dir)
	nb_classes = len(char_set)
	max_nb_char = get_maxnb_char(train_data_dir)
	label_set = get_label_set(train_data_dir)
	# val 'char_set:', char_set
	print 'nb_classes:', nb_classes
	print 'max_nb_char:', max_nb_char
	print 'size_label_set:', len(label_set)
	model = build_shallow(img_channels, img_width, img_height, max_nb_char, nb_classes) # build CNN architecture
	# model.load_weights(weights_file_path) # load trained model

	val_data = load_data(val_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
	# val_data = None 
	train_data = load_data(train_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx) 
	train(model, batch_size, nb_epoch, save_dir, train_data, val_data, char_set)

	# train_data = load_data(train_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
	# test(model, train_data, char_set, label_set, post_correction)
	# val_data = load_data(val_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
	# test(model, val_data, char_set, label_set, post_correction)
	# test_data = load_data(test_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
	# test(model, test_data, char_set, label_set, post_correction) 
开发者ID:xingjian-f,项目名称:DeepLearning-OCR,代码行数:39,代码来源:train.py

示例2: data_prepare

# 需要导入模块: import util [as 别名]
# 或者: from util import load_data [as 别名]
def data_prepare(print_image_shape=False, print_input_shape=False):
    """
    prepare data for model.
    :param print_image_shape: print image shape if set true.
    :param print_input_shape: print input shape(after categorize) if set true
    :return: list of input to model
    """
    def reshape_mask(origin, cate, num_class):
        return cate.reshape((origin.shape[0], origin.shape[1], origin.shape[2], num_class))

    train_imgs, train_det_masks, train_cls_masks = load_data(data_path=DATA_DIR, type='train')
    valid_imgs, valid_det_masks, valid_cls_masks = load_data(data_path=DATA_DIR, type='validation')
    test_imgs, test_det_masks, test_cls_masks = load_data(data_path=DATA_DIR, type='test')

    if print_image_shape:
        print('Image shape print below: ')
        print('train_imgs: {}, train_det_masks: {}, train_cls_masks: {}'.format(train_imgs.shape, train_det_masks.shape,
                                                                                train_cls_masks.shape))
        print('valid_imgs: {}, valid_det_masks: {}, validn_cls_masks: {}'.format(valid_imgs.shape, valid_det_masks.shape, valid_cls_masks.shape))
        print('test_imgs: {}, test_det_masks: {}, test_cls_masks: {}'.format(test_imgs.shape, test_det_masks.shape, test_cls_masks.shape))
        print()

    train_det = np_utils.to_categorical(train_det_masks, 2)
    train_det = reshape_mask(train_det_masks, train_det, 2)
    train_cls = np_utils.to_categorical(train_cls_masks, 5)
    train_cls = reshape_mask(train_cls_masks, train_cls, 5)

    valid_det = np_utils.to_categorical(valid_det_masks, 2)
    valid_det = reshape_mask(valid_det_masks, valid_det, 2)
    valid_cls = np_utils.to_categorical(valid_cls_masks, 5)
    valid_cls = reshape_mask(valid_cls_masks, valid_cls, 5)

    test_det = np_utils.to_categorical(test_det_masks, 2)
    test_det = reshape_mask(test_det_masks, test_det, 2)
    test_cls = np_utils.to_categorical(test_cls_masks, 5)
    test_cls = reshape_mask(test_cls_masks, test_cls, 5)

    if print_input_shape:
        print('input shape print below: ')
        print('train_imgs: {}, train_det: {}, train_cls: {}'.format(train_imgs.shape, train_det.shape, train_cls.shape))
        print('valid_imgs: {}, valid_det: {}, validn_cls: {}'.format(valid_imgs.shape, valid_det.shape, valid_cls.shape))
        print('test_imgs: {}, test_det: {}, test_cls: {}'.format(test_imgs.shape, test_det.shape, test_cls.shape))
        print()
    return [train_imgs, train_det, train_cls, valid_imgs, valid_det, valid_cls, test_imgs, test_det, test_cls] 
开发者ID:zhuyiche,项目名称:sfcn-opi,代码行数:46,代码来源:model.py


注:本文中的util.load_data方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。