当前位置: 首页>>代码示例>>Python>>正文


Python persist.save_obj函数代码示例

本文整理汇总了Python中neon.util.persist.save_obj函数的典型用法代码示例。如果您正苦于以下问题:Python save_obj函数的具体用法?Python save_obj怎么用?Python save_obj使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了save_obj函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: save_history

    def save_history(self, epoch, model):
        # if history > 1, this function will save the last N checkpoints
        # where N is equal to self.history.  The files will have the form
        # of save_path with the epoch added to the filename before the ext

        if len(self.checkpoint_files) > self.history:
            # remove oldest checkpoint file when max count have been saved
            fn = self.checkpoint_files.popleft()
            try:
                os.remove(fn)
                logger.info("removed old checkpoint %s" % fn)
            except OSError:
                logger.warn("Could not delete old checkpoint file %s" % fn)

        path_split = os.path.splitext(self.save_path)
        save_path = "%s_%d%s" % (path_split[0], epoch, path_split[1])
        # add the current file to the deque
        self.checkpoint_files.append(save_path)
        save_obj(model.serialize(keep_states=True), save_path)

        # maintain a symlink pointing to the latest model params
        try:
            if os.path.islink(self.save_path):
                os.remove(self.save_path)
            os.symlink(os.path.split(save_path)[-1], self.save_path)
        except OSError:
            logger.warn("Could not create latest model symlink %s -> %s" % (self.save_path, save_path))
开发者ID:yapjiaqing,项目名称:neon,代码行数:27,代码来源:callbacks.py

示例2: on_epoch_end

 def on_epoch_end(self, callback_data, model, epoch):
     _eil = self._get_cached_epoch_loss(callback_data, model, epoch, "loss")
     if _eil:
         if _eil["cost"] < self.best_cost or self.best_cost is None:
             # TODO: switch this to a general seralization op
             save_obj(model.serialize(keep_states=True), self.best_path)
             self.best_cost = _eil["cost"]
开发者ID:yapjiaqing,项目名称:neon,代码行数:7,代码来源:callbacks.py

示例3: serialize

    def serialize(self, fn=None, keep_states=True):
        """
        Creates a dictionary storing the layer parameters and epochs complete.

        Arguments:
            fn (str): file to save pkl formatted model dictionary
            keep_states (bool): Whether to save optimizer states.

        Returns:
            dict: Model data including layer parameters and epochs complete.
        """

        # get the model dict with the weights
        pdict = self.get_description(get_weights=True, keep_states=keep_states)
        pdict['epoch_index'] = self.epoch_index + 1
        if self.initialized:
            if not hasattr(self.layers, 'decoder'):
                pdict['train_input_shape'] = self.layers.in_shape
            else:
                # serialize shapes both for encoder and decoder
                pdict['train_input_shape'] = (self.layers.encoder.in_shape +
                                              self.layers.decoder.in_shape)
        if fn is not None:
            save_obj(pdict, fn)
            return
        return pdict
开发者ID:rlugojr,项目名称:neon,代码行数:26,代码来源:model.py

示例4: save_params

    def save_params(self, param_path, keep_states=True):
        """
        Serializes and saves model parameters to the path specified.

        Arguments:
            param_path (str): File to write serialized parameter dict to.
            keep_states (bool): Whether to save optimizer states too.
                                Defaults to True.
        """
        save_obj(self.serialize(keep_states), param_path)
开发者ID:bin2000,项目名称:neon,代码行数:10,代码来源:model.py

示例5: on_epoch_end

    def on_epoch_end(self, epoch):

        if 'cost/validation' in self.callback_data:
            val_freq = self.callback_data['cost/validation'].attrs['epoch_freq']
            if (epoch + 1) % val_freq == 0:
                validation_cost = self.callback_data['cost/validation'][epoch/val_freq]

                if validation_cost < self.best_cost or self.best_cost is None:
                    save_obj(self.model.serialize(keep_states=True), self.best_path)
                    self.best_cost = validation_cost
