当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.local_variables_initializer函数代码示例

本文整理汇总了Python中tensorflow.local_variables_initializer函数的典型用法代码示例。如果您正苦于以下问题:Python local_variables_initializer函数的具体用法?Python local_variables_initializer怎么用?Python local_variables_initializer使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了local_variables_initializer函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test

def test(model, config, prompts):

    sr = 24000 if 'blizzard' in config.data_path else 16000
    meta = data_input.load_meta(config.data_path)
    config.r = audio.r
    ivocab = meta['vocab']
    config.vocab_size = len(ivocab)

    with tf.device('/cpu:0'):
        batch_inputs = data_input.load_prompts(prompts, ivocab)
        config.num_prompts = len(prompts)

    with tf.Session() as sess:

        stft_mean = tf.get_variable('stft_mean', shape=(1025*audio.r,), dtype=tf.float16)
        stft_std = tf.get_variable('stft_std', shape=(1025*audio.r,), dtype=tf.float32)

        # initialize model
        model = model(config, batch_inputs, train=False)

        train_writer = tf.summary.FileWriter('log/' + config.save_path + '/test', sess.graph)

        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        saver = tf.train.Saver()

        print('restoring weights')
        latest_ckpt = tf.train.latest_checkpoint(
            'weights/' + config.save_path[:config.save_path.rfind('/')]
        )
        saver.restore(sess, latest_ckpt)

        stft_mean, stft_std = sess.run([stft_mean, stft_std])

        try:
            while(True):
                out = sess.run([
                    model.output,
                    model.alignments,
                    batch_inputs
                ])
                outputs, alignments, inputs = out

                print('saving samples')
                for out, words, align in zip(outputs, inputs['text'], alignments):
                    # store a sample to listen to
                    text = ''.join([ivocab[w] for w in words])
                    attention_plot = data_input.generate_attention_plot(align)
                    sample = audio.invert_spectrogram(out*stft_std + stft_mean)
                    merged = sess.run(tf.summary.merge(
                         [tf.summary.audio(text, sample[None, :], sr),
                          tf.summary.image(text, attention_plot)]
                    ))
                    train_writer.add_summary(merged, 0)
        except tf.errors.OutOfRangeError:
            coord.request_stop()
            coord.join(threads)
开发者ID:yhgon,项目名称:Tacotron-tf-barronalex,代码行数:60,代码来源:test.py

示例2: train

def train(model, data, gen, params):
    anim_frames = []

    with tf.Session() as session:
        tf.local_variables_initializer().run()
        tf.global_variables_initializer().run()

        for step in range(params.num_steps + 1):
            # update discriminator
            x = data.sample(params.batch_size)
            z = gen.sample(params.batch_size)
            loss_d, _, = session.run([model.loss_d, model.opt_d], {
                model.x: np.reshape(x, (params.batch_size, 1)),
                model.z: np.reshape(z, (params.batch_size, 1))
            })

            # update generator
            z = gen.sample(params.batch_size)
            loss_g, _ = session.run([model.loss_g, model.opt_g], {
                model.z: np.reshape(z, (params.batch_size, 1))
            })

            if step % params.log_every == 0:
                print('{}: {:.4f}\t{:.4f}'.format(step, loss_d, loss_g))

            if params.anim_path and (step % params.anim_every == 0):
                anim_frames.append(
                    samples(model, session, data, gen.range, params.batch_size)
                )

        if params.anim_path:
            save_animation(anim_frames, params.anim_path, gen.range)
        else:
            samps = samples(model, session, data, gen.range, params.batch_size)
            plot_distributions(samps, gen.range)
开发者ID:yashvardhan90,项目名称:gan-intro,代码行数:35,代码来源:gan.py

示例3: test_empty_labels_and_scores_gives_nan_auc

 def test_empty_labels_and_scores_gives_nan_auc(self):
   with self.test_session():
     labels = tf.constant([], shape=[0], dtype=tf.bool)
     scores = tf.constant([], shape=[0], dtype=tf.float32)
     score_range = [0, 1.]
     auc, update_op = tf.contrib.metrics.auc_using_histogram(labels, scores,
                                                             score_range)
     tf.local_variables_initializer().run()
     update_op.run()
     self.assertTrue(np.isnan(auc.eval()))
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:10,代码来源:histogram_ops_test.py

