BatchClassifierSSL#

class capymoa.ssl.classifier.BatchClassifierSSL[source]#

Bases: ClassifierSSL, ABC

__init__(
batch_size: int,
schema: Schema = None,
random_seed=1,
)[source]#
abstract train_on_batch(
x_batch: ndarray[Any, dtype[float64]],
y_indices: ndarray[Any, dtype[int64]],
)[source]#

Train the model on a batch of instances. Some of the instances may be unlabeled, this is coded as a -1 in the y array.

Parameters:
  • x_batch – Batched instances of shape (batch_size, num_attributes)

  • y_indices – Batched label vector of shape (batch_size,)

train(instance: LabeledInstance)[source]#

Add an instance to the batch and train the model if the batch is full.

train_on_unlabeled(instance: Instance)[source]#

Add an unlabeled instance to the batch and train the model if the batch is full.

abstract predict(instance: Instance) int | None[source]#
abstract predict_proba(
instance: Instance,
) ndarray[Any, dtype[float64]][source]#