当前位置: 首页>>代码示例>>Python>>正文


Python input_data.read_data_sets方法代码示例

本文整理汇总了Python中tensorflow.examples.tutorials.mnist.input_data.read_data_sets方法的典型用法代码示例。如果您正苦于以下问题:Python input_data.read_data_sets方法的具体用法?Python input_data.read_data_sets怎么用?Python input_data.read_data_sets使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.examples.tutorials.mnist.input_data的用法示例。


在下文中一共展示了input_data.read_data_sets方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def __init__(
            self,
            seed=0,
            episode_len=None,
            no_images=None
    ):
        from tensorflow.examples.tutorials.mnist import input_data
        # we could use temporary directory for this with a context manager and 
        # TemporaryDirecotry, but then each test that uses mnist would re-download the data
        # this way the data is not cleaned up, but we only download it once per machine
        mnist_path = osp.join(tempfile.gettempdir(), 'MNIST_data')
        with filelock.FileLock(mnist_path + '.lock'):
           self.mnist = input_data.read_data_sets(mnist_path)

        self.np_random = np.random.RandomState()
        self.np_random.seed(seed)

        self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
        self.action_space = Discrete(10)
        self.episode_len = episode_len
        self.time = 0
        self.no_images = no_images

        self.train_mode()
        self.reset() 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:27,代码来源:mnist_env.py

示例2: mlp_mnist

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def mlp_mnist():
    """test MLP with MNIST data and Sequential

    """
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('/tmp/data', one_hot=True)
    training_data = np.array([image.flatten() for image in mnist.train.images])
    training_label = mnist.train.labels
    valid_data = np.array([image.flatten() for image in mnist.validation.images])
    valid_label = mnist.validation.labels
    input_dim = training_data.shape[1]
    label_size = training_label.shape[1]

    model = Sequential()
    model.add(Input(input_shape=(input_dim, )))
    model.add(Dense(300, activator='selu'))
    model.add(Dropout(0.2))
    model.add(Softmax(label_size))
    model.compile('CCE', optimizer=SGD())
    model.fit(training_data, training_label, validation_data=(valid_data, valid_label)) 
开发者ID:l11x0m7,项目名称:lightnn,代码行数:22,代码来源:mnist.py

示例3: cnn_mnist

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def cnn_mnist():
    """test CNN with MNIST data and Sequential

    """
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('/tmp/data', one_hot=True)
    training_data = np.array([image.reshape(28, 28, 1) for image in mnist.train.images])
    training_label = mnist.train.labels
    valid_data = np.array([image.reshape(28, 28, 1) for image in mnist.validation.images])
    valid_label = mnist.validation.labels
    label_size = training_label.shape[1]

    model =Sequential()
    model.add(Input(batch_input_shape=(None, 28, 28, 1)))
    model.add(Conv2d((3, 3), 1, activator='selu'))
    model.add(AvgPooling((2, 2), stride=2))
    model.add(Conv2d((4, 4), 2, activator='selu'))
    model.add(AvgPooling((2, 2), stride=2))
    model.add(Flatten())
    model.add(Softmax(label_size))
    model.compile('CCE', optimizer=SGD(lr=1e-2))
    model.fit(training_data, training_label, validation_data=(valid_data, valid_label), verbose=2) 
开发者ID:l11x0m7,项目名称:lightnn,代码行数:24,代码来源:mnist.py

示例4: model_mlp_mnist

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def model_mlp_mnist():
    """test MLP with MNIST data and Model

    """
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('/tmp/data', one_hot=True)
    training_data = np.array([image.flatten() for image in mnist.train.images])
    training_label = mnist.train.labels
    valid_data = np.array([image.flatten() for image in mnist.validation.images])
    valid_label = mnist.validation.labels
    input_dim = training_data.shape[1]
    label_size = training_label.shape[1]

    dense_1 = Dense(300, input_dim=input_dim, activator=None)
    dense_2 = Activation('selu')(dense_1)
    dropout_1 = Dropout(0.2)(dense_2)
    softmax_1 = Softmax(label_size)(dropout_1)
    model = Model(dense_1, softmax_1)
    model.compile('CCE', optimizer=Adadelta())
    model.fit(training_data, training_label, validation_data=(valid_data, valid_label)) 
