otx.algorithms.visual_prompting.tasks.train#
Visual Prompting Task.
Classes
|
Training Task for Visual Prompting. |
- class otx.algorithms.visual_prompting.tasks.train.TrainingTask(task_environment: TaskEnvironment, output_path: str | None = None)[source]#
Bases:
InferenceTask
,ITrainingTask
Training Task for Visual Prompting.
- Parameters:
dataset (DatasetEntity) – Input dataset.
output_model (ModelEntity) – Output model to save the model weights.
train_parameters (TrainParameters) – Training parameters
seed (Optional[int]) – Setting seed to a value other than 0
deterministic (bool) – Setting PytorchLightning trainer’s deterministic flag.
- train(dataset: DatasetEntity, output_model: ModelEntity, train_parameters: TrainParameters, seed: int | None = None, deterministic: bool = False) None [source]#
Train a new model using the model currently loaded by the task.
If training was successful, the new model should be used for subsequent calls (e.g. optimize or infer).
The new model weights should be saved in the object output_model.
The task has two choices:
Set the output model weights, if the task was able to improve itself (according to own measures)
Set the model state as failed if it failed to improve itself (according to own measures)
- Parameters:
dataset (DatasetEntity) – Dataset containing the training and validation splits to use for training.
output_model (ModelEntity) – Output model where the weights should be stored
train_parameters (TrainParameters) – Training parameters