当前位置: 首页>>代码示例>>Python>>正文


Python tf_inspect.getargspec方法代码示例

本文整理汇总了Python中tensorflow.python.util.tf_inspect.getargspec方法的典型用法代码示例。如果您正苦于以下问题:Python tf_inspect.getargspec方法的具体用法?Python tf_inspect.getargspec怎么用?Python tf_inspect.getargspec使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.util.tf_inspect的用法示例。


在下文中一共展示了tf_inspect.getargspec方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _model_fn_args

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def _model_fn_args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
  _, fn = tf_decorator.unwrap(fn)
  if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
    # Handle functools.partial and similar objects.
    return tuple([
        arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
        if arg not in set(fn.keywords.keys())
    ])
  # Handle function.
  return tuple(tf_inspect.getargspec(fn).args) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:23,代码来源:estimator.py

示例2: filter_sk_params

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def filter_sk_params(self, fn, override=None):
    """Filters `sk_params` and return those in `fn`'s arguments.

    Arguments:
        fn : arbitrary function
        override: dictionary, values to override sk_params

    Returns:
        res : dictionary dictionary containing variables
            in both sk_params and fn's arguments.
    """
    override = override or {}
    res = {}
    fn_args = tf_inspect.getargspec(fn)[0]
    for name, value in self.sk_params.items():
      if name in fn_args:
        res.update({name: value})
    res.update(override)
    return res 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:21,代码来源:scikit_learn.py