开发者ID:l11x0m7,项目名称:lightnn,代码行数:22,代码来源:mnist.py

示例5: main

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def main():
    mnist = input_data.read_data_sets(train_dir='mnist')

    train = {'X': resize_images(mnist.train.images.reshape(-1, 28, 28)),
             'y': mnist.train.labels}
    
    test = {'X': resize_images(mnist.test.images.reshape(-1, 28, 28)),
            'y': mnist.test.labels}
    #~ train = {'X': mnist.train.images,
             #~ 'y': mnist.train.labels}
    
    #~ test = {'X': mnist.test.images,
            #~ 'y': mnist.test.labels}
        
    save_pickle(train, 'mnist/train.pkl')
    save_pickle(test, 'mnist/test.pkl') 
开发者ID:pmorerio,项目名称:minimal-entropy-correlation-alignment,代码行数:18,代码来源:prepro.py

示例6: do_eval

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  """Runs one evaluation against the full epoch of data.

  Args:
    sess: The session in which the model has been trained.
    eval_correct: The Tensor that returns the number of correct predictions.
    images_placeholder: The images placeholder.
    labels_placeholder: The labels placeholder.
    data_set: The set of images and labels to evaluate, from
      input_data.read_data_sets().
  """
  # And run one epoch of eval.
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set, images_placeholder, labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = true_count / num_examples
  print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision)) 
开发者ID:GoogleCloudPlatform,项目名称:cloudml-samples,代码行数:27,代码来源:task.py

示例7: get_dataset

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def get_dataset(data_dir, preprocess_fcn=None, dtype=tf.float32, reshape=True):
  """Construct a DataSet.
  `dtype` can be either
  `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
  `[0, 1]`.
   `reshape` Convert shape from [num examples, rows, columns, depth]
    to [num examples, rows*columns] (assuming depth == 1)    
  """
  from tensorflow.examples.tutorials.mnist import input_data

  datasets = input_data.read_data_sets(data_dir, dtype=dtype, reshape=reshape)
  
  if preprocess_fcn is not None:
    train = _preprocess_dataset(datasets.train, preprocess_fcn, dtype, reshape)
    validation = _preprocess_dataset(datasets.validation, preprocess_fcn, dtype, reshape)
    test = _preprocess_dataset(datasets.test, preprocess_fcn, dtype, reshape)
  else:
    train = datasets.train
    validation = datasets.validation
    test = datasets.test

  height, width, channels = 28, 28, 1 
  return Datasets(train, validation, test, height, width, channels) 
开发者ID:jakebelew,项目名称:gated-pixel-cnn,代码行数:25,代码来源:mnist_data.py

