当前位置: 首页>>代码示例>>Python>>正文


Python faiss.METRIC_L2属性代码示例

本文整理汇总了Python中faiss.METRIC_L2属性的典型用法代码示例。如果您正苦于以下问题:Python faiss.METRIC_L2属性的具体用法?Python faiss.METRIC_L2怎么用?Python faiss.METRIC_L2使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在faiss的用法示例。


在下文中一共展示了faiss.METRIC_L2属性的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _build_approximate_index

# 需要导入模块: import faiss [as 别名]
# 或者: from faiss import METRIC_L2 [as 别名]
def _build_approximate_index(self,
                                     data: np.ndarray):
            dimensionality = data.shape[1]
            nlist = 100 if data.shape[0] > 100 else 2

            if self.kernel_name in {'rbf'}:
                quantizer = faiss.IndexFlatL2(dimensionality)
                cpu_index_flat = faiss.IndexIVFFlat(quantizer, dimensionality, nlist, faiss.METRIC_L2)
            else:
                quantizer = faiss.IndexFlatIP(dimensionality)
                cpu_index_flat = faiss.IndexIVFFlat(quantizer, dimensionality, nlist)

            gpu_index_ivf = faiss.index_cpu_to_gpu(self.resource, 0, cpu_index_flat)
            gpu_index_ivf.train(data)
            gpu_index_ivf.add(data)
            self.index = gpu_index_ivf 
开发者ID:uclnlp,项目名称:gntp,代码行数:18,代码来源:faiss.py

示例2: train_index

# 需要导入模块: import faiss [as 别名]
# 或者: from faiss import METRIC_L2 [as 别名]
def train_index(data, quantizer_path, trained_index_path, fine_quant='SQ8', cuda=False):
    quantizer = faiss.read_index(quantizer_path)
    if fine_quant == 'SQ8':
        trained_index = faiss.IndexIVFScalarQuantizer(quantizer, quantizer.d, quantizer.ntotal, faiss.METRIC_L2)
    elif fine_quant.startswith('PQ'):
        m = int(fine_quant[2:])
        trained_index = faiss.IndexIVFPQ(quantizer, quantizer.d, quantizer.ntotal, m, 8)
    else:
        raise ValueError(fine_quant)

    if cuda:
        if fine_quant.startswith('PQ'):
            print('PQ not supported on GPU; keeping CPU.')
        else:
            res = faiss.StandardGpuResources()
            gpu_index = faiss.index_cpu_to_gpu(res, 0, trained_index)
            gpu_index.train(data)
            trained_index = faiss.index_gpu_to_cpu(gpu_index)
    else:
        trained_index.train(data)
    faiss.write_index(trained_index, trained_index_path) 
开发者ID:uwnlp,项目名称:denspi,代码行数:23,代码来源:run_index.py

示例3: __init__

# 需要导入模块: import faiss [as 别名]
# 或者: from faiss import METRIC_L2 [as 别名]
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
    super(FAISSIndex, self).__init__()
    self.cell_size = cell_size
    self.nr_cells = nr_cells
    self.probes = probes
    self.K = K
    self.num_lists = num_lists
    self.gpu_id = gpu_id

    # BEWARE: if this variable gets deallocated, FAISS crashes
    self.res = res if res else faiss.StandardGpuResources()
    self.res.setTempMemoryFraction(0.01)
    if self.gpu_id != -1:
      self.res.initializeForDevice(self.gpu_id)

    nr_samples = self.nr_cells * 100 * self.cell_size
    train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size)

    self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_L2)
    self.index.setNumProbes(self.probes)
    self.train(train) 
开发者ID:ixaxaar,项目名称:pytorch-dnc,代码行数:23,代码来源:faiss_index.py

示例4: fit

# 需要导入模块: import faiss [as 别名]
# 或者: from faiss import METRIC_L2 [as 别名]
def fit(self, X):
        if self._metric == 'angular':
            X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')

        if X.dtype != numpy.float32:
            X = X.astype(numpy.float32)

        self.quantizer = faiss.IndexFlatL2(X.shape[1])
        index = faiss.IndexIVFFlat(
            self.quantizer, X.shape[1], self._n_list, faiss.METRIC_L2)
        index.train(X)
        index.add(X)
        self.index = index 
开发者ID:erikbern,项目名称:ann-benchmarks,代码行数:15,代码来源:faiss.py

示例5: fit

