quaterion_models.encoders.encoder module
- class Encoder[source]
Bases:
Module
Base class for encoder abstraction
- disable_gradients_if_required()[source]
Disables gradients of the model if it is declared as not trainable
- classmethod extract_meta(batch: List[Any]) List[dict] [source]
Extracts meta information from the batch
- Parameters:
batch – raw batch of data
- Returns:
meta information
- forward(batch: TensorInterchange) Tensor [source]
Infer encoder - convert input batch to embeddings
- Parameters:
batch – processed batch
- Returns:
embeddings – shape: (batch_size, embedding_size)
- get_collate_fn() CollateFnType [source]
Provides function that converts raw data batch into suitable model input
- Returns:
CollateFnType
– model’s collate function
- get_meta_extractor() Callable[[List[Any]], List[dict]] [source]
- classmethod load(input_path: str) Encoder [source]
Instantiate encoder from saved state.
If no state required - just call create instead
- Parameters:
input_path – path to load from
- Returns:
Encoder
– loaded encoder
- save(output_path: str)[source]
Persist current state to the provided directory
- Parameters:
output_path – path to save model