开发者ID:rupertsmall,项目名称:neon,代码行数:10,代码来源:callbacks.py

示例6: save_meta

 def save_meta(self):
     save_obj({'ntrain': self.ntrain,
               'nval': self.nval,
               'train_start': self.train_start,
               'val_start': self.val_start,
               'macro_size': self.macro_size,
               'batch_prefix': self.batch_prefix,
               'global_mean': self.global_mean,
               'label_dict': self.label_dict,
               'label_names': self.label_names,
               'val_nrec': self.val_nrec,
               'train_nrec': self.train_nrec,
               'img_size': self.target_size,
               'nclass': self.nclass}, self.meta_file)
开发者ID:GerritKlaschke,项目名称:neon,代码行数:14,代码来源:batch_writer.py

示例7: on_sigint_catch

    def on_sigint_catch(self, epoch, minibatch):
        """
        Callback to handle SIGINT events

        Arguments:
            epoch (int): index of current epoch
            minibatch (int): index of minibatch that is ending
        """
        # restore the orignal handler
        signal.signal(signal.SIGINT, signal.SIG_DFL)

        # save the model
        if self.save_path is not None:
            save_obj(self.model().serialize(keep_states=True), self.save_path)
            raise KeyboardInterrupt("Checkpoint file saved to {0}".format(self.save_path))
        else:
            raise KeyboardInterrupt
开发者ID:yapjiaqing,项目名称:neon,代码行数:17,代码来源:callbacks.py

示例8: serialize

    def serialize(self, fn=None, keep_states=True):
        """
        Creates a dictionary storing the layer parameters and epochs complete.

        Arguments:
            fn (str): file to save pkl formatted model dictionary
            keep_states (bool): Whether to save optimizer states.

        Returns:
            dict: Model data including layer parameters and epochs complete.
        """

        # get the model dict with the weights
        pdict = self.get_description(get_weights=True, keep_states=keep_states)
        pdict['epoch_index'] = self.epoch_index + 1
        if fn is not None:
            save_obj(pdict, fn)
            return
        return pdict
开发者ID:maony,项目名称:neon,代码行数:19,代码来源:model.py

示例9: save_meta

 def save_meta(self):
     save_obj(
         {
             "ntrain": self.ntrain,
             "nval": self.nval,
             "train_start": self.train_start,
             "val_start": self.val_start,
             "macro_size": self.macro_size,
             "batch_prefix": self.batch_prefix,
             "global_mean": self.global_mean,
             "label_dict": self.label_dict,
             "label_names": self.label_names,
             "val_nrec": self.val_nrec,
             "train_nrec": self.train_nrec,
             "img_size": self.target_size,
             "nclass": self.nclass,
         },
         self.meta_file,
     )
开发者ID:hgl888,项目名称:neon,代码行数:19,代码来源:batch_writer.py

示例10: save_history

    def save_history(self, epoch):
        # if history > 1, this function will save the last N checkpoints
        # where N is equal to self.history.  The files will have the form
        # of save_path with the epoch added to the filename before the ext

        if len(self.checkpoint_files) > self.history:
            # remove oldest checkpoint file when max count have been saved
            fn = self.checkpoint_files.popleft()
            try:
                os.remove(fn)
                logger.info('removed old checkpoint %s' % fn)
            except OSError:
                logger.warn('Could not delete old checkpoint file %s' % fn)

        path_split = os.path.splitext(self.save_path)
        save_path = '%s_%d%s' % (path_split[0], epoch, path_split[1])
        # add the current file to the deque
        self.checkpoint_files.append(save_path)
        save_obj(self.model.serialize(keep_states=True), save_path)
开发者ID:rupertsmall,项目名称:neon,代码行数:19,代码来源:callbacks.py

示例11: get_w2v_vocab

