当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.load_op_library方法代码示例

本文整理汇总了Python中tensorflow.load_op_library方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.load_op_library方法的具体用法?Python tensorflow.load_op_library怎么用?Python tensorflow.load_op_library使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.load_op_library方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testShuffle

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def testShuffle(self):
        shuffle_module = tf.load_op_library('shuffle_op.so')
        shuffle = shuffle_module.shuffle

        input_tensor = np.arange(12).reshape((3, 4))
        desired_shape = np.array([6, -1])
        output_tensor = input_tensor.reshape((6, 2))
        with self.test_session():
            result = shuffle(input_tensor, desired_shape)
            self.assertAllEqual(result.eval(), output_tensor)

        input_tensor = np.arange(12).reshape((3, 4))
        desired_shape = np.array([5, -1])
        output_tensor = input_tensor.reshape((6, 2))[:-1]
        with self.test_session():
            result = shuffle(input_tensor, desired_shape)
            self.assertAllEqual(result.eval(), output_tensor) 
开发者ID:jhetherly,项目名称:EnglishSpeechUpsampler,代码行数:19,代码来源:shuffle_op_test.py

示例2: _load_library

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def _load_library(filename, lib="op"):
  """_load_library"""
  f = inspect.getfile(sys._getframe(1)) # pylint: disable=protected-access

  # Construct filename
  f = os.path.join(os.path.dirname(f), filename)
  filenames = [f]

  # Function to load the library, return True if file system library is loaded
  load_fn = tf.load_op_library if lib == "op" \
      else lambda f: tf.compat.v1.load_file_system_library(f) is None

  # Try to load all paths for file, fail if none succeed
  errs = []
  for f in filenames:
    try:
      l = load_fn(f)
      if l is not None:
        return l
    except errors.NotFoundError as e:
      errs.append(str(e))
  raise NotImplementedError(
      "unable to open file: " +
      "{}, from paths: {}\ncaused by: {}".format(filename, filenames, errs)) 
开发者ID:Kaggle,项目名称:docker-python,代码行数:26,代码来源:__init__.py

示例3: register_custom_kernels

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def register_custom_kernels() -> None:
    all_shared_objects = _get_all_shared_objects()
    if not all_shared_objects:
        raise FileNotFoundError(
            "No shared objects files were found in the custom ops "
            "directory in Tensorflow Addons, check your installation again,"
            "or, if you don't need custom ops, call `tfa.register_all(custom_kernels=False)`"
            " instead."
        )
    try:
        for shared_object in all_shared_objects:
            tf.load_op_library(shared_object)
    except tf.errors.NotFoundError as e:
        raise RuntimeError(
            "One of the shared objects ({}) could not be loaded. This may be "
            "due to a number of reasons (incompatible TensorFlow version, buiding from "
            "source with different flags, broken install of TensorFlow Addons...). If you"
            "wanted to register the shared objects because you needed them when loading your "
            "model, you should fix your install of TensorFlow Addons. If you don't "
            "use custom ops in your model, you can skip registering custom ops with "
            "`tfa.register_all(custom_kernels=False)`".format(shared_object)
        ) from e 
开发者ID:tensorflow,项目名称:addons,代码行数:24,代码来源:register.py

示例4: _custom_cpp_op

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def _custom_cpp_op(op: CompilableOp, stateful, name):
    """ Compiles and registers a custom C++ Tensorflow operator """
    # Compile the .so file
    tf_path = os.path.abspath(os.path.dirname(tf.__file__))
    
    so_file = TFCompiler().compile_op(op.name, op.files, 
        op.inputs, op.outputs,
        any([f.endswith('.cu') for f in op.files]), op.live_output,  
        additional_cmake_options=['-DTENSORFLOW_PATH=' + tf_path] + op.cmake_options,
        additional_definitions=op.defs, output_folder=op.output_folder)

    # Load the compiled library into Tensorflow
    op_module = tf.load_op_library(so_file)
    op_func = getattr(op_module, 'tf_op' + op.name)
    op_grad_func = getattr(op_module, 'tf_op_grad' + op.name)
    
    # Create the deep500 custom op object
    lib = ctypes.CDLL(so_file)
    if not getattr(lib, 'create_new_op', False):
        raise ValueError('Invalid custom operator library file')
    lib.create_new_op.restype = ctypes.c_int64
    lib.is_cuda_supported.restype = ctypes.c_bool
    lib.report.restype = ctypes.c_int64

    return TFCompiledOp(op, op_func, op_grad_func, lib) 
开发者ID:deep500,项目名称:deep500,代码行数:27,代码来源:tf.py

