L2PViT#

class capymoa.ocl.strategy.l2p.L2PViT[source]#

Bases: ABC

Abstract interface for vision transformer backbones used in L2P.

abstract forward_encoder(
prompts: Tensor,
patch_embed: Tensor,
) Tensor[source]#

Encode the patch embeddings with the given prompts.

Parameters:
  • prompts – A tensor of shape (batch_size, prompt_length, embed_dim)

  • patch_embed – A tensor of shape (batch_size, num_patches + 1, embed_dim)

Returns:

A tensor of shape (batch_size, num_patches + 1 + prompt_length, embed_dim)

abstract forward_query(patch_embed: Tensor) Tensor[source]#

Get the encoded query embedding from the patch embeddings.

Parameters:

patch_embed – A tensor of shape (batch_size, num_patches + 1, embed_dim)

Returns:

A tensor of shape (batch_size, embed_dim)

abstract get_embedding_size() int[source]#

Get the dimension of the patch embeddings.

Returns:

The embedding dimension.

abstract get_patch_embed(pixel_values: Tensor) Tensor[source]#

Turn pixel values into patch embeddings.

Parameters:

pixel_values – A tensor of shape (batch_size, channels, height, width)

Returns:

A tensor of shape (batch_size, num_patches + 1, embed_dim)