Source code for otx.api.usecases.tasks.interfaces.training_interface
"""This module contains the interface class for tasks that can perform training."""
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import abc
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import ModelEntity
from otx.api.entities.train_parameters import TrainParameters
[docs]
class ITrainingTask(metaclass=abc.ABCMeta):
"""A base interface class for tasks which can perform training."""
[docs]
@abc.abstractmethod
def save_model(self, output_model: ModelEntity):
"""Save the model currently loaded by the task to `output_model`.
This method is for instance used to save the pre-trained weights before training
when the task has been initialised with pre-trained weights rather than an existing model.
Args:
output_model (ModelEntity): Output model where the weights should be stored
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def train(
self,
dataset: DatasetEntity,
output_model: ModelEntity,
train_parameters: TrainParameters,
):
"""Train a new model using the model currently loaded by the task.
If training was successful, the new model should be used for subsequent calls (e.g. `optimize` or `infer`).
The new model weights should be saved in the object `output_model`.
The task has two choices:
- Set the output model weights, if the task was able to improve itself (according to own measures)
- Set the model state as failed if it failed to improve itself (according to own measures)
Args:
dataset (DatasetEntity): Dataset containing the training and validation splits to use for training.
output_model (ModelEntity): Output model where the weights should be stored
train_parameters (TrainParameters): Training parameters
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def cancel_training(self):
"""Cancels the currently running training process.
If training is not running, do nothing.
"""
raise NotImplementedError