import functools
import logging
import os
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from typing import Optional, Tuple, List, Union, Dict, Any
import numpy as np
import math
import cv2
from skimage import color
from pathlib import Path
from tqdm.auto import tqdm
from core.config import PREPROCESSING_RESOLUTION, REGISTRATION_RESOLUTION
logger = logging.getLogger(__name__)
# Scale factor relating coarse (preprocessing) resolution to fine (registration) resolution.
# Used to convert nuclei coordinates between the two resolution spaces.
_RESOLUTION_SCALE: float = PREPROCESSING_RESOLUTION / REGISTRATION_RESOLUTION
@functools.lru_cache(maxsize=32)
def _make_gaussian_blur(kernel_width: int, blur_sigma: float) -> transforms.GaussianBlur:
"""Return a cached GaussianBlur transform for the given parameters."""
return transforms.GaussianBlur(kernel_width, blur_sigma)
[docs]
def gaussian_smoothing(input_tensor: torch.Tensor, blur_sigma: float) -> torch.Tensor:
"""Apply Gaussian blur to a tensor. GaussianBlur objects are cached to avoid re-computing kernels."""
with torch.set_grad_enabled(False):
kernel_width = int(blur_sigma * 2.54) + 1
if kernel_width % 2 == 0:
kernel_width += 1
return _make_gaussian_blur(kernel_width, blur_sigma)(input_tensor)
[docs]
def build_reference_coordinate_system(input_tensor: Optional[torch.Tensor] = None,
dimensions: Optional[torch.Size] = None,
compute_device: Optional[Union[str, torch.device]] = None) -> torch.Tensor:
"""Build a reference coordinate grid for the given tensor dimensions."""
if input_tensor is not None:
dimensions = input_tensor.size()
# Convert string device specification to torch.device
if isinstance(compute_device, str):
compute_device = torch.device(compute_device)
if compute_device is None and input_tensor is not None:
base_transform = torch.eye(len(dimensions)-1)[:-1, :].unsqueeze(0).type_as(input_tensor)
else:
base_transform = torch.eye(len(dimensions)-1, device=compute_device)[:-1, :].unsqueeze(0)
base_transform = torch.repeat_interleave(base_transform, dimensions[0], dim=0)
coordinate_grid = F.affine_grid(base_transform, dimensions, align_corners=False)
return coordinate_grid
[docs]
def compute_smoothness_regularization(vector_field: "torch.Tensor",
compute_device: "torch.device" = None,
weight_map: "Optional[torch.Tensor]" = None) -> "torch.Tensor":
dim_count = len(vector_field.size()) - 2
if dim_count == 2:
x_grad = ((vector_field[:, 1:, :, :] - vector_field[:, :-1, :, :]) *
vector_field.shape[1])**2
y_grad = ((vector_field[:, :, 1:, :] - vector_field[:, :, :-1, :]) *
vector_field.shape[2])**2
if weight_map is not None:
# Apply spatial weighting if provided
x_weight = weight_map[:, 1:, :].unsqueeze(-1)
y_weight = weight_map[:, :, 1:].unsqueeze(-1)
smoothness_term = (torch.mean(x_grad * x_weight) + torch.mean(y_grad * y_weight)) / 2
else:
smoothness_term = (torch.mean(x_grad) + torch.mean(y_grad)) / 2
else:
raise ValueError("Unsupported dimensionality. Must be 2D or 3D.")
return smoothness_term
[docs]
def scale_tensor_to_dimensions(input_tensor: torch.Tensor,
target_dimensions: torch.Size,
interpolation_method: str = 'bilinear') -> torch.Tensor:
"""Resize input tensor to the given spatial dimensions."""
return F.interpolate(input_tensor, size=target_dimensions,
mode=interpolation_method, align_corners=False)
[docs]
def compute_normalized_cross_correlation(sources: torch.Tensor,
targets: torch.Tensor,
device: Optional[Union[str, torch.device]] = None,
**config_params) -> torch.Tensor:
ndim = len(sources.size()) - 2
if ndim not in [2, 3]:
raise ValueError("Unsupported number of dimensions.")
size = config_params.get('size', 7)
window = (size, ) * ndim
if device is None:
sum_filt = torch.ones([1, 1, *window]).type_as(sources)
else:
sum_filt = torch.ones([1, 1, *window], device=device)
pad_no = math.floor(window[0] / 2)
stride = ndim * (1,)
padding = ndim * (pad_no,)
conv_fn = getattr(F, 'conv%dd' % ndim)
sources_denom = sources**2
targets_denom = targets**2
numerator = sources*targets
sources_sum = conv_fn(sources, sum_filt, stride=stride, padding=padding)
targets_sum = conv_fn(targets, sum_filt, stride=stride, padding=padding)
sources_denom_sum = conv_fn(sources_denom, sum_filt, stride=stride, padding=padding)
targets_denom_sum = conv_fn(targets_denom, sum_filt, stride=stride, padding=padding)
numerator_sum = conv_fn(numerator, sum_filt, stride=stride, padding=padding)
size = np.prod(window)
u_sources = sources_sum / size
u_targets = targets_sum / size
cross = numerator_sum - u_targets * sources_sum - u_sources * targets_sum + u_sources * u_targets * size
sources_var = sources_denom_sum - 2 * u_sources * sources_sum + u_sources * u_sources * size
targets_var = targets_denom_sum - 2 * u_targets * targets_sum + u_targets * u_targets * size
ncc = cross * cross / (sources_var * targets_var + 1e-5)
return -torch.mean(ncc)
[docs]
def create_multiscale_representation(input_tensor: torch.Tensor,
level_count: int,
interpolation_method: str = 'bilinear',
scale_factor: float = 2.0) -> List[torch.Tensor]:
"""Build a multi-scale image pyramid from fine to coarse."""
pyramid_levels = [None] * level_count
# Build from fine to coarse
for i in range(level_count - 1, -1, -1):
if i == level_count - 1:
# Original resolution
pyramid_levels[i] = input_tensor
else:
# Get previous level and compute dimensions for current level
prev_size = pyramid_levels[i+1].size()
current_dims = tuple(int(prev_size[j] / scale_factor) if j > 1 else prev_size[j]
for j in range(len(prev_size)))
# Extract just the spatial dimensions
spatial_dims = torch.Size(current_dims)[2:]
# Apply smoothing to prevent aliasing, then downsample
smoothed = gaussian_smoothing(pyramid_levels[i+1], 1)
downsampled = scale_tensor_to_dimensions(smoothed, spatial_dims,
interpolation_method)
pyramid_levels[i] = downsampled
return pyramid_levels
[docs]
def convert_image_to_tensor(img_array: np.ndarray, compute_device: Union[str, torch.device] = "cpu") -> torch.Tensor:
"""Convert a numpy image array to a PyTorch tensor."""
# Convert string device specification to torch.device
if isinstance(compute_device, str):
compute_device = torch.device(compute_device)
# Normalize image if it's not already in [0, 1] range
if img_array.dtype != np.float32 and img_array.dtype != np.float64:
if img_array.max() > 1.0:
img_array = img_array.astype(np.float32) / 255.0
if len(img_array.shape) == 3:
# Color image
return torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(compute_device)
elif len(img_array.shape) == 2:
# Grayscale image
return torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0).to(compute_device)
else:
raise ValueError(f"Unsupported image dimensions: {img_array.shape}")
[docs]
def prepare_image_tensors(source_image: np.ndarray,
target_image: np.ndarray,
compute_device: Union[str, torch.device],
normalize: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
# Convert string device specification to torch.device
if isinstance(compute_device, str):
compute_device = torch.device(compute_device)
# Convert to grayscale if RGB
if len(source_image.shape) == 3 and source_image.shape[2] == 3:
gray_source = color.rgb2gray(source_image)
else:
gray_source = source_image
if len(target_image.shape) == 3 and target_image.shape[2] == 3:
gray_target = color.rgb2gray(target_image)
else:
gray_target = target_image
# Normalize if requested
if normalize:
gray_source = (gray_source - gray_source.min()) / (gray_source.max() - gray_source.min() + 1e-10)
gray_target = (gray_target - gray_target.min()) / (gray_target.max() - gray_target.min() + 1e-10)
# Convert to tensor format
tensor_source = convert_image_to_tensor(gray_source, compute_device)
tensor_target = convert_image_to_tensor(gray_target, compute_device)
# Create tensors with gradient tracking
source_tensor = tensor_source.clone().detach().to(dtype=torch.float32, device=compute_device).requires_grad_(True)
target_tensor = tensor_target.clone().detach().to(dtype=torch.float32, device=compute_device).requires_grad_(True)
return source_tensor, target_tensor
[docs]
def elastic_image_registration(
source: np.ndarray,
target: np.ndarray,
similarity_metric: str = "ncc",
similarity_metric_params: Dict[str, Any] = {"size": 7},
compute_device: Union[str, torch.device] = "cuda",
verbose: bool = False,
output_dir: Optional[Union[str, Path]] = None,
save_intermediate: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
# Setup
device = torch.device(compute_device) if isinstance(compute_device, str) else compute_device
# Resize source to match target dimensions (preserving reflect border) before building pyramids
aligned_source = cv2.warpAffine(source, np.eye(2, 3), (target.shape[1], target.shape[0]), borderMode=cv2.BORDER_REFLECT)
source_t, target_t = prepare_image_tensors(aligned_source, target, device)
pyramid_levels = 6
src_pyr = create_multiscale_representation(source_t, pyramid_levels)
tgt_pyr = create_multiscale_representation(target_t, pyramid_levels)
# Hyperparameters
iterations_per_level = [200, 200, 150, 100, 100, 80]
learning_rates = [0.01, 0.005, 0.002, 0.002, 0.001, 0.0005]
regularization_weights = [10.0, 1.5, 1.2, 1.0, 0.8, 0.4]
smoothing_sigma = 1.25
prev_def_field = None
# Loop through pyramid levels
for lvl in range(pyramid_levels):
curr_src = src_pyr[lvl]
curr_tgt = tgt_pyr[lvl]
H, W = curr_src.shape[2:]
# Initialize or upsample deformation field
if lvl == 0:
def_field = initialize_deformation_field(curr_src).detach().clone().requires_grad_(True)
else:
def_field = scale_deformation_field(prev_def_field, (H, W)).detach().clone().requires_grad_(True)
# Optimizer: LBFGS on final level, Adam otherwise
if lvl == pyramid_levels - 1:
optimizer = optim.LBFGS([def_field], lr=learning_rates[lvl], max_iter=50, line_search_fn="strong_wolfe")
else:
optimizer = optim.Adam([def_field], lr=learning_rates[lvl])
weight = regularization_weights[lvl]
for iter_idx in tqdm(range(iterations_per_level[lvl]), disable=not verbose, desc=f"Level {lvl}/{pyramid_levels-1}"):
def closure():
optimizer.zero_grad()
warped = apply_deformation_field(curr_src, def_field, compute_device=device)
sim_loss = compute_normalized_cross_correlation(warped, curr_tgt, compute_device=device, **similarity_metric_params)
reg_loss = deformation_loss(def_field, compute_device=device)
loss = sim_loss + weight * reg_loss
loss.backward()
return loss
loss = optimizer.step(closure)
with torch.no_grad():
# Optional clipping to prevent folding
max_disp = 5.0 # pixels
def_field.clamp_(-max_disp, max_disp)
prev_def_field = def_field
# Upsample to original shape if needed
final_def = scale_deformation_field(prev_def_field, (source_t.size(2), source_t.size(3))) \
if tuple(prev_def_field.shape[1:3]) != (source_t.size(2), source_t.size(3)) else prev_def_field
final_warped = apply_deformation_field(source_t, final_def, compute_device=device)
if save_intermediate and output_dir:
os.makedirs(output_dir, exist_ok=True)
warped_np = (final_warped.detach().cpu().numpy()[0, 0] * 255).astype(np.uint8)
cv2.imwrite(os.path.join(output_dir, "final_warped.png"), warped_np)
logger.info("Saved intermediate warped image to %s", output_dir)
return final_def, final_warped