quaterion_models.encoders.encoder module¶
- class Encoder[source]¶
Bases:
Module
Base class for encoder abstraction
- disable_gradients_if_required()[source]¶
Disables gradients of the model if it is declared as not trainable
- classmethod extract_meta(batch: List[Any]) List[dict] [source]¶
Extracts meta information from the batch
- Parameters:
batch – raw batch of data
- Returns:
meta information
- forward(batch: TensorInterchange) Tensor [source]¶
Infer encoder - convert input batch to embeddings
- Parameters:
batch – processed batch
- Returns:
embeddings – shape: (batch_size, embedding_size)
- get_collate_fn() CollateFnType [source]¶
Provides function that converts raw data batch into suitable model input
- Returns:
CollateFnType
– model’s collate function
- classmethod load(input_path: str) Encoder [source]¶
Instantiate encoder from saved state.
If no state required - just call create instead
- Parameters:
input_path – path to load from
- Returns:
Encoder
– loaded encoder
- save(output_path: str)[source]¶
Persist current state to the provided directory
- Parameters:
output_path – path to save model
- property embedding_size: int¶
Size of resulting embedding
- property trainable: bool¶
Defines if encoder is trainable.
This flag affects caching and checkpoint saving of the encoder.
- training: bool¶