L2P#

class capymoa.ocl.strategy.l2p.L2P[source]#

Bases: BatchClassifier, TrainTaskAware

Learning to Prompt.

Learning to Prompt (L2P) [1] is a continual learning strategy that leverages a pool of learnable prompts to adapt a pre-trained vision transformer (ViT) to new tasks. For each input, the most relevant prompts are selected from the pool based on the similarity between the input’s embedding and the prompt keys. The selected prompts are then used to condition the ViT, allowing it to effectively learn new tasks while mitigating catastrophic forgetting.

L2P relies on knowledge of the tasks during training to select task-specific prompts but does not require task information during inference.

# Please note this code block is not regularly tested.
from capymoa.ocl.strategy.l2p import L2P
from capymoa.ocl.datasets import SplitCIFAR100
from capymoa.ocl.evaluation import ocl_train_eval_loop
scenario = SplitCIFAR100()
learner = L2P(scenario.schema, scenario.task_mask, device="cuda")
results = ocl_train_eval_loop(
    learner,
    scenario.train_loaders(32),
    scenario.test_loaders(32),
    progress_bar=True
)
print(f"{results.accuracy_final*100:.1f}%")
__init__(
schema: Schema,
task_mask: Tensor,
vit: L2PViT | str = 'facebook/dinov2-small',
prompts_per_task: int = 5,
prompt_length: int = 1,
top_k: int = 3,
pull_constraint_coeff: float = 0.1,
optimizer: Callable[[Any], Optimizer] = lambda params: ...,
device: str = 'cpu',
random_seed: int = 1,
)[source]#

Construct L2P learner.

Parameters:
  • schema – Schema describing the datastream.

  • task_mask – A boolean tensor of shape (num_tasks, num_classes) indicating which classes belong to each task.

  • vit – Vision transformer backbone or the name of a pretrained model from HuggingFace Transformers. Requires transformers to be installed.

  • prompts_per_task – Number of prompts per task in the prompt pool.

  • prompt_length – Length of each prompt (number of tokens/patches).

  • top_k – Number of top prompts to retrieve per query.

  • pull_constraint_coeff – Coefficient for the pull constraint loss term.

  • optimizer – Function that takes model parameters and returns an optimizer instance.

  • device – Device to run the model on, e.g., “cpu” or “cuda”.

  • random_seed – Random seed for reproducibility.

  • logger – Optional logger for tracking training metrics.

batch_predict(x: Tensor) Tensor[source]#

Predict the labels for a batch of instances.

Parameters:

x – Batch of x_dtype valued feature vectors (batch_size, num_features)

Returns:

Predicted batch of y_dtype valued labels (batch_size,).

batch_predict_proba(x: Tensor) Tensor[source]#

Predict the probabilities of the classes for a batch of instances.

Parameters:

x – Batch of x_dtype valued feature vectors (batch_size, num_features)

Returns:

Batch of x_dtype valued predicted probabilities (batch_size, num_classes).

batch_train(x: Tensor, y: Tensor) None[source]#

Train with a batch of instances.

Parameters:
  • x – Batch of x_dtype valued feature vectors (batch_size, num_features)

  • y – Batch of y_dtype valued labels (batch_size,).

on_train_task(task_id: int)[source]#

Called when a new training task starts.

predict(instance: Instance) int | None[source]#

Predict the label of an instance.

The base implementation calls predict_proba() and returns the label with the highest probability.

Parameters:

instance – The instance to predict the label for.

Returns:

The predicted label or None if the classifier is unable to make a prediction.

predict_proba(
instance: Instance,
) ndarray[tuple[Any, ...], dtype[float64]] | None[source]#

Calls batch_predict_proba() with a batch of size 1.

train(instance: LabeledInstance) None[source]#

Calls batch_train() with a batch of size 1.

device: device = device(type='cpu')#

Device on which the batch will be processed.

random_seed: int#

The random seed for reproducibility.

When implementing a classifier ensure random number generators are seeded.

schema: Schema#

The schema representing the instances.

x_dtype: dtype = torch.float32#

Data type for the input features.

y_dtype: dtype = torch.int64#

Data type for the target value/labels.