示例4: _check_auc

  def _check_auc(self,
                 nbins=100,
                 desired_auc=0.75,
                 score_range=None,
                 num_records=50,
                 frac_true=0.5,
                 atol=0.05,
                 num_updates=10):
    """Check auc accuracy against synthetic data.

    Args:
      nbins:  nbins arg from contrib.metrics.auc_using_histogram.
      desired_auc:  Number in [0, 1].  The desired auc for synthetic data.
      score_range:  2-tuple, (low, high), giving the range of the resultant
        scores.  Defaults to [0, 1.].
      num_records:  Positive integer.  The number of records to return.
      frac_true:  Number in (0, 1).  Expected fraction of resultant labels that
        will be True.  This is just in expectation...more or less may actually
        be True.
      atol:  Absolute tolerance for final AUC estimate.
      num_updates:  Update internal histograms this many times, each with a new
        batch of synthetic data, before computing final AUC.

    Raises:
      AssertionError: If resultant AUC is not within atol of theoretical AUC
        from synthetic data.
    """
    score_range = [0, 1.] or score_range
    with self.test_session():
      labels = tf.placeholder(tf.bool, shape=[num_records])
      scores = tf.placeholder(tf.float32, shape=[num_records])
      auc, update_op = tf.contrib.metrics.auc_using_histogram(labels,
                                                              scores,
                                                              score_range,
                                                              nbins=nbins)
      tf.local_variables_initializer().run()
      # Updates, then extract auc.
      for _ in range(num_updates):
        labels_a, scores_a = synthetic_data(desired_auc, score_range,
                                            num_records, self.rng, frac_true)
        update_op.run(feed_dict={labels: labels_a, scores: scores_a})
      labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records,
                                          self.rng, frac_true)
      # Fetch current auc, and verify that fetching again doesn't change it.
      auc_eval = auc.eval()
      self.assertAlmostEqual(auc_eval, auc.eval(), places=5)

    msg = ('nbins: %s, desired_auc: %s, score_range: %s, '
           'num_records: %s, frac_true: %s, num_updates: %s') % (nbins,
                                                                 desired_auc,
                                                                 score_range,
                                                                 num_records,
                                                                 frac_true,
                                                                 num_updates)
    np.testing.assert_allclose(desired_auc, auc_eval, atol=atol, err_msg=msg)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:55,代码来源:histogram_ops_test.py

示例5: train

    def train(self, DGTrain, DGTest, saver=True):

        epoch = DGTrain.length

        self.LearningRateSchedule(self.LEARNING_RATE, self.K, epoch)

        trainable_var = tf.trainable_variables()
        
        self.regularize_model()
        self.optimization(trainable_var)
        self.ExponentialMovingAverage(trainable_var, self.DECAY_EMA)

        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()

        self.summary_test_writer = tf.summary.FileWriter(self.LOG + '/test',
                                            graph=self.sess.graph)

        self.summary_writer = tf.summary.FileWriter(self.LOG + '/train', graph=self.sess.graph)
        merged_summary = tf.summary.merge_all()
        steps = self.STEPS

        
        # for i in range(Xval.shape[0]):
        #     imsave("/tmp/image_{}.png".format(i), Xval[i])
        #     imsave("/tmp/label_{}.png".format(i), Yval[i,:,:,0])



        for step in range(steps):
            batch_data, batch_labels = DGTrain.Batch(0, self.BATCH_SIZE)
            feed_dict = {self.input_node: batch_data,
                         self.train_labels_node: batch_labels}

            # self.optimizer is replaced by self.training_op for the exponential moving decay
            _, l, lr, predictions, s = self.sess.run(
                        [self.training_op, self.loss, self.learning_rate,
                         self.train_prediction, merged_summary],
                        feed_dict=feed_dict)

            if step % self.N_PRINT == 0:
                i = datetime.now()
                print i.strftime('%Y/%m/%d %H:%M:%S: \n ')
                self.summary_writer.add_summary(s, step)                
                error, acc, acc1, recall, prec, f1 = self.error_rate(predictions, batch_labels, step)
                print('  Step %d of %d' % (step, steps))
                print('  Learning rate: %.5f \n') % lr
                print('  Mini-batch loss: %.5f \n       Accuracy: %.1f%% \n       acc1: %.1f%% \n       recall: %1.f%% \n       prec: %1.f%% \n       f1 : %1.f%% \n' % 
                      (l, acc, acc1, recall, prec, f1))
                self.Validation(DGTest, step)
开发者ID:PeterJackNaylor,项目名称:PhD_Fabien,代码行数:50,代码来源:ObjectOriented.py

示例6: main