示例5: setupCTC

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def setupCTC(self):
        """ Create CTC loss and decoder and return them """
        # BxTxC -> TxBxC
        self.ctcIn3dTBC = tf.transpose(self.rnnOut3d, [1, 0, 2])

        # Ground truth text as sparse tensor
        with tf.name_scope('CTC_Loss'):
            self.gtTexts = tf.SparseTensor(tf.placeholder(tf.int64, shape=[
                                           None, 2]), tf.placeholder(tf.int32, [None]), tf.placeholder(tf.int64, [2]))
            # Calculate loss for batch
            self.seqLen = tf.placeholder(tf.int32, [None])
            self.loss = tf.reduce_mean(tf.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, sequence_length=self.seqLen,
                               ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=True))
        with tf.name_scope('CTC_Decoder'):
            # Decoder: Best path decoding or Word beam search decoding
            if self.decoderType == DecoderType.BestPath:
                self.decoder = tf.nn.ctc_greedy_decoder(
                    inputs=self.ctcIn3dTBC, sequence_length=self.seqLen)
            elif self.decoderType == DecoderType.BeamSearch:
                self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50, merge_repeated=True)
            elif self.decoderType == DecoderType.WordBeamSearch:
                # Import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch)
                word_beam_search_module = tf.load_op_library(
                    './TFWordBeamSearch.so')

                # Prepare: dictionary, characters in dataset, characters forming words
                chars = codecs.open(FilePaths.wordCharList.txt, 'r').read()
                wordChars = codecs.open(
                    FilePaths.fnWordCharList, 'r').read()
                corpus = codecs.open(FilePaths.corpus.txt, 'r').read()

                # # Decoder using the "NGramsForecastAndSample": restrict number of (possible) next words to at most 20 words: O(W) mode of word beam search
                # decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(ctcIn3dTBC, dim=2), 25, 'NGramsForecastAndSample', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))

                # Decoder using the "Words": only use dictionary, no scoring: O(1) mode of word beam search
                self.decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(
                    self.ctcIn3dTBC, dim=2), 25, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))

        # Return a CTC operation to compute the loss and CTC operation to decode the RNN output
        return self.loss, self.decoder 
开发者ID:sushant097,项目名称:Handwritten-Line-Text-Recognition-using-Deep-Learning-with-Tensorflow,代码行数:42,代码来源:Model.py

示例6: ops

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def ops(self):
        if SKIP_CUSTOM_OPS:
            import pytest

            pytest.skip(
                "Skipping the test because a custom ops "
                "was being loaded while --skip-custom-ops was set."
            )
        if self._ops is None:
            self.display_warning_if_incompatible()
            self._ops = tf.load_op_library(get_path_to_datafile(self.relative_path))
        return self._ops 
开发者ID:tensorflow,项目名称:addons,代码行数:14,代码来源:resource_loader.py

示例7: test_get_all_shared_objects

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def test_get_all_shared_objects():
    if resource_loader.SKIP_CUSTOM_OPS:
        pytest.skip(
            "Skipping the test because a custom ops "
            "was being loaded while --skip-custom-ops was set."
        )
    all_shared_objects = _get_all_shared_objects()
    assert len(all_shared_objects) >= 4

    for file in all_shared_objects:
        tf.load_op_library(file) 
开发者ID:tensorflow,项目名称:addons,代码行数:13,代码来源:register_test.py

示例8: f_segm_match

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def f_segm_match(iou, s_gt):
  """Matching between segmentation output and groundtruth.
  Args:
    y_out: [B, T, H, W], output segmentations
    y_gt: [B, T, H, W], groundtruth segmentations
    s_gt: [B, T], groudtruth score sequence
  """
  global hungarian_module
  if hungarian_module is None:
    mod_name = './hungarian.so'
    hungarian_module = tf.load_op_library(mod_name)
    log.info('Loaded library "{}"'.format(mod_name))

  # Mask X, [B, M] => [B, 1, M]
  mask_x = tf.expand_dims(s_gt, dim=1)
  # Mask Y, [B, M] => [B, N, 1]
  mask_y = tf.expand_dims(s_gt, dim=2)
  iou_mask = iou * mask_x * mask_y

  # Keep certain precision so that we can get optimal matching within
  # reasonable time.
  eps = 1e-5
  precision = 1e6
  iou_mask = tf.round(iou_mask * precision) / precision
  match_eps = hungarian_module.hungarian(iou_mask + eps)[0]

  # [1, N, 1, 1]
  s_gt_shape = tf.shape(s_gt)
  num_segm_out = s_gt_shape[1]
  num_segm_out_mul = tf.pack([1, num_segm_out, 1])
  # Mask the graph algorithm output.
  match = match_eps * mask_x * mask_y

  return match 
开发者ID:renmengye,项目名称:rec-attend-public,代码行数:36,代码来源:modellib.py

