Source code for core.utils.mha_wsi

"""
MHA Deformation Field Application for CORE

Apply deformation fields computed at low resolution (e.g., 0.625x) 
to whole slide images at high resolution (e.g., 40x).

Features:
- Uses pyvips for memory-efficient large image handling
- Scales deformation fields from low to high resolution
- Properly scales both spatial dimensions and displacement magnitudes
"""

import numpy as np
import pyvips
import SimpleITK as sitk
from pathlib import Path
from typing import Tuple, Optional, Union
import argparse


[docs] def get_image_shape(img: Union[np.ndarray, pyvips.Image]) -> Tuple[int, ...]: """Get shape of image (row, col) or (row, col, channels).""" if isinstance(img, pyvips.Image): if img.bands > 1: return (img.height, img.width, img.bands) return (img.height, img.width) return img.shape
[docs] def array_to_vips(arr: np.ndarray) -> pyvips.Image: """Convert numpy array to pyvips Image.""" if arr.ndim == 2: height, width = arr.shape bands = 1 linear = arr.flatten() else: height, width, bands = arr.shape linear = arr.flatten() dtype_to_format = { np.dtype('uint8'): 'uchar', np.dtype('int8'): 'char', np.dtype('uint16'): 'ushort', np.dtype('int16'): 'short', np.dtype('uint32'): 'uint', np.dtype('int32'): 'int', np.dtype('float32'): 'float', np.dtype('float64'): 'double', } vips_format = dtype_to_format.get(arr.dtype, 'float') vips_img = pyvips.Image.new_from_memory( linear.tobytes(), width, height, bands, vips_format ) return vips_img
[docs] def vips_to_array(vips_img: pyvips.Image) -> np.ndarray: """Convert pyvips Image to numpy array.""" format_map = { 'uchar': np.uint8, 'char': np.int8, 'ushort': np.uint16, 'short': np.int16, 'uint': np.uint32, 'int': np.int32, 'float': np.float32, 'double': np.float64, } dtype = np.dtype(format_map.get(vips_img.format, np.float32)) np_arr = np.ndarray( buffer=vips_img.write_to_memory(), dtype=dtype, shape=[vips_img.height, vips_img.width, vips_img.bands] ) if np_arr.shape[2] == 1: np_arr = np_arr.squeeze(axis=2) return np_arr
[docs] def resize_displacement_field( displacement: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], target_shape_rc: Tuple[int, int], source_shape_rc: Tuple[int, int], registered_shape_rc: Tuple[int, int] ) -> np.ndarray: """ Resize displacement field to match target resolution. Steps: 1. Resize the displacement field spatially to match target dimensions 2. Scale the displacement values by the ratio of resolutions Args: displacement: Displacement field as (2, H, W) array or tuple of (dx, dy) target_shape_rc: Target output shape (rows, cols) source_shape_rc: Shape of image where transform was computed registered_shape_rc: Shape of registered image at transform resolution Returns: Resized displacement field as (2, H, W) array """ # Convert tuple to array if needed if isinstance(displacement, tuple): displacement = np.array([displacement[0], displacement[1]]) # Calculate scaling factors scale_factor_rc = np.array(target_shape_rc) / np.array(registered_shape_rc) # Extract x and y components disp_x = displacement[0] disp_y = displacement[1] # Use pyvips for efficient resizing vips_disp_x = array_to_vips(disp_x.astype(np.float32)) vips_disp_y = array_to_vips(disp_y.astype(np.float32)) # Resize to target dimensions scale_x = target_shape_rc[1] / disp_x.shape[1] scale_y = target_shape_rc[0] / disp_x.shape[0] resized_disp_x = vips_disp_x.resize(scale_x, vscale=scale_y) resized_disp_y = vips_disp_y.resize(scale_x, vscale=scale_y) # Convert back to numpy scaled_disp_x = vips_to_array(resized_disp_x) scaled_disp_y = vips_to_array(resized_disp_y) # Scale the displacement values scaled_disp_x = scaled_disp_x * scale_factor_rc[1] # x by column scale scaled_disp_y = scaled_disp_y * scale_factor_rc[0] # y by row scale return np.array([scaled_disp_x, scaled_disp_y])
[docs] def apply_displacement_field( img: Union[np.ndarray, pyvips.Image], displacement: np.ndarray, output_shape_rc: Tuple[int, int], source_shape_rc: Tuple[int, int], registered_shape_rc: Tuple[int, int], background: Optional[Union[int, float, list]] = None, interpolation: str = "bicubic" ) -> pyvips.Image: """ Apply displacement field to warp an image. Args: img: Input image to warp displacement: Backwards displacement field (2, H, W) output_shape_rc: Output shape (rows, cols) source_shape_rc: Shape where transform was computed registered_shape_rc: Shape of registered image at transform resolution background: Background color for areas outside the image interpolation: Interpolation method Returns: Warped pyvips Image """ # Convert numpy to vips if needed if isinstance(img, np.ndarray): vips_img = array_to_vips(img) else: vips_img = img # Get image properties n_channels = vips_img.bands # Set default background if background is None: background = [0] * n_channels elif isinstance(background, (int, float)): background = [background] * n_channels # Resize displacement field to output dimensions scaled_displacement = resize_displacement_field( displacement, output_shape_rc, source_shape_rc, registered_shape_rc ) # Create coordinate grids for mapping rows, cols = output_shape_rc col_indices = np.tile(np.arange(cols, dtype=np.float32), (rows, 1)) row_indices = np.tile(np.arange(rows, dtype=np.float32).reshape(-1, 1), (1, cols)) # Add displacements to get source coordinates source_cols = col_indices + scaled_displacement[0] source_rows = row_indices + scaled_displacement[1] # Stack into index image for mapim (2 bands: x, y) coord_map = np.stack([source_cols, source_rows], axis=2).astype(np.float32) vips_coord_map = array_to_vips(coord_map) # Set interpolation method interp_methods = { 'nearest': pyvips.Interpolate.new('nearest'), 'bilinear': pyvips.Interpolate.new('bilinear'), 'bicubic': pyvips.Interpolate.new('bicubic'), } interpolator = interp_methods.get(interpolation, pyvips.Interpolate.new('bicubic')) # Warp image using coordinate mapping warped = vips_img.mapim(vips_coord_map, interpolate=interpolator, background=background) return warped
[docs] def load_mha_displacement(mha_path: str) -> Tuple[np.ndarray, dict]: """ Load MHA deformation field and convert to displacement array format. Output format is (2, H, W) where: - displacement[0] = dx (column/x displacement) - displacement[1] = dy (row/y displacement) Args: mha_path: Path to MHA file Returns: Tuple of (displacement array, metadata dict) """ print(f"Loading deformation field: {mha_path}") sitk_img = sitk.ReadImage(mha_path) # Get metadata metadata = { 'size': sitk_img.GetSize(), 'spacing': sitk_img.GetSpacing(), 'origin': sitk_img.GetOrigin(), 'direction': sitk_img.GetDirection(), } print(f" Size: {metadata['size']}") print(f" Spacing: {metadata['spacing']}") # Convert to numpy arr = sitk.GetArrayFromImage(sitk_img) # Handle different array shapes if arr.ndim == 3: if arr.shape[2] == 2: # (H, W, 2) -> (2, H, W) disp_x = arr[:, :, 0] disp_y = arr[:, :, 1] elif arr.shape[0] == 2: # Already (2, H, W) disp_x = arr[0] disp_y = arr[1] else: raise ValueError(f"Unexpected deformation field shape: {arr.shape}") else: raise ValueError(f"Expected 3D array, got shape: {arr.shape}") displacement = np.array([disp_x, disp_y], dtype=np.float32) print(f" Displacement shape: {displacement.shape}") return displacement, metadata
[docs] def register_wsi_with_mha( mha_path: str, wsi_path: str, output_path: str, source_magnification: float = 0.625, target_magnification: float = 40.0, interpolation: str = "bicubic", background: int = 0 ): """ Apply MHA deformation field to WSI for registration. Args: mha_path: Path to deformation field MHA file wsi_path: Path to WSI to be warped output_path: Output path for registered image source_magnification: Magnification at which deformation was computed target_magnification: Target WSI magnification interpolation: Interpolation method background: Background color """ print("=" * 60) print("WSI Registration with MHA Deformation Field") print("=" * 60) # Calculate scale factor scale_factor = target_magnification / source_magnification print(f"Scale factor: {scale_factor}x ({source_magnification}x -> {target_magnification}x)") print() # Load deformation field displacement, mha_metadata = load_mha_displacement(mha_path) source_shape_rc = (displacement.shape[1], displacement.shape[2]) registered_shape_rc = source_shape_rc print() # Load WSI with pyvips for memory efficiency print(f"Loading WSI: {wsi_path}") vips_img = pyvips.Image.new_from_file(wsi_path, access='sequential') print(f" Size: {vips_img.width} x {vips_img.height}") print(f" Bands: {vips_img.bands}") print() # Calculate output shape output_shape_rc = (vips_img.height, vips_img.width) # Apply deformation print("Applying deformation field...") warped = apply_displacement_field( img=vips_img, displacement=displacement, output_shape_rc=output_shape_rc, source_shape_rc=source_shape_rc, registered_shape_rc=registered_shape_rc, background=background, interpolation=interpolation ) print(" Warping complete") print() # Save result print(f"Saving to: {output_path}") # Ensure output directory exists Path(output_path).parent.mkdir(parents=True, exist_ok=True) # Determine output format based on extension ext = Path(output_path).suffix.lower() if ext == '.mha': # Convert to numpy and save with SimpleITK warped_arr = vips_to_array(warped) sitk_img = sitk.GetImageFromArray(warped_arr) sitk.WriteImage(sitk_img, output_path, useCompression=True) elif output_path.endswith('.ome.tiff') or output_path.endswith('.ome.tif'): # Save as OME-TIFF (pyramidal) — must be checked before plain .tiff warped.write_to_file( output_path, tile=True, pyramid=True, compression='lzw', tile_width=256, tile_height=256 ) elif ext in ['.tiff', '.tif']: # Save as TIFF with pyvips warped.write_to_file(output_path, compression='lzw') else: # Default: save with pyvips warped.write_to_file(output_path) print(" Save complete") print() print("=" * 60) print("Registration complete!") print("=" * 60)
[docs] def save_displacement_as_mha( displacement: np.ndarray, output_path: str, spacing: Tuple[float, float] = (1.0, 1.0), origin: Tuple[float, float] = (0.0, 0.0) ): """ Save displacement array as MHA file. Args: displacement: Displacement field (2, H, W) output_path: Output path spacing: Pixel spacing origin: Image origin """ # Convert (2, H, W) to (H, W, 2) for SimpleITK disp_x = displacement[0] disp_y = displacement[1] arr = np.stack([disp_x, disp_y], axis=2).astype(np.float64) # Create SimpleITK image sitk_img = sitk.GetImageFromArray(arr, isVector=True) sitk_img.SetSpacing(spacing) sitk_img.SetOrigin(origin) # Save sitk.WriteImage(sitk_img, output_path, useCompression=True) print(f"Saved displacement field to: {output_path}")
if __name__ == "__main__": parser = argparse.ArgumentParser( description="Apply MHA deformation field to WSI for registration" ) parser.add_argument( "--mha", "-m", required=True, help="Path to MHA deformation field" ) parser.add_argument( "--wsi", "-w", required=True, help="Path to WSI to be registered" ) parser.add_argument( "--output", "-o", required=True, help="Output path" ) parser.add_argument( "--source-mag", type=float, default=0.625, help="Source magnification (default: 0.625)" ) parser.add_argument( "--target-mag", type=float, default=40.0, help="Target magnification (default: 40.0)" ) parser.add_argument( "--interp", default="bicubic", choices=["nearest", "bilinear", "bicubic"], help="Interpolation method" ) args = parser.parse_args() register_wsi_with_mha( mha_path=args.mha, wsi_path=args.wsi, output_path=args.output, source_magnification=args.source_mag, target_magnification=args.target_mag, interpolation=args.interp )