def main(model_config, train_config, track_config):
  # Create training directory
  train_dir = train_config['train_dir']
  if not tf.gfile.IsDirectory(train_dir):
    tf.logging.info('Creating training directory: %s', train_dir)
    tf.gfile.MakeDirs(train_dir)

  # Build the Tensorflow graph
  g = tf.Graph()
  with g.as_default():
    # Set fixed seed
    np.random.seed(train_config['seed'])
    tf.set_random_seed(train_config['seed'])

    # Build the model
    model = siamese_model.SiameseModel(model_config, train_config, mode='inference')
    model.build()

    # Save configurations for future reference
    save_cfgs(train_dir, model_config, train_config, track_config)

    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=train_config['max_checkpoints_to_keep'])

    # Dynamically allocate GPU memory
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_options)

    sess = tf.Session(config=sess_config)
    model_path = tf.train.latest_checkpoint(train_config['train_dir'])

    if not model_path:
      # Initialize all variables
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      start_step = 0

      # Load pretrained embedding model if needed
      if model_config['embed_config']['embedding_checkpoint_file']:
        model.init_fn(sess)

    else:
      logging.info('Restore from last checkpoint: {}'.format(model_path))
      sess.run(tf.local_variables_initializer())
      saver.restore(sess, model_path)
      start_step = tf.train.global_step(sess, model.global_step.name) + 1

    checkpoint_path = osp.join(train_config['train_dir'], 'model.ckpt')
    saver.save(sess, checkpoint_path, global_step=start_step)
开发者ID:fossabot,项目名称:SiamFC-TensorFlow,代码行数:49,代码来源:convert_pretrained_model.py

