quaterion_models.model module¶
- MetricModel¶
alias of
SimilarityModel
- class SimilarityModel(encoders: Union[Encoder, Dict[str, Encoder]], head: EncoderHead)[source]¶
Bases:
Module
Main class which contains encoder models with the head layer.
- classmethod collate_fn(batch: List[dict], encoders_collate_fns: Dict[str, CollateFnType], meta_extractors: Dict[str, Callable[[List[Any]], List[dict]]]) TensorInterchange [source]¶
Construct batches for all encoders
- Parameters:
batch –
encoders_collate_fns – Dict (or single) of collate functions associated with encoders
meta_extractors – Dict (or single) of meta extractor functions associated with encoders
- encode(inputs: Union[List[Any], Any], batch_size=32, to_numpy=True) Union[Tensor, ndarray] [source]¶
Encode data in batches
- Parameters:
inputs – list of input data to encode
batch_size –
to_numpy –
- Returns:
Numpy array or torch.Tensor of shape (input_size, embedding_size)
- forward(batch)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_collate_fn() Callable [source]¶
Construct a function to convert input data into neural network inputs
- Returns:
neural network inputs
- classmethod get_encoders_output_size(encoders: Union[Encoder, Dict[str, Encoder]])[source]¶
Calculate total output size of given encoders
- Parameters:
encoders –
- classmethod load(input_path: str) SimilarityModel [source]¶
- train(mode: bool = True)[source]¶
Sets the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout
,BatchNorm
, etc.- Parameters:
mode (bool) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.- Returns:
Module – self
- training: bool¶