示例8: fill_feed_dict

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def fill_feed_dict(data_set, images_pl, labels_pl, batch_size):
  """Fills the feed_dict for training the given step.

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().
    batch_size: Batch size of data to feed.

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size ` examples.
  images_feed, labels_feed = data_set.next_batch(batch_size, FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:22,代码来源:mnist.py

示例9: download_and_process_mnist

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def download_and_process_mnist():
    
    
    if not os.path.exists('./data/mnist'):
	os.makedirs('./data/mnist')
    
    mnist = input_data.read_data_sets(train_dir='./data/mnist')

    train = {'X': resize_images(mnist.train.images.reshape(-1, 28, 28)),
             'y': mnist.train.labels}
    
    test = {'X': resize_images(mnist.test.images.reshape(-1, 28, 28)),
            'y': mnist.test.labels}
        
    with open('./data/mnist/train.pkl','w') as f:
	cPickle.dump(train,f,cPickle.HIGHEST_PROTOCOL)
    
    with open('./data/mnist/test.pkl','w') as f:
	cPickle.dump(test,f,cPickle.HIGHEST_PROTOCOL) 
开发者ID:ricvolpi,项目名称:adversarial-feature-augmentation,代码行数:21,代码来源:download_and_process_mnist.py

示例10: load

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def load(config, **unused_kwargs):

    del unused_kwargs

    if not os.path.exists(config.data_folder):
        os.makedirs(config.data_folder)

    dataset = input_data.read_data_sets(config.data_folder)

    train_data = {'imgs': dataset.train.images, 'labels': dataset.train.labels}
    valid_data = {'imgs': dataset.validation.images, 'labels': dataset.validation.labels}

    # This function turns a dictionary of numpy.ndarrays into tensors.
    train_tensors = tensors_from_data(train_data, config.batch_size, shuffle=True)
    valid_tensors = tensors_from_data(valid_data, config.batch_size, shuffle=False)

    data_dict = AttrDict(
        train_img=train_tensors['imgs'],
        valid_img=valid_tensors['imgs'],
        train_label=train_tensors['labels'],
        valid_label=valid_tensors['labels'],
    )

    return data_dict 
开发者ID:akosiorek,项目名称:forge,代码行数:26,代码来源:mnist_data.py

示例11: load

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def load(config, **unused_kwargs):

    del unused_kwargs

    if not os.path.exists(config.data_folder):
        os.makedirs(config.data_folder)

    dataset = input_data.read_data_sets(config.data_folder)

    train_data = {'imgs': dataset.train.images, 'labels': dataset.train.labels}
    valid_data = {'imgs': dataset.validation.images, 'labels': dataset.validation.labels}

    train_tensors = tensors_from_data(train_data, config.batch_size, shuffle=True)
    valid_tensors = tensors_from_data(valid_data, config.batch_size, shuffle=False)

    data_dict = AttrDict(
        train_img=train_tensors['imgs'],
        valid_img=valid_tensors['imgs'],
        train_label=train_tensors['labels'],
        valid_label=valid_tensors['labels'],
    )

    return data_dict 
开发者ID:akosiorek,项目名称:forge,代码行数:25,代码来源:mnist_data.py

示例12: load_model

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def load_model(self):
        tf.train.Saver().restore(self._sess, tf.train.latest_checkpoint("/home/ilmare/Desktop/FaceReplace/model/"))
        mnist = input_data.read_data_sets("/home/ilmare/dataSet/mnist", one_hot=True)
        source = np.reshape(mnist.train.images[0], [1, 784])
        dest = self.reconstrct(source)
        source = np.reshape(source, [28, 28])
        dest = np.reshape(dest, [28, 28])
        print(source.shape, dest.shape)
        # fig = plt.figure("test")
        # ax = fig.add_subplot(121)
        # ax.imshow(source)
        # bx = fig.add_subplot(122)
        # bx.imshow(dest)
        # plt.show()
        cv2.imshow("test", dest)
        cv2.waitKey(0) 
开发者ID:yhswjtuILMARE,项目名称:Machine-Learning-Study-Notes,代码行数:18,代码来源:AGNModel.py

示例13: generate_metadata_file

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def generate_metadata_file():
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir,
                                      one_hot=True)
    # The ".tsv" file will contain one number per row to point to the good label
    # for each test example in the dataset.
    # For example, labels could be saved as plain text on those lines if needed.
    # In our case we have only 10 possible different labels, so their
    # "uniqueness" is recognised to later associate colors automatically in
    # TensorBoard.
    def save_metadata(file):
        with open(file, 'w') as f:
            for i in range(NB_TEST_DATA):
                c = np.nonzero(mnist.test.labels[::1])[1:][0][i]
                f.write('{}\n'.format(c))

    save_metadata(FLAGS.log_dir + '/projector/metadata.tsv') 
开发者ID:Vooban,项目名称:Autoencoder-TensorBoard-t-SNE,代码行数:19,代码来源:autoencoder_t-sne.py

示例14: fill_feed_dict

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.
  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }
  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().
  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict 
开发者ID:eyalzk,项目名称:telegrad,代码行数:25,代码来源:tf_mnist_example.py

示例15: download_mnist_retry

# 需要导入模块: from tensorflow.examples.tutorials.mnist import input_data [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.input_data import read_data_sets [as 别名]
def download_mnist_retry(data_dir, max_num_retries=20):
    """Try to download mnist dataset and avoid errors"""
    for _ in range(max_num_retries):
        try:
            return input_data.read_data_sets(data_dir, one_hot=True)
        except tf.errors.AlreadyExistsError:
            time.sleep(1)
    raise Exception("Failed to download MNIST.") 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:10,代码来源:2_mnist.py


注:本文中的tensorflow.examples.tutorials.mnist.input_data.read_data_sets方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。