本文整理汇总了Python中tensorflow.contrib.tensor_forest.python.tensor_forest.RandomForestDeviceAssigner方法的典型用法代码示例。如果您正苦于以下问题:Python tensor_forest.RandomForestDeviceAssigner方法的具体用法?Python tensor_forest.RandomForestDeviceAssigner怎么用?Python tensor_forest.RandomForestDeviceAssigner使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.tensor_forest.python.tensor_forest
的用法示例。
在下文中一共展示了tensor_forest.RandomForestDeviceAssigner方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from tensorflow.contrib.tensor_forest.python import tensor_forest [as 别名]
# 或者: from tensorflow.contrib.tensor_forest.python.tensor_forest import RandomForestDeviceAssigner [as 别名]
def __init__(self,
params,
device_assigner=None,
optimizer_class=adagrad.AdagradOptimizer,
**kwargs):
self.device_assigner = (
device_assigner or tensor_forest.RandomForestDeviceAssigner())
self.params = params
self.optimizer = optimizer_class(self.params.learning_rate)
self.is_regression = params.regression
self.regularizer = None
if params.regularization == "l1":
self.regularizer = layers.l1_regularizer(
self.params.regularization_strength)
elif params.regularization == "l2":
self.regularizer = layers.l2_regularizer(
self.params.regularization_strength)
示例2: __init__
# 需要导入模块: from tensorflow.contrib.tensor_forest.python import tensor_forest [as 别名]
# 或者: from tensorflow.contrib.tensor_forest.python.tensor_forest import RandomForestDeviceAssigner [as 别名]
def __init__(self, params, device_assigner=None, model_dir=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
master='', accuracy_metric=None,
tf_random_seed=None, config=None):
self.params = params.fill()
self.accuracy_metric = (accuracy_metric or
('r2' if self.params.regression else 'accuracy'))
self.data_feeder = None
self.device_assigner = (
device_assigner or tensor_forest.RandomForestDeviceAssigner())
self.graph_builder_class = graph_builder_class
self.training_args = {}
self.construction_args = {}
super(TensorForestEstimator, self).__init__(model_dir=model_dir,
config=config)
示例3: __init__
# 需要导入模块: from tensorflow.contrib.tensor_forest.python import tensor_forest [as 别名]
# 或者: from tensorflow.contrib.tensor_forest.python.tensor_forest import RandomForestDeviceAssigner [as 别名]
def __init__(self, params, layer_num, device_assigner, *args, **kwargs):
self.layer_num = layer_num
self.device_assigner = (
device_assigner or tensor_forest.RandomForestDeviceAssigner())
self.params = params
self._define_vars(params, **kwargs)
示例4: export
# 需要导入模块: from tensorflow.contrib.tensor_forest.python import tensor_forest [as 别名]
# 或者: from tensorflow.contrib.tensor_forest.python.tensor_forest import RandomForestDeviceAssigner [as 别名]
def export(self,
export_dir,
input_fn,
signature_fn=None,
input_feature_key=None,
default_batch_size=1):
"""See BaseEstimator.export."""
# Reset model function with basic device assigner.
# Servo doesn't support distributed inference
# but it will try to respect device assignments if they're there.
# pylint: disable=protected-access
orig_model_fn = self._estimator._model_fn
self._estimator._model_fn = get_model_fn(
self.params, self.graph_builder_class,
tensor_forest.RandomForestDeviceAssigner(),
weights_name=self.weights_name)
result = self._estimator.export(
export_dir=export_dir,
input_fn=input_fn,
input_feature_key=input_feature_key,
use_deprecated_input_fn=False,
signature_fn=(signature_fn or
(export.regression_signature_fn
if self.params.regression else
export.classification_signature_fn_with_prob)),
default_batch_size=default_batch_size,
prediction_key=eval_metrics.INFERENCE_PROB_NAME)
self._estimator._model_fn = orig_model_fn
# pylint: enable=protected-access
return result
示例5: export
# 需要导入模块: from tensorflow.contrib.tensor_forest.python import tensor_forest [as 别名]
# 或者: from tensorflow.contrib.tensor_forest.python.tensor_forest import RandomForestDeviceAssigner [as 别名]
def export(self,
export_dir,
input_fn,
signature_fn=None,
default_batch_size=1):
"""See BaseEstimator.export."""
# Reset model function with basic device assigner.
# Servo doesn't support distributed inference
# but it will try to respect device assignments if they're there.
# pylint: disable=protected-access
orig_model_fn = self._estimator._model_fn
self._estimator._model_fn = get_model_fn(
self.params, self.graph_builder_class,
tensor_forest.RandomForestDeviceAssigner(),
weights_name=self.weights_name)
result = self._estimator.export(
export_dir=export_dir,
use_deprecated_input_fn=True,
signature_fn=(signature_fn or
(export.regression_signature_fn
if self.params.regression else
export.classification_signature_fn_with_prob)),
default_batch_size=default_batch_size,
prediction_key=eval_metrics.INFERENCE_PROB_NAME)
self._estimator._model_fn = orig_model_fn
# pylint: enable=protected-access
return result
示例6: __init__
# 需要导入模块: from tensorflow.contrib.tensor_forest.python import tensor_forest [as 别名]
# 或者: from tensorflow.contrib.tensor_forest.python.tensor_forest import RandomForestDeviceAssigner [as 别名]
def __init__(self, params, device_assigner=None, model_dir=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
config=None, weights_name=None, keys_name=None,
feature_engineering_fn=None, early_stopping_rounds=100,
num_trainers=1, trainer_id=0):
"""Initializes a TensorForestEstimator instance.
Args:
params: ForestHParams object that holds random forest hyperparameters.
These parameters will be passed into `model_fn`.
device_assigner: An `object` instance that controls how trees get
assigned to devices. If `None`, will use
`tensor_forest.RandomForestDeviceAssigner`.
model_dir: Directory to save model parameters, graph, etc. To continue
training a previously saved model, load checkpoints saved to this
directory into an estimator.
graph_builder_class: An `object` instance that defines how TF graphs for
random forest training and inference are built. By default will use
`tensor_forest.RandomForestGraphs`.
config: `RunConfig` object to configure the runtime settings.
weights_name: A string defining feature column name representing
weights. Will be multiplied by the loss of the example. Used to
downweight or boost examples during training.
keys_name: A string defining feature column name representing example
keys. Used by `predict_with_keys` method.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
early_stopping_rounds: Allows training to terminate early if the forest is
no longer growing. 100 by default.
num_trainers: Number of training jobs, which will partition trees
among them.
trainer_id: Which trainer this instance is.
Returns:
A `TensorForestEstimator` instance.
"""
self.params = params.fill()
self.graph_builder_class = graph_builder_class
self.early_stopping_rounds = early_stopping_rounds
self.weights_name = weights_name
self._estimator = estimator.Estimator(
model_fn=get_model_fn(params, graph_builder_class, device_assigner,
weights_name=weights_name, keys_name=keys_name,
num_trainers=num_trainers, trainer_id=trainer_id),
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
self._skcompat = estimator.SKCompat(self._estimator)
示例7: __init__
# 需要导入模块: from tensorflow.contrib.tensor_forest.python import tensor_forest [as 别名]
# 或者: from tensorflow.contrib.tensor_forest.python.tensor_forest import RandomForestDeviceAssigner [as 别名]
def __init__(self, params, device_assigner=None, model_dir=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
config=None, weights_name=None, keys_name=None,
feature_engineering_fn=None, early_stopping_rounds=100):
"""Initializes a TensorForestEstimator instance.
Args:
params: ForestHParams object that holds random forest hyperparameters.
These parameters will be passed into `model_fn`.
device_assigner: An `object` instance that controls how trees get
assigned to devices. If `None`, will use
`tensor_forest.RandomForestDeviceAssigner`.
model_dir: Directory to save model parameters, graph, etc. To continue
training a previously saved model, load checkpoints saved to this
directory into an estimator.
graph_builder_class: An `object` instance that defines how TF graphs for
random forest training and inference are built. By default will use
`tensor_forest.RandomForestGraphs`.
config: `RunConfig` object to configure the runtime settings.
weights_name: A string defining feature column name representing
weights. Will be multiplied by the loss of the example. Used to
downweight or boost examples during training.
keys_name: A string defining feature column name representing example
keys. Used by `predict_with_keys` method.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
early_stopping_rounds: Allows training to terminate early if the forest is
no longer growing. 100 by default.
Returns:
A `TensorForestEstimator` instance.
"""
self.params = params.fill()
self.graph_builder_class = graph_builder_class
self.early_stopping_rounds = early_stopping_rounds
self.weights_name = weights_name
self._estimator = estimator.Estimator(
model_fn=get_model_fn(params, graph_builder_class, device_assigner,
weights_name=weights_name, keys_name=keys_name),
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)