Shortcuts

Source code for quaterion_models.encoders.switch_encoder

from __future__ import annotations

import json
import os
from functools import partial
from typing import Any, Dict, List

import torch
from torch import Tensor

from quaterion_models.encoders import Encoder
from quaterion_models.types import CollateFnType, TensorInterchange
from quaterion_models.utils import move_to_device, restore_class, save_class_import


[docs]def inverse_permutation(perm): inv = torch.empty_like(perm) inv[perm] = torch.arange(perm.size(0), device=perm.device) return inv
[docs]class SwitchEncoder(Encoder): """Allows use alternative embeddings based on input data. For example, train shared embedding representation for images and texts. In this case image encoder should be used if input is an image and text encoder in other case. """
[docs] @classmethod def encoder_selection(cls, record: Any) -> str: """Decide which encoder to use for given record. Args: record: input piece of data Returns: name of the related encoder """ raise NotImplementedError()
def __init__(self, options: Dict[str, Encoder]): super(SwitchEncoder, self).__init__() self.options = options embedding_sizes = set() for key, encoder in self.options.items(): self.add_module(key, encoder) embedding_sizes.add(encoder.embedding_size) if len(embedding_sizes) != 1: raise RuntimeError( f"Alternative encoders have inconsistent output size: {embedding_sizes}" ) self._embedding_size = list(embedding_sizes)[0]
[docs] def disable_gradients_if_required(self): for encoder in self.options.values(): encoder.disable_gradients_if_required()
@property def trainable(self) -> bool: return any(encoder.trainable for encoder in self.options.values()) @property def embedding_size(self) -> int: return self._embedding_size
[docs] @classmethod def switch_collate_fn( cls, batch: List[Any], encoder_collates: Dict[str, CollateFnType] ) -> TensorInterchange: switch_batches = dict((key, []) for key in encoder_collates.keys()) switch_ordering = dict((key, []) for key in encoder_collates.keys()) for original_id, record in enumerate(batch): record_encoder = cls.encoder_selection(record) switch_batches[record_encoder].append(record) switch_ordering[record_encoder].append(original_id) switch_batches = { key: encoder_collates[key](batch) for key, batch in switch_batches.items() if len(batch) > 0 } return {"ordering": switch_ordering, "batches": switch_batches}
[docs] def get_collate_fn(self) -> CollateFnType: return partial( self.__class__.switch_collate_fn, encoder_collates=dict( (key, encoder.get_collate_fn()) for key, encoder in self.options.items() ), )
[docs] @classmethod def extract_meta(cls, batch: List[Any]) -> List[dict]: meta = [] for record in batch: meta.append( { "encoder": cls.encoder_selection(record), } ) return meta
[docs] def forward(self, batch: TensorInterchange) -> Tensor: switch_ordering: dict = batch["ordering"] switch_batches: dict = batch["batches"] embeddings = [] ordering = [] for key, batch in switch_batches.items(): embeddings.append(self.options[key].forward(batch)) ordering += switch_ordering[key] ordering_tensor: Tensor = inverse_permutation(torch.tensor(ordering)) embeddings_tensor: Tensor = torch.cat(embeddings) ordering_tensor = move_to_device(ordering_tensor, embeddings_tensor.device) return embeddings_tensor[ordering_tensor]
[docs] def save(self, output_path: str): encoders = {} for key, encoder in self.options.items(): encoders[key] = save_class_import(encoder) encoder_path = os.path.join(output_path, key) os.makedirs(encoder_path, exist_ok=True) encoder.save(encoder_path) with open(os.path.join(output_path, "config.json"), "w") as f_out: json.dump( { "encoders": encoders, }, f_out, indent=2, )
[docs] @classmethod def load(cls, input_path: str) -> "Encoder": with open(os.path.join(input_path, "config.json")) as f_in: config = json.load(f_in) encoders = {} encoders_params: dict = config["encoders"] for key, class_params in encoders_params.items(): encoder_path = os.path.join(input_path, key) encoder_class = restore_class(class_params) encoders[key] = encoder_class.load(encoder_path) return cls(options=encoders)

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community