Visual Prompting (Fine-tuning)#

Visual prompting is a computer vision task that uses a combination of an image and prompts, such as texts, bounding boxes, points, and so on to troubleshoot problems. Using these useful prompts, the main purpose of this task is to obtain labels from unlabeled datasets, and to use generated label information on particular domains or to develop a new model with the generated information.

This section examines the solutions for visual prompting offered by the OpenVINO Training Extensions library. Segment Anything (SAM), is one of the most famous visual prompting methods and this model will be used to adapt a new dataset domain. Because SAM was trained by using web-scale dataset and has huge backbone network, fine-tuning the whole network is difficult and lots of resources are required. Therefore, in this section, we try to fine-tune only mask decoder only for several epochs to increase performance on the new dataset domain. For fine-tuning SAM, we use following algorithms components:

  • Pre-processing: Resize an image according to the longest axis and pad the rest with zero.

  • Optimizer: We use Adam optimizer.

  • Loss function: We use standard loss combination, 20 * focal loss + dice loss + iou loss, used in SAM as it is.

  • Additional training techniques
    • Early stopping: To add adaptability to the training pipeline and prevent overfitting. Early stopping will be automatically applied.

Note

Currently, fine-tuning SAM with bounding boxes in the OpenVINO Training Extensions is only supported. We will support fine-tuning with other prompts (points and texts) and continuous fine-tuning with predicted mask information in the near future.

Note

Currently, Post-Training Quantization (PTQ) for SAM is only supported, not Quantization Aware Training (QAT).

Dataset Format#

For the dataset handling inside OpenVINO™ Training Extensions, we use Dataset Management Framework (Datumaro).

We support three dataset formats for visual prompting:

If you organized supported dataset format, starting training will be very simple. We just need to pass a path to the root folder and desired model template to start training:

$ otx train <model_template> \
    --train-data-roots <path_to_data_root> \
    --val-data-roots <path_to_data_root>

Note

During training, mDice for binary mask without label information is used for train/validation metric. After training, if using otx eval to evaluate performance, mDice for binary or multi-class masks with label information will be used. As you can expect, performance will be different between otx train and otx eval, but if unlabeled mask performance is high, labeld mask performance is high as well.

Models#

We support the following model templates in experimental phase:

Template ID

Name

Complexity (GFLOPs)

Model size (MB)

Visual_Prompting_SAM_Tiny_ViT

SAM_Tiny_ViT

38.95

47

Visual_Prompting_SAM_ViT_B

SAM_ViT_B

483.71

362

To check feasibility of SAM, we did experiments using three public datasets with each other domains: WGISD, Trashcan, and FLARE22, and checked Dice score. We used sampled training data from Trashcan and FLARE22, and full training data (=110) from WGISD.

Dataset

#samples

WGISD

110

Trashcan

500

FLARE22

1 CT (=100 slices)

The below table shows performance improvement after fine-tuning.

Model name

WGISD

Trashcan

FLARE22

Tiny_ViT

90.32 → 92.29 (+1.97)

82.38 → 85.01 (+2.63)

89.69 → 93.05 (+3.36)

ViT_B

92.32 → 92.46 (+0.14)

79.61 → 81.50 (+1.89)

91.48 → 91.68 (+0.20)

According to datasets, learning rate and batch size can be adjusted like below:

$ otx train <model_template> \
    --train-data-roots <path_to_data_root> \
    --val-data-roots <path_to_data_root> \
    params \
    --learning_parameters.dataset.train_batch_size <batch_size_to_be_updated> \
    --learning_parameters.optimizer.lr <learning_rate_to_be_updated>