Source code for quaterion_models.model
from __future__ import annotations
import json
import os
from functools import partial
from typing import Any, Callable, Dict, List, Type, Union
import numpy as np
import torch
from torch import nn
from quaterion_models.encoders import Encoder
from quaterion_models.heads.encoder_head import EncoderHead
from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange
from quaterion_models.utils.classes import restore_class, save_class_import
from quaterion_models.utils.meta import merge_meta
from quaterion_models.utils.tensors import move_to_device
DEFAULT_ENCODER_KEY = "default"
[docs]class SimilarityModel(nn.Module):
    """Main class which contains encoder models with the head layer."""
    def __init__(self, encoders: Union[Encoder, Dict[str, Encoder]], head: EncoderHead):
        super().__init__()
        if not isinstance(encoders, dict):
            self.encoders: Dict[str, Encoder] = {DEFAULT_ENCODER_KEY: encoders}
        else:
            self.encoders: Dict[str, Encoder] = encoders
        for key, encoder in self.encoders.items():
            encoder.disable_gradients_if_required()
            self.add_module(key, encoder)
        self.head = head
[docs]    @classmethod
    def collate_fn(
        cls,
        batch: List[dict],
        encoders_collate_fns: Dict[str, CollateFnType],
        meta_extractors: Dict[str, MetaExtractorFnType],
    ) -> TensorInterchange:
        """Construct batches for all encoders
        Args:
            batch:
            encoders_collate_fns: Dict (or single) of collate functions associated with encoders
            meta_extractors: Dict (or single) of meta extractor functions associated with encoders
        """
        data = dict(
            (key, collate_fn(batch)) for key, collate_fn in encoders_collate_fns.items()
        )
        meta = dict(
            (key, meta_extractor_fn(batch))
            for key, meta_extractor_fn in meta_extractors.items()
        )
        return {
            "data": data,
            "meta": merge_meta(meta),
        } 
[docs]    @classmethod
    def get_encoders_output_size(cls, encoders: Union[Encoder, Dict[str, Encoder]]):
        """Calculate total output size of given encoders
        Args:
            encoders:
        """
        encoders = encoders.values() if isinstance(encoders, dict) else [encoders]
        total_size = 0
        for encoder in encoders:
            total_size += encoder.embedding_size
        return total_size 
[docs]    def train(self, mode: bool = True):
        super().train(mode) 
[docs]    def get_collate_fn(self) -> Callable:
        """Construct a function to convert input data into neural network inputs
        Returns:
            neural network inputs
        """
        return partial(
            SimilarityModel.collate_fn,
            encoders_collate_fns=dict(
                (key, encoder.get_collate_fn())
                for key, encoder in self.encoders.items()
            ),
            meta_extractors=dict(
                (key, encoder.get_meta_extractor())
                for key, encoder in self.encoders.items()
            ),
        ) 
    # -------------------------------------------
    # ---------- Inference methods --------------
    # -------------------------------------------
[docs]    def encode(
        self, inputs: Union[List[Any], Any], batch_size=32, to_numpy=True
    ) -> Union[torch.Tensor, np.ndarray]:
        """Encode data in batches
        Args:
            inputs: list of input data to encode
            batch_size:
            to_numpy:
        Returns:
            Numpy array or torch.Tensor of shape (input_size, embedding_size)
        """
        self.eval()
        device = next(self.parameters(), torch.tensor(0)).device
        collate_fn = self.get_collate_fn()
        input_was_list = True
        if not isinstance(inputs, list):
            input_was_list = False
            inputs = [inputs]
        all_embeddings = []
        for start_index in range(0, len(inputs), batch_size):
            input_batch = [
                inputs[i]
                for i in range(start_index, min(len(inputs), start_index + batch_size))
            ]
            features = collate_fn(input_batch)
            features = move_to_device(features, device)
            with torch.no_grad():
                embeddings = self.forward(features)
                embeddings = embeddings.detach()
                if to_numpy:
                    embeddings = embeddings.cpu().numpy()
                all_embeddings.append(embeddings)
        if to_numpy:
            all_embeddings = np.concatenate(all_embeddings, axis=0)
        else:
            all_embeddings = torch.cat(all_embeddings, dim=0)
        if not input_was_list:
            all_embeddings = all_embeddings.squeeze()
        if to_numpy:
            all_embeddings = np.atleast_2d(all_embeddings)
        else:
            all_embeddings = torch.atleast_2d(all_embeddings)
        return all_embeddings 
[docs]    def forward(self, batch):
        embeddings = [
            (key, encoder.forward(batch["data"][key]))
            for key, encoder in self.encoders.items()
        ]
        meta = batch["meta"]
        # Order embeddings by key name, to ensure reproduction
        embeddings = sorted(embeddings, key=lambda x: x[0])
        # Only embedding tensors of shape [batch_size x encoder_output_size]
        embedding_tensors = [embedding[1] for embedding in embeddings]
        # Shape: [batch_size x sum( encoders_emb_sizes )]
        joined_embeddings = torch.cat(embedding_tensors, dim=1)
        # Shape: [batch_size x output_emb_size]
        result_embedding = self.head(joined_embeddings, meta=meta)
        return result_embedding 
    # -------------------------------------------
    # ---------- Persistence methods ------------
    # -------------------------------------------
    @classmethod
    def _get_head_path(cls, directory: str):
        return os.path.join(directory, "head")
    @classmethod
    def _get_encoders_path(cls, directory: str):
        return os.path.join(directory, "encoders")
[docs]    def save(self, output_path: str):
        head_path = self._get_head_path(output_path)
        os.makedirs(head_path, exist_ok=True)
        self.head.save(head_path)
        head_config = save_class_import(self.head)
        encoders_path = self._get_encoders_path(output_path)
        os.makedirs(encoders_path, exist_ok=True)
        encoders_config = []
        for encoder_key, encoder in self.encoders.items():
            encoder_path = os.path.join(encoders_path, encoder_key)
            os.mkdir(encoder_path)
            encoder.save(encoder_path)
            encoders_config.append({"key": encoder_key, **save_class_import(encoder)})
        with open(os.path.join(output_path, "config.json"), "w") as f_out:
            json.dump(
                {"encoders": encoders_config, "head": head_config}, f_out, indent=2
            ) 
[docs]    @classmethod
    def load(cls, input_path: str) -> SimilarityModel:
        with open(os.path.join(input_path, "config.json")) as f_in:
            config = json.load(f_in)
        head_config = config["head"]
        head_class: Type[EncoderHead] = restore_class(head_config)
        head_path = cls._get_head_path(input_path)
        head = head_class.load(head_path)
        encoders: Union[Encoder, Dict[str, Encoder]] = {}
        encoders_path = cls._get_encoders_path(input_path)
        encoders_config = config["encoders"]
        for encoder_params in encoders_config:
            encoder_key = encoder_params["key"]
            encoder_class = restore_class(encoder_params)
            encoders[encoder_key] = encoder_class.load(
                os.path.join(encoders_path, encoder_key)
            )
        return cls(head=head, encoders=encoders)  
# In this framework, the terms Metric Learning and Similarity Learning are considered synonymous.
# However, the word "Metric" overlaps with other concepts in model training.
# In addition, the semantics of the word "Similarity" are simpler.
# It better reflects the basic idea of this training approach.
# That's why we prefer to use Similarity over Metric.
MetricModel = SimilarityModel