def get_w2v_vocab(fname, max_vocab_size, cache=True):
    """
    Get ordered dict of vocab from google word2vec
    """
    if cache:
        cache_fname = fname.split('.')[0] + ".vocab"

        if os.path.isfile(cache_fname):
            vocab, vocab_size = load_obj(cache_fname)
            neon_logger.display("Word2Vec vocab cached, size is: {}".format(vocab_size))
            return vocab, vocab_size

    with open(fname, 'rb') as f:
        header = f.readline()
        vocab_size, embed_dim = map(int, header.split())
        binary_len = np.dtype('float32').itemsize * embed_dim

        neon_logger.display("Word2Vec vocab size is: {}".format(vocab_size))
        vocab_size = min(max_vocab_size, vocab_size)
        neon_logger.display("Reducing vocab size to: {}".format(vocab_size))

        vocab = OrderedDict()

        for i, line in enumerate(range(vocab_size)):
            word = []
            while True:
                ch = f.read(1)
                if ch == b' ':
                    word = (b''.join(word)).decode('utf-8')
                    break
                if ch != b'\n':
                    word.append(ch)
            f.read(binary_len)
            vocab[word] = i

    if cache:
        save_obj((vocab, vocab_size), cache_fname)

    return vocab, vocab_size
开发者ID:rlugojr,项目名称:neon,代码行数:39,代码来源:util.py

示例12: PolySchedule

lr_sched = PolySchedule(total_epochs=10, power=0.5)
opt_gdm = GradientDescentMomentum(0.01, 0.9, wdecay=0.0002, schedule=lr_sched)
opt_biases = GradientDescentMomentum(0.02, 0.9, schedule=lr_sched)

opt = MultiOptimizer({'default': opt_gdm, 'Bias': opt_biases})
if not args.resume:
    # fit the model for 3 epochs
    model.fit(train, optimizer=opt, num_epochs=3, cost=cost, callbacks=callbacks)

train.reset()
# get 1 image
for im, l in train:
    break
train.exit_batch_provider()
save_obj((im.get(), l.get()), 'im1.pkl')
im_save = im.get().copy()
if args.resume:
    (im2, l2) = load_obj('im1.pkl')
    im.set(im2)
    l.set(l2)

# run fprop and bprop on this minibatch save the results
out_fprop = model.fprop(im)

out_fprop_save = [x.get() for x in out_fprop]
im.set(im_save)
out_fprop = model.fprop(im)
out_fprop_save2 = [x.get() for x in out_fprop]
for x, y in zip(out_fprop_save, out_fprop_save2):
    assert np.max(np.abs(x - y)) == 0.0, '2 fprop iterations do not match'
开发者ID:JediKoder,项目名称:neon,代码行数:30,代码来源:inception.py

示例13: IOError

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("cache_file", help="path to data cache file")
    args = parser.parse_args()

    cache_file = args.cache_file

    # check for RW access to file
    assert os.path.exists(cache_file), "file does not exist %s" % cache_file
    if not os.access(os.path.abspath(cache_file), os.R_OK | os.W_OK):
        raise IOError("Need to add read and/or write permissions on file %s" % cache_file)

    dc = load_obj(cache_file)

    if "global_mean" not in dc or "img_size" not in dc:
        raise ValueError("data cache file missing global_mean key")

    sz = dc["img_size"]
    gm = dc["global_mean"]

    if len(gm.shape) != 2 or (gm.shape[0] != sz * sz * 3 or gm.shape[1] != 1):
        raise ValueError("global mean shape {} does not match format expected".format(gm.shape))

    # Collapse the full tensor mean into channel means and correct the order (RGB <-> BGR)
    dc["global_mean"] = np.mean(gm.reshape(3, -1), axis=1).reshape(3, 1)[::-1]

    save_obj(dc, cache_file)

    neon_logger.display("%s updated to new format" % cache_file)
开发者ID:Jokeren,项目名称:neon,代码行数:29,代码来源:update_dataset_cache.py

示例14: load_data


