L2PViT#
- class capymoa.ocl.strategy.l2p.L2PViT[source]#
Bases:
ABCAbstract interface for vision transformer backbones used in L2P.
- abstract forward_encoder( ) Tensor[source]#
Encode the patch embeddings with the given prompts.
- Parameters:
prompts – A tensor of shape (batch_size, prompt_length, embed_dim)
patch_embed – A tensor of shape (batch_size, num_patches + 1, embed_dim)
- Returns:
A tensor of shape (batch_size, num_patches + 1 + prompt_length, embed_dim)
- abstract forward_query(patch_embed: Tensor) Tensor[source]#
Get the encoded query embedding from the patch embeddings.
- Parameters:
patch_embed – A tensor of shape (batch_size, num_patches + 1, embed_dim)
- Returns:
A tensor of shape (batch_size, embed_dim)