示例9: find_kaldi_io_library

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def find_kaldi_io_library():
    """Check that libtf_kaldi_io.so can be found. If it can, ensure that
    Tensorflow's tf.load_op_library() can find it by potentially adding it to
    the LD_LIBRARY_PATH as necessary.

    If it is not found, raise a helpful and informative error."""
    try:
        libtf_kaldi_io = resource_filename(__package__, "libtf_kaldi_io.so")
        found = os.path.isfile(libtf_kaldi_io)
    except ImportError:
        # If we can't import tf_kaldi_io, definitely can't get its resources.
        found = False

    if found:
        # If we have a libtf_kaldi_io.so from the tf_kaldi_io Python package,
        # then ensure it gets on the path. We stick it on the front of the
        # path, because it would be confusing if a tf_kaldi_io package used a
        # libtf_kaldi_io.so that didn't correspond to it, just because the user
        # happened to have a custom LD_LIBRARY_PATH set.
        old_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
        lib_dir = os.path.dirname(libtf_kaldi_io)
        os.environ["LD_LIBRARY_PATH"] = lib_dir + ":" + old_ld_library_path

    # Ensure that at this point, no matter what, Tensorflow should be able to
    # load libtf_kaldi_io.so as an op library.
    kaldi_io_lib_paths = find_shared_library("tf_kaldi_io")
    if kaldi_io_lib_paths:
        return kaldi_io_lib_paths["libtf_kaldi_io.so"]
    else:
        raise RuntimeError(MISSING_LIBRARY_ERROR)


# Find the path to the KaldiIO shared library. 
开发者ID:open-speech,项目名称:tf_kaldi_io,代码行数:35,代码来源:__init__.py

示例10: testLoadTwice

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def testLoadTwice(self):
    zero_out_loaded_again = tf.load_op_library(os.path.join(
        tf.resource_loader.get_data_files_path(), 'zero_out_op_kernel_1.so'))
    self.assertEqual(zero_out_loaded_again, zero_out_op_1._zero_out_module) 
开发者ID:tensorflowkorea,项目名称:tensorflow-kr,代码行数:6,代码来源:zero_out_1_test.py

示例11: graph_transform_mpi

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def graph_transform_mpi(single_gpu_meta_graph_def, config,
                        op_library_path=None):
    if op_library_path is not None:
        tf.load_op_library(op_library_path)

    with tf.Graph().as_default() as replica:
        tf.train.import_meta_graph(single_gpu_meta_graph_def)

        tensor_or_op_name_to_replica_names = {}
        for op in replica.get_operations():
            tensor_or_op_name_to_replica_names[op.name] = [op.name]
            for output in op.outputs:
                tensor_or_op_name_to_replica_names[output.name] = [output.name]

        # Initialize horovod
        hvd.init()

        num_workers = hvd.size()
        worker_id = hvd.rank()
        update_shard_values_for_worker(num_workers, worker_id)

        op_to_control_consumer_ops = get_all_control_consumers(replica)
        trainable_variable_ops = [var.op for var in tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES)]

        for gradients_info in tf.get_collection(tf.GraphKeys.GRADIENTS_INFO):
            target_tensor = gradients_info._target
            if target_tensor.op not in trainable_variable_ops:
                parallax_log.debug(
                    "Gradient for non-trainable variable %s is created, ignore"
                    % target_tensor.op.name)
                continue

            _add_aggregation_ops(gradients_info, op_to_control_consumer_ops, config)
        _add_broadcast_ops()

    return tf.train.export_meta_graph(graph=replica), \
           tensor_or_op_name_to_replica_names 
开发者ID:snuspl,项目名称:parallax,代码行数:40,代码来源:graph_transform.py

示例12: testBasic

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def testBasic(self):
    library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
                                    'invalid_op.so')
    with self.assertRaises(tf.errors.InvalidArgumentError):
      tf.load_op_library(library_filename) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:7,代码来源:invalid_op_test.py

示例13: testBasic

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def testBasic(self):
    library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
                                    'duplicate_op.so')
    duplicate = tf.load_op_library(library_filename)

    self.assertEqual(len(duplicate.OP_LIST.op), 0)

    with self.test_session():
      self.assertEqual(tf.add(1, 41).eval(), 42) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:11,代码来源:duplicate_op_test.py

示例14: testBasic

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def testBasic(self):
    library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
                                    'ackermann_op.so')
    ackermann = tf.load_op_library(library_filename)

    self.assertEqual(len(ackermann.OP_LIST.op), 1)
    self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')

    with self.test_session():
      self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)') 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:12,代码来源:ackermann_test.py

示例15: Load

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import load_op_library [as 别名]
def Load():
  """Load the TopN ops library and return the loaded module."""
  with _ops_lock:
    global _topn_ops
    if not _topn_ops:
      ops_path = tf.resource_loader.get_path_to_datafile(TOPN_OPS_FILE)
      tf.logging.info('data path: %s', ops_path)
      _topn_ops = tf.load_op_library(ops_path)

      assert _topn_ops, 'Could not load topn_ops.so'
  return _topn_ops 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:13,代码来源:topn_ops.py


注:本文中的tensorflow.load_op_library方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。