當前位置: 首頁>>代碼示例>>Python>>正文


Python backend.set_session方法代碼示例

本文整理匯總了Python中tensorflow.python.keras.backend.set_session方法的典型用法代碼示例。如果您正苦於以下問題:Python backend.set_session方法的具體用法?Python backend.set_session怎麽用?Python backend.set_session使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.python.keras.backend的用法示例。


在下文中一共展示了backend.set_session方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: onJsonInput

# 需要導入模塊: from tensorflow.python.keras import backend [as 別名]
# 或者: from tensorflow.python.keras.backend import set_session [as 別名]
def onJsonInput(self, jsonInput):
		#build the result object
		result = {'prediction':-1}

		#prepare the input
		x_raw = [jsonInput['pixels']]
		x_raw = np.reshape(x_raw, (1, 28, 28))

		ue.log('image shape: ' + str(x_raw.shape))
		#ue.log(stored)

		#convert pixels to N_samples, height, width, N_channels input tensor
		x = np.reshape(x_raw, (len(x_raw), 28, 28, 1))

		ue.log('input shape: ' + str(x.shape))

		#run run the input through our network
		if self.model is None:
			ue.log("Warning! No 'model' found. Did training complete?")
			return result

		#restore our saved session and model
		K.set_session(self.session)

		with self.session.as_default():
			output = self.model.predict(x)

			ue.log(output)

			#convert output array to prediction
			index, value = max(enumerate(output[0]), key=operator.itemgetter(1))

			result['prediction'] = index
			result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round tripping

		return result

	#expected api: no params forwarded for training? TBC 
開發者ID:getnamo,項目名稱:tensorflow-ue4-examples,代碼行數:40,代碼來源:mnistKerasCNN.py

示例2: onJsonInput

# 需要導入模塊: from tensorflow.python.keras import backend [as 別名]
# 或者: from tensorflow.python.keras.backend import set_session [as 別名]
def onJsonInput(self, jsonInput):
		#build the result object
		result = {'prediction':-1}

		#prepare the input
		x_raw = [jsonInput['pixels']]
		x_raw = np.reshape(x_raw, (1, 28, 28))

		ue.log('image shape: ' + str(x_raw.shape))
		#ue.log(stored)

		#convert pixels to N_samples, height, width, N_channels input tensor
		x = np.reshape(x_raw, (len(x_raw), 28, 28, 1))

		ue.log('input shape: ' + str(x.shape))

		#run run the input through our network
		if self.model is None:
			ue.log("Warning! No 'model' found. Did training complete?")
			return result

		#restore our saved session and model
		K.set_session(self.session)

		with self.session.as_default():
			output = self.model.predict(x)

			ue.log(output)

			#convert output array to prediction
			index, value = max(enumerate(output[0]), key=operator.itemgetter(1))

			result['prediction'] = index
			result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round trip testing

		return result

	#expected api: no params forwarded for training? TBC 
開發者ID:getnamo,項目名稱:tensorflow-ue4-examples,代碼行數:40,代碼來源:mnistKerasCNNOpt.py

示例3: keras_reproducible

# 需要導入模塊: from tensorflow.python.keras import backend [as 別名]
# 或者: from tensorflow.python.keras.backend import set_session [as 別名]
def keras_reproducible(seed=1234, verbose=0, TF_CPP_MIN_LOG_LEVEL="3"):
    """
    https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
    """
    import random
    import pkg_resources
    import os

    random.seed(seed)
    np.random.seed(seed)

    os.environ["PYTHONHASHSEED"] = "0"  # might need to do this outside the script

    if verbose == 0:
        os.environ[
            "TF_CPP_MIN_LOG_LEVEL"
        ] = TF_CPP_MIN_LOG_LEVEL  # 2 will print warnings

    try:
        import tensorflow
    except ImportError:
        raise ImportError("Missing required package 'tensorflow'")

    # Use the TF 1.x API
    if pkg_resources.get_distribution("tensorflow").version.startswith("1."):
        tf = tensorflow
    else:
        tf = tensorflow.compat.v1

    if verbose == 0:
        # https://github.com/tensorflow/tensorflow/issues/27023
        try:
            from tensorflow.python.util import deprecation

            deprecation._PRINT_DEPRECATION_WARNINGS = False
        except ImportError:
            try:
                from tensorflow.python.util import module_wrapper as deprecation
            except ImportError:
                from tensorflow.python.util import deprecation_wrapper as deprecation
            deprecation._PER_MODULE_WARNING_LIMIT = 0

        # this was deprecated in 1.15 (maybe earlier)
        tensorflow.compat.v1.logging.set_verbosity(tensorflow.compat.v1.logging.ERROR)

    ConfigProto = tf.ConfigProto

    session_conf = tf.ConfigProto(
        intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
    )

    with capture_all():  # doesn't have quiet option
        try:
            from tensorflow.python.keras import backend as K
        except ImportError:
            raise ImportError("Missing required module 'keras'")

    tf.set_random_seed(seed)
    sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
    K.set_session(sess) 
開發者ID:microsoft,項目名稱:SparseSC,代碼行數:62,代碼來源:match_space.py


注:本文中的tensorflow.python.keras.backend.set_session方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。