GDumb#

class capymoa.ocl.strategy.GDumb[source]#

Bases: BatchClassifier, TestTaskAware

Greedy sampler and a dumb learner.

Greedy sampler and a dumb learner (GDumb) [1] is a baseline replay strategy. It works by down sampling the dataset and training offline. Since online learners do not have an inference time, GDumb is an offline algorithm, but GDumb remains a useful baseline.

__init__(
schema: Schema,
model: Module,
epochs: int,
batch_size: int,
capacity: int,
lr: float = 0.001,
device: str | device = 'cpu',
seed: int = 0,
)[source]#
batch_predict(x: Tensor) Tensor[source]#

Predict the labels for a batch of instances.

Parameters:

x – Batch of x_dtype valued feature vectors (batch_size, num_features)

Returns:

Predicted batch of y_dtype valued labels (batch_size,).

batch_predict_proba(x: Tensor) Tensor[source]#

Predict the probabilities of the classes for a batch of instances.

Parameters:

x – Batch of x_dtype valued feature vectors (batch_size, num_features)

Returns:

Batch of x_dtype valued predicted probabilities (batch_size, num_classes).

batch_train(x: Tensor, y: Tensor) None[source]#

Train with a batch of instances.

Parameters:
  • x – Batch of x_dtype valued feature vectors (batch_size, num_features)

  • y – Batch of y_dtype valued labels (batch_size,).

gdumb_fit() None[source]#

Fit the model on the coreset.

on_test_task(task_id: int) None[source]#

Called when testing on a task starts.

predict(instance: Instance) int | None[source]#

Predict the label of an instance.

The base implementation calls predict_proba() and returns the label with the highest probability.

Parameters:

instance – The instance to predict the label for.

Returns:

The predicted label or None if the classifier is unable to make a prediction.

predict_proba(
instance: Instance,
) ndarray[Any, dtype[float64]] | None[source]#

Calls batch_predict_proba() with a batch of size 1.

train(instance: LabeledInstance) None[source]#

Calls batch_train() with a batch of size 1.

device: device = device(type='cpu')#

Device on which the batch will be processed.

random_seed: int#

The random seed for reproducibility.

When implementing a classifier ensure random number generators are seeded.

schema: Schema#

The schema representing the instances.

x_dtype: dtype = torch.float32[source]#

Data type for the input features.

y_dtype: dtype = torch.int64[source]#

Data type for the target value/labels.