quaterion_models.encoders.switch_encoder module¶
- class SwitchEncoder(options: Dict[str, Encoder])[source]¶
Bases:
Encoder
Allows use alternative embeddings based on input data.
For example, train shared embedding representation for images and texts. In this case image encoder should be used if input is an image and text encoder in other case.
- disable_gradients_if_required()[source]¶
Disables gradients of the model if it is declared as not trainable
- classmethod encoder_selection(record: Any) str [source]¶
Decide which encoder to use for given record.
- Parameters:
record – input piece of data
- Returns:
name of the related encoder
- classmethod extract_meta(batch: List[Any]) List[dict] [source]¶
Extracts meta information from the batch
- Parameters:
batch – raw batch of data
- Returns:
meta information
- forward(batch: TensorInterchange) Tensor [source]¶
Infer encoder - convert input batch to embeddings
- Parameters:
batch – processed batch
- Returns:
embeddings – shape: (batch_size, embedding_size)
- get_collate_fn() CollateFnType [source]¶
Provides function that converts raw data batch into suitable model input
- Returns:
CollateFnType
– model’s collate function
- classmethod load(input_path: str) Encoder [source]¶
Instantiate encoder from saved state.
If no state required - just call create instead
- Parameters:
input_path – path to load from
- Returns:
Encoder
– loaded encoder
- save(output_path: str)[source]¶
Persist current state to the provided directory
- Parameters:
output_path – path to save model
- classmethod switch_collate_fn(batch: List[Any], encoder_collates: Dict[str, CollateFnType]) TensorInterchange [source]¶
- property embedding_size: int¶
Size of resulting embedding
- property trainable: bool¶
Defines if encoder is trainable.
This flag affects caching and checkpoint saving of the encoder.
- training: bool¶