Source code for core.preprocessing.tissuemask
import logging
import numpy as np
import torch
import shutil
import cv2
import os
from skimage import morphology, measure
from scipy import ndimage
from vision_agent.tools import florence2_sam2_instance_segmentation
from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor
from tiatoolbox.models.architecture.unet import UNetModel
from pillow_heif import register_heif_opener
import core.preprocessing.stainnorm as stainnorm
register_heif_opener()
logger = logging.getLogger(__name__)
[docs]
class FlorenceTissueMaskExtractor:
def __init__(self, unet_model_path: str = "", unet_device: str = "cuda"):
# Define default and fallback prompts
self.default_prompt = "tissue,stain"
self.backup_prompts = ["tissue,stain", "tissue", "cell,tissue", "histology"]
self.unet_model_path = unet_model_path
self.unet_device = unet_device
[docs]
def extract(self, image: np.ndarray, artefacts: bool) -> np.ndarray:
"""
Extracts the tissue mask from an image using instance segmentation or fallback methods.
Extraction order:
1. Florence-2 + SAM2 prompt-based instance segmentation.
2. If that fails and a UNet model path is provided, the UNet extractor.
3. Final fallback: Otsu threshold with morphological cleanup.
Args:
image (np.ndarray): Input RGB image.
artefacts (bool): When True, return only the first (largest) segment mask
so that control tissue artefacts are isolated.
Returns:
np.ndarray: Binary tissue mask (uint8, values 0 or 255).
"""
# Try instance segmentation first
segments = self._segment_with_prompts(image, self.default_prompt)
if not segments:
for prompt in self.backup_prompts:
segments = self._segment_with_prompts(image, prompt)
if segments:
break
else:
stain = stainnorm.StainNormalizer()
norm, h, e = stain.process(image)
segments = self._segment_with_prompts(norm, prompt)
if artefacts:
if segments:
return (segments[0]['mask'] * 255).astype(np.uint8)
# No segments found for artefact extraction — fall through to UNet / fallback
else:
if segments:
combined_mask = np.zeros_like(segments[0]['mask'], dtype=np.uint8)
for segment in segments:
combined_mask = np.maximum(combined_mask, (segment['mask'] * 255).astype(np.uint8))
return combined_mask
# Try UNet-based extraction if a model path has been provided
if self.unet_model_path:
return self._unet_mask(image)
# Final fallback
return self._fallback_mask(image)
@staticmethod
def _segment_with_prompts(image: np.ndarray, prompt: str):
try:
return florence2_sam2_instance_segmentation(prompt, image)
except Exception:
return []
def _unet_mask(self, image: np.ndarray) -> np.ndarray:
"""Extract tissue mask using the UNet model, falling back to Otsu on error."""
try:
extractor = UNetTissueMaskExtractor(
model_path=self.unet_model_path,
device=self.unet_device,
)
mask = extractor.extract_masks(image)
if mask is not None:
return mask
except Exception as exc:
logger.warning("UNet tissue mask extraction failed: %s", exc)
return self._fallback_mask(image)
def _fallback_mask(self, image: np.ndarray) -> np.ndarray:
"""Fallback method using Otsu threshold and morphology."""
logger.info("Applying fallback tissue mask extraction.")
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
_, threshold_mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Morphological operations to clean up the mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(threshold_mask, cv2.MORPH_OPEN, kernel, iterations=1)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
mask_binary = (mask > 0).astype(np.uint8)
# Invert to match tissue as foreground
mask_binary = 1 - mask_binary
# Extract largest connected component
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_binary, connectivity=8)
if num_labels <= 1:
return mask_binary
largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
largest_component_mask = (labels == largest_label).astype(np.uint8)
return largest_component_mask
[docs]
class UNetTissueMaskExtractor:
[docs]
def __init__(self, model_path: str, device: str = "cuda"):
"""
Args:
model_path (str): Path to the pretrained UNet checkpoint.
device (str): 'cuda' or 'cpu'.
"""
self.device = device
self.model_path = model_path
self.model = self._load_model()
[docs]
@staticmethod
def convert_pytorch_checkpoint(net_state_dict):
"""Convert checkpoint from DataParallel to single-GPU format."""
variable_name_list = list(net_state_dict.keys())
is_in_parallel_mode = all(v.split(".")[0] == "module" for v in variable_name_list)
if is_in_parallel_mode:
net_state_dict = {
".".join(k.split(".")[1:]): v for k, v in net_state_dict.items()
}
return net_state_dict
[docs]
@staticmethod
def post_processing_mask(mask: np.ndarray) -> np.ndarray:
"""Fill holes and keep only the largest object in the binary mask."""
mask = ndimage.binary_fill_holes(mask, structure=np.ones((3, 3))).astype(int)
label_img = measure.label(mask)
if len(np.unique(label_img)) > 2:
regions = measure.regionprops(label_img)
mask = mask.astype(bool)
all_area = [r.area for r in regions]
second_max = max([a for a in all_area if a != max(all_area)], default=0)
mask = morphology.remove_small_objects(mask, min_size=second_max + 1)
return mask.astype(np.uint8)
def _load_model(self):
"""Load and return the UNet model."""
if self.device == "cuda":
pretrained = torch.load(self.model_path, map_location='cuda')
else:
pretrained = torch.load(self.model_path, map_location='cpu')
pretrained = self.convert_pytorch_checkpoint(pretrained)
model = UNetModel(num_input_channels=3, num_output_channels=3)
model.load_state_dict(pretrained)
return model
[docs]
def extract_masks(self, image: np.ndarray) -> np.ndarray:
"""
Generate a tissue mask for a single image using UNet segmentation.
Args:
image (np.ndarray): Input RGB image.
Returns:
np.ndarray: Processed binary tissue mask.
"""
global_save_dir = "./tmp/"
save_dir = os.path.join(global_save_dir, 'tissue_mask')
# Clean up and create fresh directories
if os.path.exists(global_save_dir):
shutil.rmtree(global_save_dir)
os.makedirs(save_dir)
# Prepare RGB input from grayscale
image_rgb = np.repeat(np.expand_dims(image, axis=2), 3, axis=2)
# Save images
image_path = os.path.join(global_save_dir, 'image.png')
cv2.imwrite(image_path, image_rgb)
# Create segmentor and predict
segmentor = SemanticSegmentor(
model=self.model,
pretrained_model="unet_tissue_mask_tsef",
num_loader_workers=4,
batch_size=4,
)
output = segmentor.predict(
[image_path],
save_dir=save_dir,
mode="tile",
resolution=1.0,
units="baseline",
patch_input_shape=[1024, 1024],
patch_output_shape=[512, 512],
stride_shape=[512, 512],
device=self.device,
crash_on_exception=True,
)
# Load and process masks
mask = np.load(output[0][1] + ".raw.0.npy")
mask = np.argmax(mask, axis=-1) == 2
mask = self.post_processing_mask(mask)
return mask