Shortcuts

Source code for quaterion_models.model

from __future__ import annotations

import json
import os
from functools import partial
from typing import Any, Callable, Dict, List, Type, Union

import numpy as np
import torch
from torch import nn

from quaterion_models.encoders import Encoder
from quaterion_models.heads.encoder_head import EncoderHead
from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange
from quaterion_models.utils.classes import restore_class, save_class_import
from quaterion_models.utils.meta import merge_meta
from quaterion_models.utils.tensors import move_to_device

DEFAULT_ENCODER_KEY = "default"


[docs]class SimilarityModel(nn.Module): """Main class which contains encoder models with the head layer.""" def __init__(self, encoders: Union[Encoder, Dict[str, Encoder]], head: EncoderHead): super().__init__() if not isinstance(encoders, dict): self.encoders: Dict[str, Encoder] = {DEFAULT_ENCODER_KEY: encoders} else: self.encoders: Dict[str, Encoder] = encoders for key, encoder in self.encoders.items(): encoder.disable_gradients_if_required() self.add_module(key, encoder) self.head = head
[docs] @classmethod def collate_fn( cls, batch: List[dict], encoders_collate_fns: Dict[str, CollateFnType], meta_extractors: Dict[str, MetaExtractorFnType], ) -> TensorInterchange: """Construct batches for all encoders Args: batch: encoders_collate_fns: Dict (or single) of collate functions associated with encoders meta_extractors: Dict (or single) of meta extractor functions associated with encoders """ data = dict( (key, collate_fn(batch)) for key, collate_fn in encoders_collate_fns.items() ) meta = dict( (key, meta_extractor_fn(batch)) for key, meta_extractor_fn in meta_extractors.items() ) return { "data": data, "meta": merge_meta(meta), }
[docs] @classmethod def get_encoders_output_size(cls, encoders: Union[Encoder, Dict[str, Encoder]]): """Calculate total output size of given encoders Args: encoders: """ encoders = encoders.values() if isinstance(encoders, dict) else [encoders] total_size = 0 for encoder in encoders: total_size += encoder.embedding_size return total_size
[docs] def train(self, mode: bool = True): super().train(mode)
[docs] def get_collate_fn(self) -> Callable: """Construct a function to convert input data into neural network inputs Returns: neural network inputs """ return partial( SimilarityModel.collate_fn, encoders_collate_fns=dict( (key, encoder.get_collate_fn()) for key, encoder in self.encoders.items() ), meta_extractors=dict( (key, encoder.get_meta_extractor()) for key, encoder in self.encoders.items() ), )
# ------------------------------------------- # ---------- Inference methods -------------- # -------------------------------------------
[docs] def encode( self, inputs: Union[List[Any], Any], batch_size=32, to_numpy=True ) -> Union[torch.Tensor, np.ndarray]: """Encode data in batches Args: inputs: list of input data to encode batch_size: to_numpy: Returns: Numpy array or torch.Tensor of shape (input_size, embedding_size) """ self.eval() device = next(self.parameters(), torch.tensor(0)).device collate_fn = self.get_collate_fn() input_was_list = True if not isinstance(inputs, list): input_was_list = False inputs = [inputs] all_embeddings = [] for start_index in range(0, len(inputs), batch_size): input_batch = [ inputs[i] for i in range(start_index, min(len(inputs), start_index + batch_size)) ] features = collate_fn(input_batch) features = move_to_device(features, device) with torch.no_grad(): embeddings = self.forward(features) embeddings = embeddings.detach() if to_numpy: embeddings = embeddings.cpu().numpy() all_embeddings.append(embeddings) if to_numpy: all_embeddings = np.concatenate(all_embeddings, axis=0) else: all_embeddings = torch.cat(all_embeddings, dim=0) if not input_was_list: all_embeddings = all_embeddings.squeeze() if to_numpy: all_embeddings = np.atleast_2d(all_embeddings) else: all_embeddings = torch.atleast_2d(all_embeddings) return all_embeddings
[docs] def forward(self, batch): embeddings = [ (key, encoder.forward(batch["data"][key])) for key, encoder in self.encoders.items() ] meta = batch["meta"] # Order embeddings by key name, to ensure reproduction embeddings = sorted(embeddings, key=lambda x: x[0]) # Only embedding tensors of shape [batch_size x encoder_output_size] embedding_tensors = [embedding[1] for embedding in embeddings] # Shape: [batch_size x sum( encoders_emb_sizes )] joined_embeddings = torch.cat(embedding_tensors, dim=1) # Shape: [batch_size x output_emb_size] result_embedding = self.head(joined_embeddings, meta=meta) return result_embedding
# ------------------------------------------- # ---------- Persistence methods ------------ # ------------------------------------------- @classmethod def _get_head_path(cls, directory: str): return os.path.join(directory, "head") @classmethod def _get_encoders_path(cls, directory: str): return os.path.join(directory, "encoders")
[docs] def save(self, output_path: str): head_path = self._get_head_path(output_path) os.makedirs(head_path, exist_ok=True) self.head.save(head_path) head_config = save_class_import(self.head) encoders_path = self._get_encoders_path(output_path) os.makedirs(encoders_path, exist_ok=True) encoders_config = [] for encoder_key, encoder in self.encoders.items(): encoder_path = os.path.join(encoders_path, encoder_key) os.mkdir(encoder_path) encoder.save(encoder_path) encoders_config.append({"key": encoder_key, **save_class_import(encoder)}) with open(os.path.join(output_path, "config.json"), "w") as f_out: json.dump( {"encoders": encoders_config, "head": head_config}, f_out, indent=2 )
[docs] @classmethod def load(cls, input_path: str) -> SimilarityModel: with open(os.path.join(input_path, "config.json")) as f_in: config = json.load(f_in) head_config = config["head"] head_class: Type[EncoderHead] = restore_class(head_config) head_path = cls._get_head_path(input_path) head = head_class.load(head_path) encoders: Union[Encoder, Dict[str, Encoder]] = {} encoders_path = cls._get_encoders_path(input_path) encoders_config = config["encoders"] for encoder_params in encoders_config: encoder_key = encoder_params["key"] encoder_class = restore_class(encoder_params) encoders[encoder_key] = encoder_class.load( os.path.join(encoders_path, encoder_key) ) return cls(head=head, encoders=encoders)
# In this framework, the terms Metric Learning and Similarity Learning are considered synonymous. # However, the word "Metric" overlaps with other concepts in model training. # In addition, the semantics of the word "Similarity" are simpler. # It better reflects the basic idea of this training approach. # That's why we prefer to use Similarity over Metric. MetricModel = SimilarityModel

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community