Shortcuts

Source code for quaterion_models.encoders.encoder

from __future__ import annotations

from typing import Any, List

from torch import Tensor, nn
from torch.utils.data.dataloader import default_collate

from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange


[docs]class Encoder(nn.Module): """Base class for encoder abstraction""" def __init__(self): super(Encoder, self).__init__()
[docs] def disable_gradients_if_required(self): """Disables gradients of the model if it is declared as not trainable""" if not self.trainable: for _, weights in self.named_parameters(): weights.requires_grad = False
@property def trainable(self) -> bool: """Defines if encoder is trainable. This flag affects caching and checkpoint saving of the encoder. """ raise NotImplementedError() @property def embedding_size(self) -> int: """Size of resulting embedding""" raise NotImplementedError()
[docs] def get_collate_fn(self) -> CollateFnType: """Provides function that converts raw data batch into suitable model input Returns: :const:`~quaterion_models.types.CollateFnType`: model's collate function """ return default_collate
[docs] @classmethod def extract_meta(cls, batch: List[Any]) -> List[dict]: """Extracts meta information from the batch Args: batch: raw batch of data Returns: meta information """ return [{} for _ in batch]
[docs] def get_meta_extractor(self) -> MetaExtractorFnType: return self.extract_meta
[docs] def forward(self, batch: TensorInterchange) -> Tensor: """Infer encoder - convert input batch to embeddings Args: batch: processed batch Returns: embeddings: shape: (batch_size, embedding_size) """ raise NotImplementedError()
[docs] def save(self, output_path: str): """Persist current state to the provided directory Args: output_path: path to save model """ raise NotImplementedError()
[docs] @classmethod def load(cls, input_path: str) -> Encoder: """Instantiate encoder from saved state. If no state required - just call `create` instead Args: input_path: path to load from Returns: :class:`~quaterion_models.encoders.encoder.Encoder`: loaded encoder """ raise NotImplementedError()