Shortcuts

Source code for quaterion_models.heads.softmax_head

from typing import Any, Dict

import torch
from torch.nn import Linear

from quaterion_models.heads.encoder_head import EncoderHead


[docs]class SoftmaxEmbeddingsHead(EncoderHead): """Provides a concatenation of the independent softmax embeddings groups as a head layer Useful for deriving embedding confidence. Schema: .. code-block:: none ┌──────────────────┐ │ Encoder │ └──┬───────────┬───┘ │ │ │ │ ┌───────────┼───────────┼───────────┐ │ │ │ │ │ ┌─────────▼──┐ ┌──▼─────────┐ │ │ │ Linear │ ... │ Linear │ │ │ └─────┬──────┘ └─────┬──────┘ │ │ │ │ │ │ ┌─────┴──────┐ ┌─────┴──────┐ │ │ │ SoftMax │ ... │ SoftMax │ │ │ └─────┬──────┘ └─────┬──────┘ │ │ │ │ │ │ ┌────┴──────────────────┴─────┐ │ │ │ Concatenation │ │ │ └──────────────┬──────────────┘ │ │ │ │ └─────────────────┼─────────────────┘ """ def __init__( self, output_groups: int, output_size_per_group: int, input_embedding_size: int, dropout: float = 0.0, **kwargs ): super().__init__(input_embedding_size, dropout=dropout, **kwargs) self.output_groups = output_groups self.output_size_per_group = output_size_per_group self.projectors = [] self.projection_layer = Linear( self.input_embedding_size, self.output_size_per_group * self.output_groups ) @property def output_size(self) -> int: return self.output_size_per_group * self.output_groups
[docs] def transform(self, input_vectors: torch.Tensor): """ Args: input_vectors: shape: (batch_size, ..., input_dim) Returns: shape (batch_size, ..., self.output_size_per_group * self.output_groups) """ # shape: [batch_size, ..., self.output_size_per_group * self.output_groups] projection = self.projection_layer(input_vectors) init_shape = projection.shape groups_shape = list(init_shape) groups_shape[-1] = self.output_groups groups_shape.append(-1) # shape: [batch_size, ..., self.output_groups, self.output_size_per_group] grouped_projection = torch.softmax(projection.view(*groups_shape), dim=-1) return grouped_projection.view(init_shape)
[docs] def get_config_dict(self) -> Dict[str, Any]: config = super().get_config_dict() config.update( { "output_groups": self.output_groups, "output_size_per_group": self.output_size_per_group, } ) return config

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community