Source code for kiez.neighbors.approximate.faiss

import logging
from types import MappingProxyType
from typing import Optional

import numpy as np

from kiez.neighbors.neighbor_algorithm_base import NNAlgorithm

try:
    import faiss
except ImportError:  # pragma: no cover
    faiss = None

try:
    import torch  # noqa: I001
    import faiss.contrib.torch_utils
except ImportError:  # pragma: no cover
    torch = None


[docs]class Faiss(NNAlgorithm): """Wrapper for `faiss` library. Faiss implements a number of (A)NN algorithms and enables the use of GPUs. Parameters ---------- n_candidates: int, default = 5 number of nearest neighbors used in search metric: str, default = 'l2' distance measure used in search possible measures are found in :obj:`Faiss.valid_metrics` Euclidean is the same as l2, expect for taking the sqrt of the result index_key: str, default = None index name to use If none is provided will determine the best automatically Else will use it as input for :meth:`faiss.index_factory` index_param: str, default = None Hyperparameter string for the indexing algorithm See https://github.com/facebookresearch/faiss/wiki/Index-IO,-cloning-and-hyper-parameter-tuning#auto-tuning-the-runtime-parameters for info use_gpu: bool If true uses all available gpus Examples -------- >>> import numpy as np >>> from kiez import Kiez >>> source = np.random.rand(1000, 512) >>> target = np.random.rand(100, 512) >>> k_inst = Kiez(algorithm="Faiss") >>> k_inst.fit(source, target) >>> k_inst = Kiez(algorithm="Faiss",algorithm_kwargs={"metric":"euclidean","index_key":"Flat"}) supply hyperparameters for indexing algorithm >>> k_inst = Kiez(algorithm="Faiss",algorithm_kwargs={"index_key":"HNSW32","index_param":"efSearch=16383"}) Notes ----- For details about configuring faiss consult their wiki: https://github.com/facebookresearch/faiss/wiki """ if torch: _ALLOWED_INPUT_TYPES = (np.ndarray, torch.Tensor) _METRIC_MAP = MappingProxyType({}) # type: ignore [var-annotated] if faiss: _METRIC_MAP = MappingProxyType( { "euclidean": faiss.METRIC_L2, "l2": faiss.METRIC_L2, "l1": faiss.METRIC_L1, "ip": faiss.METRIC_INNER_PRODUCT, "innerproduct": faiss.METRIC_INNER_PRODUCT, "cosine": faiss.METRIC_INNER_PRODUCT, "braycurtis": faiss.METRIC_BrayCurtis, "canberra": faiss.METRIC_Canberra, "jensenshannon": faiss.METRIC_JensenShannon, "chebyshev": faiss.METRIC_Linf, "linf": faiss.METRIC_Linf, } ) valid_metrics = tuple(_METRIC_MAP.keys())
[docs] def __init__( self, n_candidates: int = 5, metric: str = "l2", index_key: str = "Flat", index_param: Optional[str] = None, use_gpu: bool = False, verbose: int = logging.WARNING, ): if faiss is None: # pragma: no cover raise ImportError( "Please install the `faiss` package, before using this class.\nSee here" " for installation instructions:" " https://github.com/facebookresearch/faiss/blob/main/INSTALL.md" ) if metric not in self.__class__.valid_metrics: raise ValueError( f"Unknown metric {metric}, please use one of {self.valid_metrics}" ) super().__init__(n_candidates=n_candidates, metric=metric, n_jobs=None) self.index_key = index_key self.index_param = index_param self.use_gpu = use_gpu self.verbose = verbose self._faiss_metric = self.__class__._METRIC_MAP[metric]
def __repr__(self): return ( f"{self.__class__.__name__}(n_candidates={self.n_candidates}," + f"metric={self.metric}," + f"index_key={self.index_key}," + f"index_param={{{self.index_param}}}," + f"use_gpu={self.use_gpu})" ) def _normalize_if_needed(self, vec): if self.metric == "cosine": # see https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances#how-can-i-index-vectors-for-cosine-similarity if torch and isinstance(vec, torch.Tensor): return torch.nn.functional.normalize(vec) if vec.dtype != "float32": vec = vec.astype("float32") # normalizes in-place! faiss.normalize_L2(vec) return vec def _fit(self, data, is_source: bool): dim = data.shape[1] index = faiss.index_factory(dim, self.index_key, self._faiss_metric) params = faiss.ParameterSpace() if self.use_gpu: index = faiss.index_cpu_to_all_gpus(index) params = faiss.GpuParameterSpace() if self.index_param is not None: params.set_index_parameters(index, self.index_param) data = self._normalize_if_needed(data) index.add(data) return index def _kneighbors(self, query, k, index, return_distance, is_self_querying): if is_self_querying: query = self.source_ query = self._normalize_if_needed(query) dist, ind = index.search(query, k) if return_distance: if self.metric == "euclidean": if torch and isinstance(dist, torch.Tensor): dist = torch.sqrt(dist) else: dist = np.sqrt(dist) return dist, ind return ind