示例3: _args

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def _args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.
  """
  if hasattr(fn, 'func') and hasattr(fn, 'keywords'):
    # Handle functools.partial and similar objects.
    return tuple([
        arg for arg in tf_inspect.getargspec(fn.func).args
        if arg not in set(fn.keywords.keys())
    ])
  # Handle function.
  return tuple(tf_inspect.getargspec(fn).args) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:19,代码来源:metric_spec.py

示例4: expand_thunks

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def expand_thunks(pat):
  """Expands thunks (zero-argument functions) in a pattern by calling them.

  Args:
    pat: The pattern to expand, possibly containing thunks.

  Returns:
    The expanded pattern.
  """
  def is_thunk(x):
    if hasattr(x, '__call__'):
      spec = tf_inspect.getargspec(x)
      num_free_args = len(set(spec.args)) - len(set(spec.defaults or {}))
      return num_free_args == 0
    return False
  while is_thunk(pat):
    pat = pat()
  if isinstance(pat, (tuple, list)):
    return type(pat)(map(expand_thunks, pat))
  return pat


## main matcher interface functions 
开发者ID:tensorflow,项目名称:kfac,代码行数:25,代码来源:graph_matcher.py

示例5: has_arg

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def has_arg(fn, name, accept_all=False):
  """Checks if a callable accepts a given keyword argument.

  Arguments:
      fn: Callable to inspect.
      name: Check if `fn` can be called with `name` as a keyword argument.
      accept_all: What to return if there is no parameter called `name`
                  but the function accepts a `**kwargs` argument.

  Returns:
      bool, whether `fn` accepts a `name` keyword argument.
  """
  arg_spec = tf_inspect.getargspec(fn)
  if accept_all and arg_spec.keywords is not None:
    return True
  return name in arg_spec.args 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:18,代码来源:generic_utils.py

示例6: function

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def function(inputs, outputs, updates=None, **kwargs):
  """Instantiates a Keras function.

  Arguments:
      inputs: List of placeholder tensors.
      outputs: List of output tensors.
      updates: List of update ops.
      **kwargs: Passed to `tf.Session.run`.

  Returns:
      Output values as Numpy arrays.

  Raises:
      ValueError: if invalid kwargs are passed in.
  """
  if kwargs:
    for key in kwargs:
      if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
          key not in tf_inspect.getargspec(Function.__init__)[0]):
        msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
               'backend') % key
        raise ValueError(msg)
  return Function(inputs, outputs, updates=updates, **kwargs) 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:25,代码来源:backend.py

示例7: call

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def call(self, inputs, mask=None):
    arguments = self.arguments
    arg_spec = tf_inspect.getargspec(self.function)
    if 'mask' in arg_spec.args:
      arguments['mask'] = mask
    return self.function(inputs, **arguments) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:8,代码来源:core.py

示例8: check_params

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def check_params(self, params):
    """Checks for user typos in "params".

    Arguments:
        params: dictionary; the parameters to be checked

    Raises:
        ValueError: if any member of `params` is not a valid argument.
    """
    legal_params_fns = [
        Sequential.fit, Sequential.predict, Sequential.predict_classes,
        Sequential.evaluate
    ]
    if self.build_fn is None:
      legal_params_fns.append(self.__call__)
    elif (not isinstance(self.build_fn, types.FunctionType) and
          not isinstance(self.build_fn, types.MethodType)):
      legal_params_fns.append(self.build_fn.__call__)
    else:
      legal_params_fns.append(self.build_fn)

    legal_params = []
    for fn in legal_params_fns:
      legal_params += tf_inspect.getargspec(fn)[0]
    legal_params = set(legal_params)

    for params_name in params:
      if params_name not in legal_params:
        if params_name != 'nb_epoch':
          raise ValueError('{} is not a legal parameter'.format(params_name)) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:32,代码来源:scikit_learn.py

示例9: export

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def export(self,
             estimator,
             export_path,
             checkpoint_path=None,
             eval_result=None):
    """Exports the given Estimator to a specific format.

    Args:
      estimator: the Estimator to export.
      export_path: A string containing a directory where to write the export.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the strategy may locate a checkpoint (e.g. the most recent) by itself.
      eval_result: The output of Estimator.evaluate on this checkpoint.  This
        should be set only if checkpoint_path is provided (otherwise it is
        unclear which checkpoint this eval refers to).

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if the export_fn does not have the required signature
    """
    # don't break existing export_fns that don't accept checkpoint_path and
    # eval_result
    export_fn_args = tf_inspect.getargspec(self.export_fn).args
    kwargs = {}
    if 'checkpoint_path' in export_fn_args:
      kwargs['checkpoint_path'] = checkpoint_path
    if 'eval_result' in export_fn_args:
      if 'checkpoint_path' not in export_fn_args:
        raise ValueError('An export_fn accepting eval_result must also accept '
                         'checkpoint_path.')
      kwargs['eval_result'] = eval_result

    return self.export_fn(estimator, export_path, **kwargs) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:37,代码来源:export_strategy.py

示例10: _get_arguments

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def _get_arguments(func):
  """Returns a spec of given func."""
  _, func = tf_decorator.unwrap(func)
  if hasattr(func, "__code__"):
    # Regular function.
    return tf_inspect.getargspec(func)
  elif hasattr(func, "__call__"):
    # Callable object.
    return _get_arguments(func.__call__)
  elif hasattr(func, "func"):
    # Partial function.
    return _get_arguments(func.func) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:14,代码来源:head.py

示例11: data_parallelism_from_flags

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def data_parallelism_from_flags(daisy_chain_variables=True, all_workers=False):
  """Over which devices do we split each training batch.

  In old-fashioned async mode, we split the batch over all GPUs on the
  current worker.

  In sync mode, we split the batch over all the parameter server GPUs.

  This function returns an expert_utils.Parallelism object, which can be used
  to build the model.  It is configured in a way that any variables created
  by `tf.get_variable` will be assigned to the parameter servers and shared
  between datashards.

  Args:
    daisy_chain_variables: whether to copy variables in a daisy chain on GPUs.
    all_workers: whether the devices are all async workers or just this one.

  Returns:
    a expert_utils.Parallelism.
  """
  dp_arg_names = inspect.getargspec(data_parallelism).args

  blacklist = ["daisy_chain_variables", "all_workers"]

  kwargs = {}
  for arg in dp_arg_names:
    if arg in blacklist:
      continue
    kwargs[arg] = getattr(tf.flags.FLAGS, arg)

  return data_parallelism(
      daisy_chain_variables=daisy_chain_variables,
      all_workers=all_workers,
      **kwargs) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:36,代码来源:devices.py

示例12: _get_arg_spec

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def _get_arg_spec(f, params):
  args = tf_inspect.getargspec(f).args
  if params is None:
    if not args:
      raise ValueError("When params is None the differentiated function cannot"
                       " only take arguments by *args and **kwds.")
    return range(len(args))
  elif all(isinstance(x, six.string_types) for x in params):
    return [args.index(n) for n in params]
  elif all(isinstance(x, int) for x in params):
    return params
  else:
    raise ValueError(
        "params must be all strings or all integers; got %s." % params) 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:16,代码来源:backprop.py

示例13: _recompute_grad

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
  """See recompute_grad."""
  has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
  for arg in args:
    if not isinstance(arg, framework_ops.Tensor):
      raise ValueError("All inputs to function must be Tensors")
  use_data_dep_ = use_data_dep
  if use_data_dep_ == _USE_DEFAULT:
    use_data_dep_ = _is_on_tpu()

  # Use custom_gradient and return a grad_fn that recomputes on the backwards
  # pass.
  @custom_gradient.custom_gradient
  def fn_with_recompute(*args):
    """Wrapper for fn."""
    # Capture the variable and arg scopes so we can re-enter them when
    # recomputing.
    vs = variable_scope.get_variable_scope()
    arg_scope = contrib_framework_ops.current_arg_scope()
    # Track all variables touched in the function.
    with backprop.GradientTape() as tape:
      fn_kwargs = {}
      if has_is_recompute_kwarg:
        fn_kwargs["is_recomputing"] = False
      outputs = fn(*args, **fn_kwargs)
    original_vars = set(tape.watched_variables())

    def _grad_fn(output_grads, variables=None):
      # Validate that custom_gradient passes the right variables into grad_fn.
      if original_vars:
        assert variables, ("Fn created variables but the variables were not "
                           "passed to the gradient fn.")
        if set(variables) != original_vars:
          raise ValueError(_WRONG_VARS_ERR)

      return _recomputing_grad_fn(
          compute_fn=fn,
          original_args=args,
          original_vars=original_vars,
          output_grads=output_grads,
          grad_fn_variables=variables,
          use_data_dep=use_data_dep_,
          tupleize_grads=tupleize_grads,
          arg_scope=arg_scope,
          var_scope=vs,
          has_is_recompute_kwarg=has_is_recompute_kwarg)

    # custom_gradient inspects the signature of the function to determine
    # whether the user expects variables passed in the grad_fn. If the function
    # created variables, the grad_fn should accept the "variables" kwarg.
    if original_vars:
      def grad_fn(*output_grads, **kwargs):
        return _grad_fn(output_grads, kwargs["variables"])
    else:
      def grad_fn(*output_grads):
        return _grad_fn(output_grads)

    return outputs, grad_fn

  return fn_with_recompute(*args) 
开发者ID:taehoonlee,项目名称:tensornets,代码行数:62,代码来源:rev_block_lib.py

示例14: __call__

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
    Returns:
      Output tensor(s).
    """
    self._set_scope(kwargs.pop('scope', None))

    # Ensure the Layer, if being reused, is working with inputs from
    # the same graph as where it was created.
    try:
      ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
    except ValueError as e:
      raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    with vs.variable_scope(self._scope,
                           reuse=self.built or self._reuse) as scope:
      with ops.name_scope(scope.original_name_scope):
        if not self.built:
          # Check input assumptions set before layer building, e.g. input rank.
          self._assert_input_compatibility(inputs)
          input_list = [
              ops.convert_to_tensor(x, name='input')
              for x in nest.flatten(inputs)]
          input_shapes = [x.get_shape() for x in input_list]
          if len(input_shapes) == 1:
            self.build(input_shapes[0])
          else:
            self.build(input_shapes)
        if 'scope' in tf_inspect.getargspec(self.call).args:
          kwargs['scope'] = scope
        # Check input assumptions set after layer building, e.g. input shape.
        self._assert_input_compatibility(inputs)
        outputs = self.call(inputs, *args, **kwargs)

        # Apply activity regularization.
        # Note that it should be applied every time the layer creates a new
        # output, since it is output-specific.
        if hasattr(self, 'activity_regularizer') and self.activity_regularizer:
          output_list = _to_list(outputs)
          for output in output_list:
            with ops.name_scope('ActivityRegularizer'):
              activity_regularization = self.activity_regularizer(output)
            self.add_loss(activity_regularization)
            _add_elements_to_collection(
                activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES)

    # Update global default collections.
    _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    self.built = True
    return outputs 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:58,代码来源:base.py