示例7: test_batch_text_lines

  def test_batch_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("A\nB\nC\nD\nE\n")

    batch_size = 3
    queue_capacity = 10
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = tf.contrib.learn.io.read_batch_examples(
          [filename], batch_size, reader=tf.TextLineReader,
          randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
          read_batch_size=10, name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
      self.assertAllEqual(session.run(inputs), [b"D", b"E"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
开发者ID:moolighty,项目名称:tensorflow,代码行数:26,代码来源:graph_io_test.py

示例8: testRoundtrip

    def testRoundtrip(self, rate=0.25, count=5, n=500):
        """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""

        foo = self.get_values(count)
        bar = self.get_values(count)
        weights = self.get_weights(count)

        resampled_in, rates = tf.contrib.training.weighted_resample([foo, bar], tf.constant(weights), rate, seed=123)

        resampled_back_out = tf.contrib.training.resample_at_rate(resampled_in, 1.0 / rates, seed=456)

        init = tf.local_variables_initializer()
        with self.test_session() as s:
            s.run(init)  # initialize

            # outputs
            counts_resampled = collections.Counter()
            counts_reresampled = collections.Counter()
            for _ in range(n):
                resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out])

                self.assertAllEqual(resampled_vs[0], resampled_vs[1])
                self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])

                for v in resampled_vs[0]:
                    counts_resampled[v] += 1
                for v in reresampled_vs[0]:
                    counts_reresampled[v] += 1

            # assert that resampling worked as expected
            self.assert_expected(weights, rate, counts_resampled, n)

            # and that re-resampling gives the approx identity.
            self.assert_expected([1.0 for _ in weights], 1.0, counts_reresampled, n, abs_delta=0.1 * n * count)
开发者ID:brchiu,项目名称:tensorflow,代码行数:34,代码来源:resample_test.py

示例9: blend_images

def blend_images(data_folder1, data_folder2, out_folder, alpha=.5):
    filename_queue = tf.placeholder(dtype=tf.string)
    label = tf.placeholder(dtype=tf.int32)
    tensor_image = tf.read_file(filename_queue)

    image = tf.image.decode_jpeg(tensor_image, channels=3)

    multiplier = tf.div(tf.constant(224, tf.float32),
                        tf.cast(tf.maximum(tf.shape(image)[0], tf.shape(image)[1]), tf.float32))
    x = tf.cast(tf.round(tf.mul(tf.cast(tf.shape(image)[0], tf.float32), multiplier)), tf.int32)
    y = tf.cast(tf.round(tf.mul(tf.cast(tf.shape(image)[1], tf.float32), multiplier)), tf.int32)
    image = tf.image.resize_images(image, [x, y])

    image = tf.image.rot90(image, k=label)

    image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
    sess = tf.Session()
    sess.run(tf.local_variables_initializer())
    for root, folders, files in os.walk(data_folder1):
        for each in files:
            if each.find('.jpg') >= 0:
                img1 = Image.open(os.path.join(root, each))
                img2_path = os.path.join(root.replace(data_folder1, data_folder2), each.split("-")[-1])
                rotation = int(each.split("-")[1])
                img2 = sess.run(image, feed_dict={filename_queue: img2_path, label: rotation})
                imsave(os.path.join(os.getcwd(), "temp", "temp.jpg"), img2)
                img2 = Image.open(os.path.join(os.getcwd(), "temp", "temp.jpg"))
                out_image = Image.blend(img1, img2, alpha)
                outfile = os.path.join(root.replace(data_folder1, out_folder), each)
                if not os.path.exists(os.path.split(outfile)[0]):
                    os.makedirs(os.path.split(outfile)[0])
                out_image.save(outfile)
            else:
                print(each)
    sess.close()
开发者ID:Sabrewarrior,项目名称:PhotoOrientation,代码行数:35,代码来源:misc.py

示例10: test

    def test(self, p1, p2, steps):
        loss, roc = 0., 0.
        acc, F1, recall = 0., 0., 0.
        precision, jac, AJI = 0., 0., 0.
        init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())
        self.sess.run(init_op)
        self.Saver()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for step in range(steps):  
            feed_dict = {self.is_training: False} 
            l,  prob, batch_labels = self.sess.run([self.loss, self.train_prediction,
                                                               self.train_labels_node], feed_dict=feed_dict)
            loss += l
            out = ComputeMetrics(prob[0,:,:,1], batch_labels[0,:,:,0], p1, p2)
            acc += out[0]
            roc += out[1]
            jac += out[2]
            recall += out[3]
            precision += out[4]
            F1 += out[5]
            AJI += out[6]
        coord.request_stop()
        coord.join(threads)
        loss, acc, F1 = np.array([loss, acc, F1]) / steps
        recall, precision, roc = np.array([recall, precision, roc]) / steps
        jac, AJI = np.array([jac, AJI]) / steps
        return loss, acc, F1, recall, precision, roc, jac, AJI
开发者ID:PeterJackNaylor,项目名称:PhD_Fabien,代码行数:30,代码来源:UNet.py

示例11: initialize_variables

def initialize_variables(sess, saver, logdir, checkpoint=None, resume=None):
  """Initialize or restore variables from a checkpoint if available.

  Args:
    sess: Session to initialize variables in.
    saver: Saver to restore variables.
    logdir: Directory to search for checkpoints.
    checkpoint: Specify what checkpoint name to use; defaults to most recent.
    resume: Whether to expect recovering a checkpoint or starting a new run.

  Raises:
    ValueError: If resume expected but no log directory specified.
    RuntimeError: If no resume expected but a checkpoint was found.
  """
  sess.run(tf.group(
      tf.local_variables_initializer(),
      tf.global_variables_initializer()))
  if resume and not (logdir or checkpoint):
    raise ValueError('Need to specify logdir to resume a checkpoint.')
  if logdir:
    state = tf.train.get_checkpoint_state(logdir)
    if checkpoint:
      checkpoint = os.path.join(logdir, checkpoint)
    if not checkpoint and state and state.model_checkpoint_path:
      checkpoint = state.model_checkpoint_path
    if checkpoint and resume is False:
      message = 'Found unexpected checkpoint when starting a new run.'
      raise RuntimeError(message)
    if checkpoint:
      saver.restore(sess, checkpoint)
开发者ID:shamanez,项目名称:agents,代码行数:30,代码来源:utility.py

示例12: testSummariesAreFlushedToDiskWithoutGlobalStep

  def testSummariesAreFlushedToDiskWithoutGlobalStep(self):
    output_dir = os.path.join(self.get_temp_dir(), 'flush_test_no_global_step')
    if tf.gfile.Exists(output_dir):  # For running on jenkins.
      tf.gfile.DeleteRecursively(output_dir)

    names_to_metrics, names_to_updates = self._create_names_to_metrics(
        self._predictions, self._labels)

    for k in names_to_metrics:
      v = names_to_metrics[k]
      tf.summary.scalar(k, v)

    summary_writer = tf.train.SummaryWriter(output_dir)

    initial_op = tf.group(tf.global_variables_initializer(),
                          tf.local_variables_initializer())
    eval_op = tf.group(*names_to_updates.values())

    with self.test_session() as sess:
      slim.evaluation.evaluation(
          sess,
          initial_op=initial_op,
          eval_op=eval_op,
          summary_op=tf.summary.merge_all(),
          summary_writer=summary_writer)

      names_to_values = {name: names_to_metrics[name].eval()
                         for name in names_to_metrics}
    self._verify_summaries(output_dir, names_to_values)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:29,代码来源:evaluation_test.py

示例13: run

def run():
    with tf.Session() as sess:
        print("start")
        feature = {'image': tf.FixedLenFeature([], tf.string),
                   'label': tf.FixedLenFeature([], tf.int64)}
        # Create a list of filenames and pass it to a queue
        print(data_path)
        filename_queue = tf.train.string_input_producer(data_path, num_epochs=1)
        # Define a reader and read the next record
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        # Decode the record read by the reader
        features = tf.parse_single_example(serialized_example, features=feature)
        # Convert the image data from string back to the numbers
        image = tf.decode_raw(features['image'], tf.uint8)
        # image = tf.cast(image, tf.int32)

        # Cast label data into int32
        label = tf.cast(features['label'], tf.int32)
        # Reshape image data into the original shape
        init_op = [tf.global_variables_initializer(), tf.local_variables_initializer()]
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        train_list = []
        for i in range(1000):
            example, l = sess.run([image, label])
            train_list.append((example,l))
            # print (example, l)
        coord.request_stop()
        coord.join(threads)
        return train_list
# run()
开发者ID:ykakde,项目名称:trash-classifier,代码行数:34,代码来源:tf_file_reader.py

示例14: main

def main(argv):
  del argv  # Unused.
  # Sanity check on the GCS bucket URL.
  if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
    print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
    sys.exit(1)

  # Verify that writing to the records file in GCS works.
  print("\n=== Testing writing and reading of GCS record file... ===")
  example_data = create_examples(FLAGS.num_examples, 5)
  with tf.python_io.TFRecordWriter(FLAGS.gcs_bucket_url) as hf:
    for e in example_data:
      hf.write(e.SerializeToString())

    print("Data written to: %s" % FLAGS.gcs_bucket_url)

  # Verify that reading from the tfrecord file works and that
  # tf_record_iterator works.
  record_iter = tf.python_io.tf_record_iterator(FLAGS.gcs_bucket_url)
  read_count = 0
  for _ in record_iter:
    read_count += 1
  print("Read %d records using tf_record_iterator" % read_count)

  if read_count != FLAGS.num_examples:
    print("FAIL: The number of records read from tf_record_iterator (%d) "
          "differs from the expected number (%d)" % (read_count,
                                                     FLAGS.num_examples))
    sys.exit(1)

  # Verify that running the read op in a session works.
  print("\n=== Testing TFRecordReader.read op in a session... ===")
  with tf.Graph().as_default() as _:
    filename_queue = tf.train.string_input_producer([FLAGS.gcs_bucket_url],
                                                    num_epochs=1)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      tf.train.start_queue_runners()
      index = 0
      for _ in range(FLAGS.num_examples):
        print("Read record: %d" % index)
        sess.run(serialized_example)
        index += 1

      # Reading one more record should trigger an exception.
      try:
        sess.run(serialized_example)
        print("FAIL: Failed to catch the expected OutOfRangeError while "
              "reading one more record than is available")
        sys.exit(1)
      except tf.errors.OutOfRangeError:
        print("Successfully caught the expected OutOfRangeError while "
              "reading one more record than is available")

  create_dir_test()
  create_object_test()
开发者ID:DILASSS,项目名称:tensorflow,代码行数:60,代码来源:gcs_smoke.py

示例15: predict

	def predict(self):
		import cv2
		import glob
		import numpy as np
		# TODO 不应该这样写,应该直接读图片预测,而不是从tfrecord读取,因为顺序变了,无法对应
		predict_file_path = glob.glob(os.path.join(ORIGIN_PREDICT_DIRECTORY, '*.tif'))
		print(len(predict_file_path))
		ckpt_path = CHECK_POINT_PATH
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			for index, image_path in enumerate(predict_file_path):
				# image = cv2.imread(image_path, flags=0)
				image = np.reshape(a=cv2.imread(image_path, flags=0), newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
				predict_image = sess.run(
					tf.argmax(input=self.prediction, axis=3),
					feed_dict={
						self.input_image: image,
						self.keep_prob: 1.0, self.lamb: 0.004
					}
				)
				cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % index), predict_image[0] * 255)
		print('Done prediction')
开发者ID:USTCzxm,项目名称:U-net,代码行数:27,代码来源:unet-TF.py


注:本文中的tensorflow.local_variables_initializer函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。