当前位置: 首页>>代码示例>>Python>>正文


Python data_generator.DataGenerator方法代码示例

本文整理汇总了Python中data_generator.DataGenerator方法的典型用法代码示例。如果您正苦于以下问题:Python data_generator.DataGenerator方法的具体用法?Python data_generator.DataGenerator怎么用?Python data_generator.DataGenerator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_generator的用法示例。


在下文中一共展示了data_generator.DataGenerator方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: init_data_gen

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def init_data_gen(self,
                      data: DataTuple,
                      batch_size: int=64,
                      augmenter: ImageDataAugmenter=None,
                      shuffle: bool=False,
                      debug: bool=False):
        """
        Initialize new data generator object with custom methods that depend on the experiment used. The code assumes
        that the "default" mode is to convert to normalized space the input data, so "norm" methods are used as input
        for the data generator here. If that's not the case, this method is overridden in respective experiments.
        :param data: DataTuple including x, y and feats
        :param batch_size: batch size
        :param augmenter: augmenter object (ImageDataAugmenter)
        :param shuffle: True to shuffle input data
        :param debug: True if debug mode is activated to show augmentation and normalization image results
        """
        datagen = DataGenerator(data.x, data.y, data.feats, batch_size, augmenter, shuffle, debug)
        datagen.set_methods(self.arrange_arrays, self.arrange_label_array, self.look_back_range,
                            self.get_preprocess_info, self.load_image, self.preprocess_input_data_norm,
                            self.preprocess_input_label_norm, self.resize_input_data, self.prepare_tensor_dims,
                            self.normalize_input_data, self.arrange_final_data, self.decide_input_label)
        return datagen 
开发者ID:crisie,项目名称:RecurrentGaze,代码行数:24,代码来源:experiment_helper.py

示例2: __init__

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def __init__(self,
                 args,
                 backbone):
        """Contains the encoder model, the loss function,
            loading of datasets, train and evaluation routines
            to implement IIC unsupervised clustering via mutual
            information maximization

        Arguments:
            args : Command line arguments to indicate choice
                of batch size, number of heads, folder to save
                weights file, weights file name, etc
            backbone (Model): IIC Encoder backbone (eg VGG)
        """
        self.args = args
        self.backbone = backbone
        self._model = None
        self.train_gen = DataGenerator(args, siamese=True)
        self.n_labels = self.train_gen.n_labels
        self.build_model()
        self.load_eval_dataset()
        self.accuracy = 0 
开发者ID:PacktPublishing,项目名称:Advanced-Deep-Learning-with-Keras,代码行数:24,代码来源:iic-13.5.1.py

示例3: resume_train

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def resume_train(self, category, pretrainModel, modelName, initEpoch, batchSize=8, epochs=20):
        self.modelName = modelName
        self.load_model(pretrainModel)
        refineNetflag = True
        self.nStackNum = 2

        modelPath = os.path.dirname(pretrainModel)

        trainDt = DataGenerator(category, os.path.join("../../data/train/Annotations", "train_split.csv"))
        trainGen = trainDt.generator_with_mask_ohem(graph=tf.get_default_graph(), kerasModel=self.model,
                                                    batchSize=batchSize, inputSize=(self.inputHeight, self.inputWidth),
                                                    nStackNum=self.nStackNum, flipFlag=False, cropFlag=False)


        normalizedErrorCallBack = NormalizedErrorCallBack("../../trained_models/", category, refineNetflag, resumeFolder=modelPath)

        csvlogger = CSVLogger(os.path.join(normalizedErrorCallBack.get_folder_path(),
                                           "csv_train_" + self.modelName + "_" + str(
                                               datetime.datetime.now().strftime('%H:%M')) + ".csv"))

        self.model.fit_generator(initial_epoch=initEpoch, generator=trainGen, steps_per_epoch=trainDt.get_dataset_size() // batchSize,
                                 epochs=epochs, callbacks=[normalizedErrorCallBack, csvlogger]) 
开发者ID:yuanyuanli85,项目名称:FashionAI_KeyPoint_Detection_Challenge_Keras,代码行数:24,代码来源:fashion_net.py

