Source code for core.registration.cpd


import numpy as np
import matplotlib.pyplot as plt

from scipy.spatial.distance import cdist
from matplotlib import rc
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

rc('animation', html='html5')


[docs] class CPD: ''' Class implementing the Coherent Point Drift (CPD) method for rigid, affine and non-rigid point set registration. '''
[docs] def __init__(self, method='rigid'): ''' Initialize CPD class with rigid, affine or non-rigid point set registration. Parameters ---------- method: str, optional (default: 'rigid') 'rigid', 'affine' or 'nonrigid'. ''' self.method = method if self.method == 'rigid': self.solve = self.solve_rigid # Y: np.ndarray of shape = (M, D), R: np.ndarray of shape = (D, D), t: np.ndarray of shape = (D,), # s: float # Returns np.ndarray of shape = (M, D) self.transform = lambda Y, R, t, s: s * Y @ R.T + t elif self.method == 'affine': self.solve = self.solve_affine # Y: np.ndarray of shape = (M, D), B: np.ndarray of shape = (D, D), t: np.ndarray of shape = (D,) # Returns np.ndarray of shape = (M, D) self.transform = lambda Y, B, t: Y @ B.T + t elif self.method == 'nonrigid': # Y: np.ndarray of shape = (M, D), G: np.ndarray of shape = (M, M), W: np.ndarray of shape = (M, D) # Returns np.ndarray of shape = (M, D) self.transform = lambda Y, G, W: Y + G @ W
def __call__(self, X, Y, w=0.1, beta=2, lmbda=2, max_iterations=100, save_parameters=False): ''' Applies the CPD method to the data in Y to match X. Parameters ---------- X: np.ndarray, shape = (N, D) Set containing N D-dimensional data points. Y: np.ndarray, shape = (M, D) Set containing M D-dimensional data points. w: float, optional (default: 0.1) Assumption on the amount of noise in the point sets. beta: float, optional (default: 2) Model of smoothness regularizer (width of Gaussian filter in equation 20). lmbda: float, optional (default: 2) Trade-off between goodness of maximum likelihood fit and regularization. max_iterations: int, optional (default: 2) Maximum number of iterations for the EM optimization. save_parameters: bool, optional (default: False) Whether to save the history of estimated transformation parameters. Returns ------- Y_aligned: np.ndarray, shape = (M, D) Aligned Y points. P: np.ndarray, shape = (M, N) Probabilities for point assignments from Y to X. ''' # Get shapes N, D = X.shape M = Y.shape[0] if self.method == 'rigid': # Initialize theta = (R, t, s) theta = (np.eye(D), np.zeros(D), 1) elif self.method == 'affine': # Initialize theta = (B, t) theta = (np.eye(D), np.zeros(D)) elif self.method == 'nonrigid': # Initialize W W = np.zeros((M, D)) # Construct G G = np.exp(-1 / (2 * beta**2) * cdist(Y, Y, metric='euclidean')**2) theta = (G, W) if save_parameters: self.parameters = [theta] # Compute isotropic covariances var = 1 / (D * N * M) * np.sum(cdist(X, Y, metric='euclidean')**2) # EM optimization for iteration in range(max_iterations): # E-step: Compute P (fully vectorized — avoids Python loop over N) T = self.transform(Y, *theta) # (M, D) diff = X[np.newaxis, :, :] - T[:, np.newaxis, :] # (M, N, D) P = np.exp(-1 / (2 * var) * np.sum(diff ** 2, axis=2)) # (M, N) P /= np.sum(P, axis=0) + (2 * np.pi * var)**(D / 2) * w / (1 - w) * M / N if self.method == 'nonrigid': # M-step: Solve for W, var W, var = self.solve_nonrigid(X, Y, P, G, var, lmbda) theta = (G, W) else: # M-step: Solve for (R, t, s), var or (B, t), var theta, var = self.solve(X, Y, P) if save_parameters: self.parameters.append(theta) if iteration > 0 and np.allclose(P, P_old, rtol=1e-5, atol=1e-4): break else: P_old = np.copy(P) print(f'Finished after {iteration + 1} iterations.') return self.transform(Y, *theta), P
[docs] def solve_rigid(self, X, Y, P): ''' Solves for the rigid point set registration. Parameters ---------- X: np.ndarray, shape = (N, D) Set containing N D-dimensional data points. Y: np.ndarray, shape = (M, D) Set containing M D-dimensional data points. P: np.ndarray, shape = (M, N) Probabilities for point assignments from Y to X. Returns ------- theta: tuple, length = 3 Rotation matrix R: np.ndarray of shape = (D, D) Translation vector t: np.ndarray of shape = (D,) Scale s: float var: float Isotropic covariances. ''' N_P = np.sum(P) mu_x = 1 / N_P * X.T @ np.sum(P, axis=0) mu_y = 1 / N_P * Y.T @ np.sum(P, axis=1) X_hat = X - mu_x Y_hat = Y - mu_y A = X_hat.T @ P.T @ Y_hat # Compute SVD of A U, _, V_T = np.linalg.svd(A) D = X.shape[1] C = np.diag(np.concatenate([np.ones(D-1), np.linalg.det(U @ V_T).reshape(1)])) R = U @ C @ V_T s = np.trace(A.T @ R) / np.trace(Y_hat.T @ np.diag(np.sum(P, axis=1)) @ Y_hat) t = mu_x - s * R @ mu_y var = 1 / (N_P * D) * (np.trace(X_hat.T @ np.diag(np.sum(P, axis=0)) @ X_hat) - s * np.trace(A.T @ R)) return (R, t, s), var
[docs] def solve_affine(self, X, Y, P): ''' Solves for the affine point set registration. Parameters ---------- X: np.ndarray, shape = (N, D) Set containing N D-dimensional data points. Y: np.ndarray, shape = (M, D) Set containing M D-dimensional data points. P: np.ndarray, shape = (M, N) Probabilities for point assignments from Y to X. Returns ------- theta: tuple, length = 2 Affine transformation matrix B: np.ndarray of shape = (D, D) Translation vector t: np.ndarray of shape = (D,) var: float Isotropic covariances. ''' N_P = np.sum(P) mu_x = 1 / N_P * X.T @ np.sum(P, axis=0) mu_y = 1 / N_P * Y.T @ np.sum(P, axis=1) X_hat = X - mu_x Y_hat = Y - mu_y B = (X_hat.T @ P.T @ Y_hat) @ np.linalg.inv(Y_hat.T @ np.diag(np.sum(P, axis=1)) @ Y_hat) t = mu_x - B @ mu_y D = X.shape[1] var = 1 / (N_P * D) * (np.trace(X_hat.T @ np.diag(np.sum(P, axis=0)) @ X_hat) - np.trace(X_hat.T @ P.T @ Y_hat @ B.T)) return (B, t), var
[docs] def solve_nonrigid(self, X, Y, P, G, var, lmbda): ''' Solves for the non-rigid point set registration. Parameters ---------- X: np.ndarray, shape = (N, D) Set containing N D-dimensional data points. Y: np.ndarray, shape = (M, D) Set containing M D-dimensional data points. P: np.ndarray, shape = (M, N) Probabilities for point assignments from Y to X. G: np.ndarray, shape = (M, M) Kernel matrix. var: float Isotropic covariances. lmbda: float Trade-off between goodness of maximum likelihood fit and regularization. Returns ------- W: np.ndarray, shape = (M, D) Coefficient matrix. var: float Isotropic covariances. ''' diag_vals = np.sum(P, axis=1) # (M,) — avoid O(M²) matrix allocation diag_P_inv = 1.0 / diag_vals W = np.linalg.inv(G + lmbda * var * np.diag(diag_P_inv)) @ (diag_P_inv[:, np.newaxis] * P @ X - Y) N_P = np.sum(P) T = self.transform(Y, G, W) D = X.shape[1] var = 1 / (N_P * D) * (np.trace(X.T @ np.diag(np.sum(P, axis=0)) @ X) - 2 * np.trace((P @ X).T @ T) + np.trace(T.T @ np.diag(np.sum(P, axis=1)) @ T)) return W, var
[docs] def play_animation(self, X, Y, w=0.1, beta=2, lmbda=2, max_iterations=100, step_interval=0.1, save_animation=False, save_path=None): ''' Applies the CPD method to the data in Y to match X. Parameters ---------- X: np.ndarray, shape = (N, D) Set containing N D-dimensional data points. Y: np.ndarray, shape = (M, D) Set containing M D-dimensional data points. w: float, optional (default: 0.1) Assumption on the amount of noise in the point sets. beta: float, optional (default: 2) Model of smoothness regularizer (width of Gaussian filter in equation 20). lmbda: float, optional (default: 2) Trade-off between goodness of maximum likelihood fit and regularization. max_iterations: int, optional (default: 2) Maximum number of iterations for the EM optimization. step_interval: float, optional (default: 0.1) Time between animation steps in seconds. save_animation: bool, optional (default: False) Whether to save the animation. save_path: str, optional (default: None) Path where to save the animation. Returns ------- animation: matplotlib.animation.FuncAnimation Animation of the EM optimization progress. ''' self.__call__(X, Y, w=w, beta=beta, lmbda=lmbda, max_iterations=max_iterations, save_parameters=True) fig, ax = plt.subplots() ax.set_title(('Non-rigid' if self.method == 'nonrigid' else self.method.capitalize()) + ' CPD') ax.set_xlabel('x') ax.set_ylabel('y') ax.axis('equal') ax.set_xlim((min(X[:, 0].min(), Y[:, 0].min()), max(X[:, 0].max(), Y[:, 0].max()))) ax.set_ylim((min(X[:, 1].min(), Y[:, 1].min()), max(X[:, 1].max(), Y[:, 1].max()))) scat1 = ax.scatter(X[:, 0], X[:, 1], c='b') scat2 = ax.scatter([], [], c='r') def animate(i): Y_aligned = self.transform(Y, *self.parameters[i]) scat2.set_offsets(Y_aligned) return scat1, scat2, animation = FuncAnimation(fig, animate, frames=len(self.parameters), init_func=None, interval=int(1000 * step_interval), blit=True, repeat=False) plt.close(animation._fig) if save_animation and save_path is not None: animation.save(save_path + '.mp4', writer='ffmpeg', fps=1 / step_interval, dpi=None) return animation