Shortcuts

Source code for quaterion_models.encoders.encoder

from __future__ import annotations

from typing import Any, List

from torch import Tensor, nn
from torch.utils.data.dataloader import default_collate

from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange


[docs]class Encoder(nn.Module): """Base class for encoder abstraction""" def __init__(self): super(Encoder, self).__init__()
[docs] def disable_gradients_if_required(self): """Disables gradients of the model if it is declared as not trainable""" if not self.trainable: for _, weights in self.named_parameters(): weights.requires_grad = False
@property def trainable(self) -> bool: """Defines if encoder is trainable. This flag affects caching and checkpoint saving of the encoder. """ raise NotImplementedError() @property def embedding_size(self) -> int: """Size of resulting embedding""" raise NotImplementedError()
[docs] def get_collate_fn(self) -> CollateFnType: """Provides function that converts raw data batch into suitable model input Returns: :const:`~quaterion_models.types.CollateFnType`: model's collate function """ return default_collate
[docs] @classmethod def extract_meta(cls, batch: List[Any]) -> List[dict]: """Extracts meta information from the batch Args: batch: raw batch of data Returns: meta information """ return [{} for _ in batch]
[docs] def get_meta_extractor(self) -> MetaExtractorFnType: return self.extract_meta
[docs] def forward(self, batch: TensorInterchange) -> Tensor: """Infer encoder - convert input batch to embeddings Args: batch: processed batch Returns: embeddings: shape: (batch_size, embedding_size) """ raise NotImplementedError()
[docs] def save(self, output_path: str): """Persist current state to the provided directory Args: output_path: path to save model """ raise NotImplementedError()
[docs] @classmethod def load(cls, input_path: str) -> Encoder: """Instantiate encoder from saved state. If no state required - just call `create` instead Args: input_path: path to load from Returns: :class:`~quaterion_models.encoders.encoder.Encoder`: loaded encoder """ raise NotImplementedError()

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community