示例4: main

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def main(test_desc_file, train_desc_file, load_dir):
    # Prepare the data generator
    datagen = DataGenerator()
    # Load the JSON file that contains the dataset
    datagen.load_test_data(test_desc_file)
    datagen.load_train_data(train_desc_file)
    # Use a few samples from the dataset, to calculate the means and variance
    # of the features, so that we can center our inputs to the network
    datagen.fit_train(100)

    # Compile a Recurrent Network with 1 1D convolution layer, GRU units
    # and 1 fully connected layer
    model = load_model(load_dir)

    # Compile the testing function
    test_fn = compile_test_fn(model)

    # Test the model
    test_loss = test(model, test_fn, datagen)
    print ("Test loss: {}".format(test_loss)) 
开发者ID:baidu-research,项目名称:ba-dls-deepspeech,代码行数:22,代码来源:test.py

示例5: __init__

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def __init__(self, args):
        """Copy user-defined configs.
        Build backbone and fcn network models.
        """
        self.args = args
        self.fcn = None
        self.train_generator = DataGenerator(args)
        self.build_model()
        self.eval_init() 
开发者ID:PacktPublishing,项目名称:Advanced-Deep-Learning-with-Keras,代码行数:11,代码来源:fcn-12.3.1.py

示例6: __init__

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def __init__(self,
                 args,
                 backbone):
        """Contains the encoder, SimpleMINE, and linear 
            classifier models, the loss function,
            loading of datasets, train and evaluation routines
            to implement MINE unsupervised clustering via mutual
            information maximization

        Arguments:
            args : Command line arguments to indicate choice
                of batch size, folder to save
                weights file, weights file name, etc
            backbone (Model): MINE Encoder backbone (eg VGG)
        """
        self.args = args
        self.latent_dim = args.latent_dim
        self.backbone = backbone
        self._model = None
        self._encoder = None
        self.train_gen = DataGenerator(args, 
                                       siamese=True,
                                       mine=True)
        self.n_labels = self.train_gen.n_labels
        self.build_model()
        self.accuracy = 0 
开发者ID:PacktPublishing,项目名称:Advanced-Deep-Learning-with-Keras,代码行数:28,代码来源:mine-13.8.1.py

示例7: train

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def train(self, category, batchSize=8, epochs=20, lrschedule=False):
        trainDt = DataGenerator(category, os.path.join("../../data/train/Annotations", "train_split.csv"))
        trainGen = trainDt.generator_with_mask_ohem( graph=tf.get_default_graph(), kerasModel=self.model,
                                    batchSize= batchSize, inputSize=(self.inputHeight, self.inputWidth),
                                    nStackNum=self.nStackNum, flipFlag=False, cropFlag=False)

        normalizedErrorCallBack = NormalizedErrorCallBack("../../trained_models/", category, True)

        csvlogger = CSVLogger( os.path.join(normalizedErrorCallBack.get_folder_path(),
                               "csv_train_"+self.modelName+"_"+str(datetime.datetime.now().strftime('%H:%M'))+".csv"))

        xcallbacks = [normalizedErrorCallBack, csvlogger]

        self.model.fit_generator(generator=trainGen, steps_per_epoch=trainDt.get_dataset_size()//batchSize,
                                 epochs=epochs,  callbacks=xcallbacks) 
开发者ID:yuanyuanli85,项目名称:FashionAI_KeyPoint_Detection_Challenge_Keras,代码行数:17,代码来源:fashion_net.py