示例15: __call__

# 需要导入模块: from tensorflow.python.util import tf_inspect [as 别名]
# 或者: from tensorflow.python.util.tf_inspect import getargspec [as 别名]
def __call__(self, func):
    # Various sanity checks on the callable func.
    if not callable(func):
      raise ValueError("func %s must be callable" % func)

    # Func should not use kwargs and defaults.
    argspec = tf_inspect.getargspec(func)
    if argspec.keywords or argspec.defaults:
      raise ValueError("Functions with argument defaults or keyword "
                       "arguments are not supported.")

    # Computes how many arguments 'func' has.
    min_args = len(argspec.args)
    max_args = min_args
    if argspec.varargs:
      max_args = 1000000
    argnames = argspec.args
    if tf_inspect.ismethod(func):
      # 1st argument is the "class" type.
      min_args -= 1
      argnames = argnames[1:]

    if self._input_types:
      # If Defun is given a list of types for the inputs, the number
      # of input types should be compatible with 'func'.
      num = len(self._input_types)
      if num < min_args or num > max_args:
        raise ValueError(
            "The function has fewer arguments than the number of specified "
            "input types.")
      return _DefinedFunction(
          func,
          argnames,
          self._input_types,
          self._func_name,
          self._grad_func,
          self._python_grad_func,
          out_names=self._out_names,
          **self._extra_kwargs)

    # 'func' expects no arguments and input types is an empty list.
    if min_args == 0 and max_args == 0:
      return _DefinedFunction(
          func, [], [],
          self._func_name,
          self._grad_func,
          self._python_grad_func,
          out_names=self._out_names,
          **self._extra_kwargs)

    # Input types are unknown. It's an overloaded function and hence
    # its definition needs to be deferred until it's called.
    return _OverloadedFunction(
        func,
        argnames,
        self._func_name,
        self._grad_func,
        self._python_grad_func,
        out_names=self._out_names,
        **self._extra_kwargs) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:62,代码来源:function.py


注:本文中的tensorflow.python.util.tf_inspect.getargspec方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。