本文整理汇总了Python中tensorflow.python.keras.backend.set_session方法的典型用法代码示例。如果您正苦于以下问题:Python backend.set_session方法的具体用法?Python backend.set_session怎么用?Python backend.set_session使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.keras.backend
的用法示例。
在下文中一共展示了backend.set_session方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: onJsonInput
# 需要导入模块: from tensorflow.python.keras import backend [as 别名]
# 或者: from tensorflow.python.keras.backend import set_session [as 别名]
def onJsonInput(self, jsonInput):
#build the result object
result = {'prediction':-1}
#prepare the input
x_raw = [jsonInput['pixels']]
x_raw = np.reshape(x_raw, (1, 28, 28))
ue.log('image shape: ' + str(x_raw.shape))
#ue.log(stored)
#convert pixels to N_samples, height, width, N_channels input tensor
x = np.reshape(x_raw, (len(x_raw), 28, 28, 1))
ue.log('input shape: ' + str(x.shape))
#run run the input through our network
if self.model is None:
ue.log("Warning! No 'model' found. Did training complete?")
return result
#restore our saved session and model
K.set_session(self.session)
with self.session.as_default():
output = self.model.predict(x)
ue.log(output)
#convert output array to prediction
index, value = max(enumerate(output[0]), key=operator.itemgetter(1))
result['prediction'] = index
result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round tripping
return result
#expected api: no params forwarded for training? TBC
示例2: onJsonInput
# 需要导入模块: from tensorflow.python.keras import backend [as 别名]
# 或者: from tensorflow.python.keras.backend import set_session [as 别名]
def onJsonInput(self, jsonInput):
#build the result object
result = {'prediction':-1}
#prepare the input
x_raw = [jsonInput['pixels']]
x_raw = np.reshape(x_raw, (1, 28, 28))
ue.log('image shape: ' + str(x_raw.shape))
#ue.log(stored)
#convert pixels to N_samples, height, width, N_channels input tensor
x = np.reshape(x_raw, (len(x_raw), 28, 28, 1))
ue.log('input shape: ' + str(x.shape))
#run run the input through our network
if self.model is None:
ue.log("Warning! No 'model' found. Did training complete?")
return result
#restore our saved session and model
K.set_session(self.session)
with self.session.as_default():
output = self.model.predict(x)
ue.log(output)
#convert output array to prediction
index, value = max(enumerate(output[0]), key=operator.itemgetter(1))
result['prediction'] = index
result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round trip testing
return result
#expected api: no params forwarded for training? TBC
示例3: keras_reproducible
# 需要导入模块: from tensorflow.python.keras import backend [as 别名]
# 或者: from tensorflow.python.keras.backend import set_session [as 别名]
def keras_reproducible(seed=1234, verbose=0, TF_CPP_MIN_LOG_LEVEL="3"):
"""
https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
"""
import random
import pkg_resources
import os
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = "0" # might need to do this outside the script
if verbose == 0:
os.environ[
"TF_CPP_MIN_LOG_LEVEL"
] = TF_CPP_MIN_LOG_LEVEL # 2 will print warnings
try:
import tensorflow
except ImportError:
raise ImportError("Missing required package 'tensorflow'")
# Use the TF 1.x API
if pkg_resources.get_distribution("tensorflow").version.startswith("1."):
tf = tensorflow
else:
tf = tensorflow.compat.v1
if verbose == 0:
# https://github.com/tensorflow/tensorflow/issues/27023
try:
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
except ImportError:
try:
from tensorflow.python.util import module_wrapper as deprecation
except ImportError:
from tensorflow.python.util import deprecation_wrapper as deprecation
deprecation._PER_MODULE_WARNING_LIMIT = 0
# this was deprecated in 1.15 (maybe earlier)
tensorflow.compat.v1.logging.set_verbosity(tensorflow.compat.v1.logging.ERROR)
ConfigProto = tf.ConfigProto
session_conf = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
)
with capture_all(): # doesn't have quiet option
try:
from tensorflow.python.keras import backend as K
except ImportError:
raise ImportError("Missing required module 'keras'")
tf.set_random_seed(seed)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)