示例8: test

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def test(model, test_fn, datagen, mb_size=16, conv_context=11,
         conv_border_mode='valid', conv_stride=2):
    """ Testing routine for speech-models
    Params:
        model (keras.model): Constructed keras model
        test_fn (theano.function): A theano function that calculates the cost
            over a test set
        datagen (DataGenerator)
        mb_size (int): Size of each minibatch
        conv_context (int): Convolution context
        conv_border_mode (str): Convolution border mode
        conv_stride (int): Convolution stride
    Returns:
        test_cost (float): Average test cost over the whole test set
    """
    avg_cost = 0.0
    i = 0
    for batch in datagen.iterate_test(mb_size):
        inputs = batch['x']
        labels = batch['y']
        input_lengths = batch['input_lengths']
        label_lengths = batch['label_lengths']
        ground_truth = batch['texts']
        # Due to convolution, the number of timesteps of the output
        # is different from the input length. Calculate the resulting
        # timesteps
        output_lengths = [conv_output_length(l, conv_context,
                                             conv_border_mode, conv_stride)
                          for l in input_lengths]
        predictions, ctc_cost = test_fn([inputs, output_lengths, labels,
                                        label_lengths, True])
        predictions = np.swapaxes(predictions, 0, 1)
        for i, prediction in enumerate(predictions):
            print ("Truth: {}, Prediction: {}"
                   .format(ground_truth[i], argmax_decode(prediction)))
        avg_cost += ctc_cost
        i += 1
    return avg_cost / i 
开发者ID:baidu-research,项目名称:ba-dls-deepspeech,代码行数:40,代码来源:test.py

示例9: visualize

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def visualize(model, test_file, train_desc_file):
    """ Get the prediction using the model, and visualize softmax outputs
    Params:
        model (keras.models.Model): Trained speech model
        test_file (str): Path to an audio clip
        train_desc_file(str): Path to the training file used to train this
                              model
    """
    datagen = DataGenerator()
    datagen.load_train_data(train_desc_file)
    datagen.fit_train(100)

    print ("Compiling test function...")
    test_fn = compile_output_fn(model)

    inputs = [datagen.featurize(test_file)]

    prediction = np.squeeze(test_fn([inputs, True]))
    softmax_file = "softmax.npy".format(test_file)
    softmax_img_file = "softmax.png".format(test_file)
    print ("Prediction: {}"
           .format(argmax_decode(prediction)))
    print ("Saving network output to: {}".format(softmax_file))
    print ("As image: {}".format(softmax_img_file))
    np.save(softmax_file, prediction)
    sm = softmax(prediction.T)
    sm = np.vstack((sm[0], sm[2], sm[3:][::-1]))
    fig, ax = plt.subplots()
    ax.pcolor(sm, cmap=plt.cm.Greys_r)
    column_labels = [chr(i) for i in range(97, 97 + 26)] + ['space', 'blank']
    ax.set_yticks(np.arange(sm.shape[0]) + 0.5, minor=False)
    ax.set_yticklabels(column_labels[::-1], minor=False)
    plt.savefig(softmax_img_file) 
开发者ID:baidu-research,项目名称:ba-dls-deepspeech,代码行数:35,代码来源:visualize.py

示例10: validation

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def validation(model, val_fn, datagen, mb_size=16):
    """ Validation routine for speech-models
    Params:
        model (keras.model): Constructed keras model
        val_fn (theano.function): A theano function that calculates the cost
            over a validation set
        datagen (DataGenerator)
        mb_size (int): Size of each minibatch
    Returns:
        val_cost (float): Average validation cost over the whole validation set
    """
    avg_cost = 0.0
    i = 0
    for batch in datagen.iterate_validation(mb_size):
        inputs = batch['x']
        labels = batch['y']
        input_lengths = batch['input_lengths']
        label_lengths = batch['label_lengths']
        # Due to convolution, the number of timesteps of the output
        # is different from the input length. Calculate the resulting
        # timesteps
        output_lengths = [model.conv_output_length(l)
                          for l in input_lengths]
        _, ctc_cost = val_fn([inputs, output_lengths, labels,
                              label_lengths, True])
        avg_cost += ctc_cost
        i += 1
    if i == 0:
        return 0.0
    return avg_cost / i 
开发者ID:baidu-research,项目名称:ba-dls-deepspeech,代码行数:32,代码来源:train.py

