otx.algorithms.visual_prompting.tasks.train#

Visual Prompting Task.

Classes

TrainingTask(task_environment[, output_path])

Training Task for Visual Prompting.

class otx.algorithms.visual_prompting.tasks.train.TrainingTask(task_environment: TaskEnvironment, output_path: str | None = None)[source]#

Bases: InferenceTask, ITrainingTask

Training Task for Visual Prompting.

Parameters:
  • dataset (DatasetEntity) – Input dataset.

  • output_model (ModelEntity) – Output model to save the model weights.

  • train_parameters (TrainParameters) – Training parameters

  • seed (Optional[int]) – Setting seed to a value other than 0

  • deterministic (bool) – Setting PytorchLightning trainer’s deterministic flag.

train(dataset: DatasetEntity, output_model: ModelEntity, train_parameters: TrainParameters, seed: int | None = None, deterministic: bool = False) None[source]#

Train a new model using the model currently loaded by the task.

If training was successful, the new model should be used for subsequent calls (e.g. optimize or infer).

The new model weights should be saved in the object output_model.

The task has two choices:

  • Set the output model weights, if the task was able to improve itself (according to own measures)

  • Set the model state as failed if it failed to improve itself (according to own measures)

Parameters:
  • dataset (DatasetEntity) – Dataset containing the training and validation splits to use for training.

  • output_model (ModelEntity) – Output model where the weights should be stored

  • train_parameters (TrainParameters) – Training parameters