PatchCore¶
This is the implementation of the PatchCore paper.
Model Type: Segmentation
Description¶
The PatchCore algorithm is based on the idea that an image can be classified as anomalous as soon as a single patch is anomalous. The input image is tiled. These tiles act as patches which are fed into the neural network. It consists of a single pre-trained network which is used to extract “mid” level features patches. The “mid” level here refers to the feature extraction layer of the neural network model. Lower level features are generally too broad and higher level features are specific to the dataset the model is trained on. The features extracted during training phase are stored in a memory bank of neighbourhood aware patch level features.
During inference this memory bank is coreset subsampled. Coreset subsampling generates a subset which best approximates the structure of the available set and allows for approximate solution finding. This subset helps reduce the search cost associated with nearest neighbour search. The anomaly score is taken as the maximum distance between the test patch in the test patch collection to each respective nearest neighbour.
Architecture¶

Usage¶
$ python tools/train.py --model patchcore
PyTorch model for the PatchCore model implementation.
- class anomalib.models.patchcore.torch_model.PatchcoreModel(input_size: tuple[int, int], layers: list[str], backbone: str = 'wide_resnet50_2', pre_trained: bool = True, num_neighbors: int = 9)[source]¶
Bases:
DynamicBufferModule
,Module
Patchcore Module.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute_anomaly_score(patch_scores: Tensor, locations: Tensor, embedding: Tensor) Tensor [source]¶
Compute Image-Level Anomaly Score.
- Parameters:
patch_scores (Tensor) – Patch-level anomaly scores
locations – Memory bank locations of the nearest neighbor for each patch location
embedding – The feature embeddings that generated the patch scores
- Returns:
Image-level anomaly scores
- Return type:
Tensor
- forward(input_tensor: Tensor) Tensor | tuple[Tensor, Tensor] [source]¶
Return Embedding during training, or a tuple of anomaly map and anomaly score during testing.
Steps performed: 1. Get features from a CNN. 2. Generate embedding based on the features. 3. Compute anomaly map in test mode.
- Parameters:
input_tensor (Tensor) – Input tensor
- Returns:
- Embedding for training,
anomaly map and anomaly score for testing.
- Return type:
Tensor | tuple[Tensor, Tensor]
- generate_embedding(features: dict[str, Tensor]) Tensor [source]¶
Generate embedding from hierarchical feature map.
- Parameters:
features – Hierarchical feature map from a CNN (ResNet18 or WideResnet)
features – dict[str:Tensor]:
- Returns:
Embedding vector
- nearest_neighbors(embedding: Tensor, n_neighbors: int) tuple[Tensor, Tensor] [source]¶
Nearest Neighbours using brute force method and euclidean norm.
- Parameters:
embedding (Tensor) – Features to compare the distance with the memory bank.
n_neighbors (int) – Number of neighbors to look at
- Returns:
Patch scores. Tensor: Locations of the nearest neighbor(s).
- Return type:
Tensor
- static reshape_embedding(embedding: Tensor) Tensor [source]¶
Reshape Embedding.
Reshapes Embedding to the following format: [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding]
- Parameters:
embedding (Tensor) – Embedding tensor extracted from CNN features.
- Returns:
Reshaped embedding tensor.
- Return type:
Tensor
- subsample_embedding(embedding: Tensor, sampling_ratio: float) None [source]¶
Subsample embedding based on coreset sampling and store to memory.
- Parameters:
embedding (np.ndarray) – Embedding tensor from the CNN
sampling_ratio (float) – Coreset sampling ratio
- training: bool¶
Towards Total Recall in Industrial Anomaly Detection.
Paper https://arxiv.org/abs/2106.08265.
- class anomalib.models.patchcore.lightning_model.Patchcore(input_size: tuple[int, int], backbone: str, layers: list[str], pre_trained: bool = True, coreset_sampling_ratio: float = 0.1, num_neighbors: int = 9)[source]¶
Bases:
AnomalyModule
PatchcoreLightning Module to train PatchCore algorithm.
- Parameters:
input_size (tuple[int, int]) – Size of the model input.
backbone (str) – Backbone CNN network
layers (list[str]) – Layers to extract features from the backbone CNN
pre_trained (bool, optional) – Boolean to check whether to use a pre_trained backbone.
coreset_sampling_ratio (float, optional) – Coreset sampling ratio to subsample embedding. Defaults to 0.1.
num_neighbors (int, optional) – Number of nearest neighbors. Defaults to 9.
- configure_optimizers() None [source]¶
Configure optimizers.
- Returns:
Do not set optimizers by returning None.
- Return type:
None
- on_validation_start() None [source]¶
Apply subsampling to the embedding collected from the training set.
- training_step(batch: dict[str, str | Tensor], *args, **kwargs) None [source]¶
Generate feature embedding of the batch.
- Parameters:
batch (dict[str, str | Tensor]) – Batch containing image filename, image, label and mask
- Returns:
Embedding Vector
- Return type:
dict[str, np.ndarray]
- validation_step(batch: dict[str, str | Tensor], *args, **kwargs) STEP_OUTPUT [source]¶
Get batch of anomaly maps from input image batch.
- Parameters:
batch (dict[str, str | Tensor]) – Batch containing image filename, image, label and mask
- Returns:
Image filenames, test images, GT and predicted label/masks
- Return type:
dict[str, Any]
- class anomalib.models.patchcore.lightning_model.PatchcoreLightning(hparams)[source]¶
Bases:
Patchcore
PatchcoreLightning Module to train PatchCore algorithm.
- Parameters:
hparams (DictConfig | ListConfig) – Model params
Anomaly Map Generator for the PatchCore model implementation.
- class anomalib.models.patchcore.anomaly_map.AnomalyMapGenerator(input_size: ListConfig | tuple, sigma: int = 4)[source]¶
Bases:
Module
Generate Anomaly Heatmap.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute_anomaly_map(patch_scores: Tensor) Tensor [source]¶
Pixel Level Anomaly Heatmap.
- Parameters:
patch_scores (Tensor) – Patch-level anomaly scores
- Returns:
Map of the pixel-level anomaly scores
- Return type:
Tensor
- forward(patch_scores: Tensor) Tensor [source]¶
Returns anomaly_map and anomaly_score.
- Parameters:
patch_scores (Tensor) – Patch-level anomaly scores
Example >>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) >>> map = anomaly_map_generator(patch_scores=patch_scores)
- Returns:
anomaly_map
- Return type:
Tensor
- training: bool¶