示例11: __init__

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def __init__(self,
                 name: str=None,
                 description: str=None,
                 weights: str=None,
                 train: bool=False,
                 base_model: BaseModel=None,
                 model=None,
                 fc_dimensions: int=4096,
                 label_pos: int=-1,
                 look_back: int=1,
                 n_output: int=2,
                 recurrent_type: str="lstm",
                 num_recurrent_layers: int=1,
                 num_recurrent_units: int=128,
                 train_data_generator: DataGenerator=None,
                 val_data_generator: DataGenerator=None):
        """
        Initialize ExperimentHelper class.
        :param name: name of experiment
        :param description: description of experiment
        :param weights: weights of model (in case it has been already trained)
        :param train: True if training is activated
        :param base_model: base model used (for instance, VGGFace)
        :param model: model architecture type
        :param fc_dimensions: dimensions of FC layers
        :param label_pos: label position
        :param look_back: sequence length
        :param n_output: number of outputs of model
        :param recurrent_type: type of recurrent network (gru or lstm)
        :param num_recurrent_layers: number of recurrent layers
        :param num_recurrent_units: number of recurrent units
        :param train_data_generator: DataGenerator for training
        :param val_data_generator: DataGenerator for validation/test (in case there is any)
        """
        self.name = name
        self.description = description
        self.weights = weights
        self.train = train
        self.base_model = base_model
        self.model = model
        self.fc_dimensions = fc_dimensions
        self.label_pos = label_pos
        self.n_output = n_output
        # --- temporal options ---
        self.look_back = look_back
        self.recurrent_type = recurrent_type
        self.num_recurrent_layers = num_recurrent_layers
        self.num_recurrent_units = num_recurrent_units
        # --- other ---
        self.train_data_generator = train_data_generator
        self.val_data_generator = val_data_generator 
开发者ID:crisie,项目名称:RecurrentGaze,代码行数:53,代码来源:experiment_helper.py

示例12: train

# 需要导入模块: import data_generator [as 别名]
# 或者: from data_generator import DataGenerator [as 别名]
def train(model, train_fn, val_fn, datagen, save_dir, epochs=10, mb_size=16,
          do_sortagrad=True):
    """ Main training routine for speech-models
    Params:
        model (keras.model): Constructed keras model
        train_fn (theano.function): A theano function that takes in acoustic
            inputs and updates the model
        val_fn (theano.function): A theano function that calculates the cost
            over a validation set
        datagen (DataGenerator)
        save_dir (str): Path where model and costs are saved
        epochs (int): Total epochs to continue training
        mb_size (int): Size of each minibatch
        do_sortagrad (bool): If true, we sort utterances by their length in the
            first epoch
    """
    train_costs, val_costs = [], []
    iters = 0
    for e in range(epochs):
        if do_sortagrad:
            shuffle = e != 0
            sortagrad = e == 0
        else:
            shuffle = True
            sortagrad = False
        for i, batch in \
                enumerate(datagen.iterate_train(mb_size, shuffle=shuffle,
                                                sort_by_duration=sortagrad)):
            inputs = batch['x']
            labels = batch['y']
            input_lengths = batch['input_lengths']
            label_lengths = batch['label_lengths']
            # Due to convolution, the number of timesteps of the output
            # is different from the input length. Calculate the resulting
            # timesteps
            output_lengths = [model.conv_output_length(l)
                              for l in input_lengths]
            _, ctc_cost = train_fn([inputs, output_lengths, labels,
                                    label_lengths, True])
            train_costs.append(ctc_cost)
            if i % 10 == 0:
                logger.info("Epoch: {}, Iteration: {}, Loss: {}"
                            .format(e, i, ctc_cost, input_lengths))
            iters += 1
            if iters % 500 == 0:
                val_cost = validation(model, val_fn, datagen, mb_size)
                val_costs.append(val_cost)
                save_model(save_dir, model, train_costs, val_costs, iters)
        if iters % 500 != 0:
            # End of an epoch. Check validation cost and save costs
            val_cost = validation(model, val_fn, datagen, mb_size)
            val_costs.append(val_cost)
            save_model(save_dir, model, train_costs, val_costs, iters) 
开发者ID:baidu-research,项目名称:ba-dls-deepspeech,代码行数:55,代码来源:train.py


注:本文中的data_generator.DataGenerator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。