Source code for otx.algorithms.visual_prompting.tasks.train
"""Visual Prompting Task."""
# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from typing import Optional
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
TQDMProgressBar,
)
from pytorch_lightning.loggers import CSVLogger
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets import (
OTXVisualPromptingDataModule,
)
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.metrics import Performance, ScoreMetric
from otx.api.entities.model import ModelEntity
from otx.api.entities.train_parameters import TrainParameters
from otx.api.usecases.tasks.interfaces.training_interface import ITrainingTask
from otx.utils.logger import get_logger
from .inference import InferenceTask
logger = get_logger()
[docs]
class TrainingTask(InferenceTask, ITrainingTask):
"""Training Task for Visual Prompting.
Args:
dataset (DatasetEntity): Input dataset.
output_model (ModelEntity): Output model to save the model weights.
train_parameters (TrainParameters): Training parameters
seed (Optional[int]): Setting seed to a value other than 0
deterministic (bool): Setting PytorchLightning trainer's deterministic flag.
"""
[docs]
def train( # noqa: D102
self,
dataset: DatasetEntity,
output_model: ModelEntity,
train_parameters: TrainParameters,
seed: Optional[int] = None,
deterministic: bool = False,
) -> None:
logger.info("Training the model.")
self.seed = seed
self.deterministic = deterministic
self.set_seed()
self.config.trainer.deterministic = "warn" if deterministic else deterministic
logger.info("Training Configs '%s'", self.config)
self.model = self.load_model(otx_model=self.task_environment.model)
datamodule = OTXVisualPromptingDataModule(
config=self.config.dataset, dataset=dataset, train_type=self.train_type
)
loggers = CSVLogger(save_dir=self.output_path, name=".", version=self.timestamp)
callbacks = [
TQDMProgressBar(),
ModelCheckpoint(dirpath=loggers.log_dir, filename="{epoch:02d}", **self.config.callback.checkpoint),
LearningRateMonitor(),
EarlyStopping(**self.config.callback.early_stopping),
]
self.trainer = Trainer(**self.config.trainer, logger=loggers, callbacks=callbacks)
self.trainer.fit(model=self.model, datamodule=datamodule)
model_ckpt = self.trainer.checkpoint_callback.best_model_path
if not model_ckpt:
logger.error("cannot find final checkpoint from the results.")
return
# update checkpoint to the newly trained model
self._model_ckpt = model_ckpt
# compose performance statistics
best_score = self.trainer.checkpoint_callback.best_model_score
if best_score is None:
results = self.trainer.validate(model=self.model, datamodule=datamodule)
best_score = results[0].get(self.config.callback.checkpoint.monitor)
# save resulting model
self.save_model(output_model)
performance = Performance(
score=ScoreMetric(value=best_score, name=self.trainer.checkpoint_callback.monitor)
# TODO (sungchul): dashboard? -> only for Geti
)
logger.info(f"Final model performance: {str(performance)}")
output_model.performance = performance
logger.info("train done.")