#.........这里部分代码省略.........
        neon_logger.display("open existing vocab file: {}".format(vocab_file_name))
        vocab, rev_vocab, word_count = load_obj(vocab_file_name)
    else:
        neon_logger.display("Building  vocab file")

        # build vocab
        word_count = defaultdict(int)
        for sent in all_sent:
            sent_words = tokenize(sent)

            if len(sent_words) > max_len_w or len(sent_words) == 0:
                continue

            for word in sent_words:
                word_count[word] += 1

        # sort the word_count , re-assign ids by its frequency. Useful for downstream tasks
        # only done for train vocab
        vocab_sorted = sorted(word_count.items(), key=lambda kv: kv[1], reverse=True)

        vocab = OrderedDict()

        # get word count as array in same ordering as vocab (but with maximum length)
        word_count_ = np.zeros((len(word_count), ), dtype=np.int64)
        for i, t in enumerate(list(zip(*vocab_sorted))[0][:max_vocab_size]):
            word_count_[i] = word_count[t]
            vocab[t] = i
        word_count = word_count_

        # generate the reverse vocab
        rev_vocab = dict((wrd_id, wrd) for wrd, wrd_id in vocab.items())

        neon_logger.display("vocabulary from {} is saved into {}".format(path, vocab_file_name))
        save_obj((vocab, rev_vocab, word_count), vocab_file_name)

    vocab_size = len(vocab)
    neon_logger.display("\nVocab size from the dataset is: {}".format(vocab_size))

    neon_logger.display("\nProcessing and saving training data into {}".format(h5_file_name))

    # now process and save the train/valid data
    h5f = h5py.File(h5_file_name, 'w', libver='latest')
    shape, maxshape = (len(train_sent),), (None)
    dt = np.dtype([('text', h5py.special_dtype(vlen=str)),
                   ('num_words', np.uint16)])
    report_text_train = h5f.create_dataset('report_train', shape=shape,
                                           maxshape=maxshape, dtype=dt,
                                           compression='gzip')
    report_train = h5f.create_dataset('train', shape=shape, maxshape=maxshape,
                                      dtype=h5py.special_dtype(vlen=np.int32),
                                      compression='gzip')

    # map text to integers
    wdata = np.zeros((1, ), dtype=dt)
    ntrain = 0
    for sent in train_sent:
        text_int = [-1 if t not in vocab else vocab[t] for t in tokenize(sent)]

        # enforce maximum sentence length
        if len(text_int) > max_len_w or len(text_int) == 0:
            continue

        report_train[ntrain] = text_int

        wdata['text'] = clean_string(sent)
        wdata['num_words'] = len(text_int)
开发者ID:NervanaSystems,项目名称:neon,代码行数:67,代码来源:data_loader.py

示例15: int

        (im_shape, im_scale, gt_boxes, gt_classes,
            num_gt_boxes, difficult) = valid_set.get_metadata_buffers()

        num_gt_boxes = int(num_gt_boxes.get())
        im_scale = float(im_scale.get())

        # retrieve region proposals generated by the model
        (proposals, num_proposals) = proposalLayer.get_proposals()

        # convert outputs to bounding boxes
        boxes = faster_rcnn.get_bboxes(outputs, proposals, num_proposals, num_classes,
                                       im_shape.get(), im_scale, max_per_image, thresh, nms_thresh)

        all_boxes[mb_idx] = boxes

        # retrieve gt boxes
        # we add a extra column to track detections during the AP calculation
        detected = np.array([False] * num_gt_boxes)
        gt_boxes = np.hstack([gt_boxes.get()[:num_gt_boxes] / im_scale,
                              gt_classes.get()[:num_gt_boxes],
                              difficult.get()[:num_gt_boxes], detected[:, np.newaxis]])

        all_gt_boxes[mb_idx] = gt_boxes

neon_logger.display('Evaluating detections')
avg_precision = voc_eval(all_boxes, all_gt_boxes, valid_set.CLASSES, use_07_metric=True)

if args.output is not None:
    neon_logger.display('Saving inference results to {}'.format(args.output))
    save_obj([all_boxes, avg_precision], args.output)
开发者ID:NervanaSystems,项目名称:neon,代码行数:30,代码来源:inference.py


注:本文中的neon.util.persist.save_obj函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。