# 需要导入模块: import faiss [as 别名]
# 或者: from faiss import METRIC_L2 [as 别名]
def fit(self, X):
        X = X.astype(numpy.float32)
        self._index = faiss.GpuIndexIVFFlat(self._res, len(X[0]), self._n_bits,
                                            faiss.METRIC_L2)
        # self._index = faiss.index_factory(len(X[0]),
        #                                   "IVF%d,Flat" % self._n_bits)
        # co = faiss.GpuClonerOptions()
        # co.useFloat16 = True
        # self._index = faiss.index_cpu_to_gpu(self._res, 0,
        #                                      self._index, co)
        self._index.train(X)
        self._index.add(X)
        self._index.setNumProbes(self._n_probes) 
开发者ID:erikbern,项目名称:ann-benchmarks,代码行数:15,代码来源:faiss_gpu.py

示例6: _faiss_knn

# 需要导入模块: import faiss [as 别名]
# 或者: from faiss import METRIC_L2 [as 别名]
def _faiss_knn(keys: torch.Tensor,
               queries: torch.Tensor,
               num_neighbors: int,
               distance: str) -> Tuple[torch.Tensor, torch.Tensor]:
    # https://github.com/facebookresearch/XLM/blob/master/src/model/memory/utils.py
    if not is_faiss_available():
        raise RuntimeError("faiss_knn requires faiss-gpu")
    import faiss

    assert distance in ['dot_product', 'l2']
    assert keys.size(1) == queries.size(1)

    metric = faiss.METRIC_INNER_PRODUCT if distance == 'dot_product' else faiss.METRIC_L2

    k_ptr = _tensor_to_ptr(keys)
    q_ptr = _tensor_to_ptr(queries)

    scores = keys.new_zeros((queries.size(0), num_neighbors), dtype=torch.float32)
    indices = keys.new_zeros((queries.size(0), num_neighbors), dtype=torch.int64)

    s_ptr = _tensor_to_ptr(scores)
    i_ptr = _tensor_to_ptr(indices)

    faiss.bruteForceKnn(FAISS_RES, metric,
                        k_ptr, True, keys.size(0),
                        q_ptr, True, queries.size(0),
                        queries.size(1), num_neighbors, s_ptr, i_ptr)
    return scores, indices 
开发者ID:moskomule,项目名称:homura,代码行数:30,代码来源:knn.py

示例7: execute

# 需要导入模块: import faiss [as 别名]
# 或者: from faiss import METRIC_L2 [as 别名]
def execute(cls, ctx, op):
        (y,), device_id, xp = as_same_device(
            [ctx[op.input.key]], device=op.device, ret_extra=True)
        indexes = [_load_index(ctx, op, ctx[index.key], device_id)
                   for index in op.inputs[1:]]

        with device(device_id):
            y = xp.ascontiguousarray(y, dtype=np.float32)

            if len(indexes) == 1:
                index = indexes[0]
            else:
                index = faiss.IndexShards(indexes[0].d)
                [index.add_shard(ind) for ind in indexes]

            if op.metric == 'cosine':
                # faiss does not support cosine distances directly,
                # data needs to be normalize before searching,
                # refer to:
                # https://github.com/facebookresearch/faiss/wiki/FAQ#how-can-i-index-vectors-for-cosine-distance
                faiss.normalize_L2(y)

            if op.nprobe is not None:
                index.nprobe = op.nprobe

            if device_id >= 0:  # pragma: no cover
                n = y.shape[0]
                k = op.n_neighbors
                distances = xp.empty((n, k), dtype=xp.float32)
                indices = xp.empty((n, k), dtype=xp.int64)
                index.search_c(n, _swig_ptr_from_cupy_float32_array(y),
                               k, _swig_ptr_from_cupy_float32_array(distances),
                               _swig_ptr_from_cupy_int64_array(indices))
            else:
                distances, indices = index.search(y, op.n_neighbors)
            if op.return_distance:
                if index.metric_type == faiss.METRIC_L2:
                    # make it equivalent to `pairwise.euclidean_distances`
                    distances = xp.sqrt(distances, out=distances)
                elif op.metric == 'cosine':
                    # make it equivalent to `pairwise.cosine_distances`
                    distances = xp.subtract(1, distances, out=distances)
                ctx[op.outputs[0].key] = distances
            ctx[op.outputs[-1].key] = indices 
开发者ID:mars-project,项目名称:mars,代码行数:46,代码来源:_faiss.py


注:本文中的faiss.METRIC_L2属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。