Source code for quaterion_models.heads.widening_head

from typing import Any, Dict

from quaterion_models.heads.stacked_projection_head import StackedProjectionHead

[docs]class WideningHead(StackedProjectionHead): """Implements narrow-wide-narrow architecture. Widen the dimensionality by a factor of `expansion_factor` and narrow it down back to `input_embedding_size`. Args: input_embedding_size: Dimensionality of the input to this head layer. expansion_factor: Widen the dimensionality by this factor in the intermediate layer. activation_fn: Name of the activation function to apply after the intermediate layer. Must be an attribute of `torch.nn.functional` and defaults to `relu`. dropout: Probability of Dropout. If `dropout > 0.`, apply dropout layer on embeddings before applying head layer transformations """ def __init__( self, input_embedding_size: int, expansion_factor: float = 4.0, activation_fn: str = "relu", dropout: float = 0.0, **kwargs ): self._expansion_factor = expansion_factor self._activation_fn = activation_fn super(WideningHead, self).__init__( input_embedding_size=input_embedding_size, output_sizes=[ int(input_embedding_size * expansion_factor), input_embedding_size, ], activation_fn=activation_fn, dropout=dropout, )
[docs] def get_config_dict(self) -> Dict[str, Any]: config = super().get_config_dict() config.update( { "expansion_factor": self._expansion_factor, } ) return config


Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning


Find people dealing with similar problems and get answers to your questions

Join Community