"""
This module implements virtual high-framerate sequence reconstruction
methods for periodic signals.
[1] Mariani, O., Ernst, A., Mercader, N. and Liebling, M., 2019. Reconstruction
of image sequences from ungated and scanning-aberrated laser scanning
microscopy images of the beating heart. IEEE transactions on computational
imaging, 6, pp.385-395.
"""
# Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/
# Written by François Marelli <francois.marelli@idiap.ch>
#
# This file is part of CBI Toolbox.
#
# CBI Toolbox is free software: you can redistribute it and/or modify
# it under the terms of the 3-Clause BSD License.
#
# CBI Toolbox is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# 3-Clause BSD License for more details.
#
# You should have received a copy of the 3-Clause BSD License along
# with CBI Toolbox. If not, see https://opensource.org/licenses/BSD-3-Clause.
#
# SPDX-License-Identifier: BSD-3-Clause
import tempfile
import pathlib
import shutil
from functools import partial
import numpy as np
import scipy.spatial.distance as scidist
from scipy import sparse
import concorde.tsp as tsp
import scs
import jax
import jax.numpy as jnp
import optax
[docs]def phase_diff(phase_a, phase_b):
"""
Compute the difference between unitless phases (between 0 and 1).
The result of the difference lies in the interval [-0.5, 0.5[.
Parameters
----------
phase_a : array-like of floats (or float)
Unitless phases, first operand of the difference.
phase_b : array-like of floats (or float)
Unitless phases, second operand of the difference.
Returns
-------
type(phase_a)
The phase difference.
"""
diff = (phase_a - phase_b) % 1
diff -= diff >= 0.5
return diff
[docs]def phase_error(phase_a, phase_b, floating_origin=True):
"""
Compute the absolute phase error between unitless phases (between 0 and 1).
If the starting phases of the two operands are not fixed with respect to
one another, the operands are first shifted by a scalar bias to minimize
this error (i.e. the zero phase of the second sighal is set so that the
mean phase differnce between the two operands is zero).
Parameters
----------
phase_a : array-like of floats
Reference unitless phases.
phase_b : array-like of floats
Target unitless phases.
floating_origin : bool, optional
If the zero phase of the target phase can be shifted with respect to the
zero of the reference phases to minimize the error, by default True.
Returns
-------
numpy.ndarray(float)
The absolute phase error between the two.
"""
if floating_origin:
phase_a = phase_a - phase_a[0] + phase_b[0]
diffs = phase_diff(phase_a, phase_b)
if floating_origin:
mean = np.mean(diffs)
nit = 0
while np.abs(mean) > 1e-5:
if nit > 10:
raise RuntimeError("Floating origin did not converge")
diffs -= mean
diffs %= 1
diffs -= diffs > 0.5
mean = np.mean(diffs)
nit += 1
err = np.abs(diffs)
return err
[docs]def argsort_tsp(images):
"""
Returns the indices that sort a sequence of images using TSP to obtain a
virtual high-framerate signal.
The direction of the produced sequence is undefined.
Parameters
----------
images : numpy.ndarray [N, ...]
List of images to sort.
Returns
-------
numpy.ndarray(int)
Array of indices that sort the sequence according to the TSP slution.
Raises
------
RuntimeWarning
When the TSP solver does not converge.
"""
distances = scidist.pdist(
images.reshape(images.shape[0], -1), metric="minkowski", p=1
)
# TSP files expect integer numbers, this allows higher precision
# It is unknown what the maximum allowed value is, this is good enough
distances /= distances.max()
distances *= 1e6
try:
ccdir = pathlib.Path(tempfile.mkdtemp())
ccfile = ccdir / "data.tsp"
with ccfile.open("w") as file_pointer:
file_pointer.writelines(
[
"NAME: PHASESORT\n",
"TYPE: TSP\n",
"DIMENSION: {}\n".format(images.shape[0]),
"EDGE_WEIGHT_TYPE: EXPLICIT\n",
"EDGE_WEIGHT_FORMAT: UPPER_ROW\n",
"EDGE_WEIGHT_SECTION\n",
]
)
np.savetxt(file_pointer, distances, fmt="%d")
file_pointer.write("EOF\n")
solver = tsp.TSPSolver.from_tspfile(str(ccfile))
solution = solver.solve(verbose=False)
if not solution.success:
raise RuntimeWarning("Solver did not converge.")
finally:
shutil.rmtree(ccdir)
return solution.tour
[docs]def unfold_phases(phases, threshold=-0.25, in_place=False):
"""
Unfold consecutive unitless phases from the range [0, 1] to R based on a
negative delta threshold.
Parameters
----------
phases : numpy.ndarray [N]
The unitless phases to unfold, in [0, 1].
threshold : float, optional
The negative threshold to go to the next period, by default -0.25
in_place : bool, optional
Compute operations in-place, by default False
Returns
-------
numpy.ndarray [N]
The unfolded phases over multiple periods.
"""
if not in_place:
phases = np.array(phases)
for idx, phase in enumerate(phases[1:]):
if phase - phases[idx] < threshold:
phases[idx + 1 :] += 1
return phases
[docs]def orient_phases(phases, seq_lengths=None, in_place=False):
"""
Detect the orientation of a sequence of phases sorted using TSP and flip it
if necessary. Assumes that the phases have been acquired in continuous
sequences under Shannon sampling assumption.
Parameters
----------
phases : numpy.ndarray (float) [N]
The phases to orient, given in the order of acquisition.
seq_lengths : tuple (int)
The lengths of sequences of images acquired consecutively, if multiple
sequences have been acquired. The total sum of sequence lengths must be
equal to the total number of images N. By default, None: considers the
entire sequence to be consecutive.
Returns
-------
c_phases : numpy.ndarray (float) [N]
The estimated unitless phases corresponding to the input images.
Raises
------
ValueError
"""
if not in_place:
phases = np.array(phases)
if seq_lengths is None:
seq_lengths = (phases.size,)
if sum(seq_lengths) != phases.size:
raise ValueError("The sum of sequence lengths must be equal to N.")
# Detect sequence direction under Shannon sampling hypothese
seq_start = 0
phase_delta = 0
for seq in seq_lengths:
seq_end = seq_start + seq
seq_deltas = (
phases[seq_start + 1 : seq_end] - phases[seq_start : seq_end - 1]
) % 1
phase_delta += seq_deltas.sum()
seq_start = seq_end
phase_delta /= phases.size - len(seq_lengths)
if phase_delta > 0.5:
phases *= -1
phases %= 1
return phases
[docs]def average_frequency(phases, acq_frequency, n_periods=2, unfold=True):
"""
Computes the initial average frequency of a periodic signal from uniformly
sampled phases by fitting a line using least squares.
Parameters
----------
phases : numpy.ndarray (float) [N]
The unitless phases sampled in the signal.
acq_frequency : float
The frequency at which the signal was sampled.
n_periods : float, optional
The number of periods over which to compute the average, by default 1.
unfold : bool, optional
If the input phases are folded over a single period, by default True.
Returns
-------
float
The average frequency of the sampled signal over the given region. The
utits are the same as the acquisition frequency.
"""
n_anchors = phases.size
if unfold:
phases = unfold_phases(phases)
if phases[-1] < n_periods:
n_anchors = phases.size
else:
n_anchors = np.argmax(phases > n_periods)
if n_anchors < 2:
raise ValueError(
"Insufficient number of points, increase n_periods or provide more points."
)
# Use linear regression to find the slope of the curve
x_anchors = np.stack((np.arange(n_anchors), np.ones(n_anchors)), axis=-1)
sol = np.linalg.lstsq(x_anchors, phases[:n_anchors], rcond=None)[0][0]
return sol * acq_frequency
[docs]def sample_video(images, phases, n_frames):
"""
Re-sample a non-uniformly sampled image sequence to generate a video.
Parameters
----------
images : np.ndarray [N, ...]
Sequence of images of size N.
phases : np.ndarray [N]
Sequence of phases corresponding to the images respectively.
n_frames : int
Number of frames in the generated video
Returns
-------
video: np.ndarray [n_frames, W, H]
The resampled video.
"""
order = np.argsort(phases)
phases = phases[order]
images = images[order]
video = np.empty(shape=(n_frames, *images.shape[1:]), dtype=images.dtype)
phase_idx = 0
for frame_idx in range(n_frames):
frame_phase = frame_idx / n_frames
phase_dist = np.abs(phases[phase_idx] - frame_phase)
for _ in range(phases.size - phase_idx - 1):
_dist = np.abs(phases[phase_idx + 1] - frame_phase)
if _dist < phase_dist:
phase_dist = _dist
phase_idx += 1
else:
break
video[frame_idx] = images[phase_idx]
return video
[docs]def vhf_phase_naive(images):
"""
Use naive vhf method to estimate the phases corresponding to a randomly
sampled periodic signal by solving TSP, as exposed in [1].
Parameters
----------
images : numpy.ndarray [N, ...]
The measured images.
Returns
-------
numpy.ndarray (float) [N]
The estimated unitless phases corresponding to each input image.
"""
# Sort the images using TSP, assume uniform sampling
argsort = argsort_tsp(images)
uni_phases = np.argsort(argsort) / argsort.size
uni_phases -= uni_phases[0]
uni_phases %= 1
return uni_phases
[docs]def estimate_di(phases, distances, k_size, w_width, d_i_k=None):
"""
Estimate the image function distance on phases based on phase neighbours.
Parameters
----------
phases : numpy.ndarray (float) [N]
The phase array.
distances : numpy.ndarray (float) [N, N]
The pairwise image distances.
k_size : int
The number of points of di to estimate.
w_width : float
The size of the weighting window.
d_i_k : numpy.ndarray (float) [N, K+1], optional
The array to store the d_i_kput, by default None
Returns
-------
numpy.ndarray [N, K+1]
The estimated distance functions for each point.
float
The residual error after optimization.
Raises
------
ValueError
RuntimeWarning
"""
size = phases.size
# This is a weird bug, sometimes %1 still contains a 1
delta_c = ((phases[None, :] - phases[:, None]) % 1) % 1
# Create the base condition matrix for increasing monotonic
a_cone_template = sparse.diags(
(-1.0, 1.0), offsets=(0, -1), shape=(k_size, k_size - 1), format="csc"
)
if d_i_k is None:
d_i_k = np.empty((size, k_size + 1))
else:
if not np.array_equal(d_i_k.shape, (size, k_size + 1)):
raise ValueError(
f"Incorrect d_i_k array size, expected {(size, k_size+1)} and got {d_i_k.shape}"
)
d_i_k[:, 0] = 0
d_i_k[:, -1] = 0
# pivots = np.empty(size, dtype=int)
# Assign each point to an interval, and compute its position in it
modfs = np.modf((delta_c * k_size).ravel())
modfs = zip(modfs[0], modfs[1].astype(int))
# Create the X matrix for least squares
rho = sparse.lil_array((size**2, k_size + 1), dtype=float)
for index, (ratio, position) in enumerate(modfs):
rho[index, position : position + 2] = (1 - ratio, ratio)
rho = rho[:, 1:-1]
rho = rho.tocsc()
def gaussian(x_val, sigma):
return np.exp(-0.5 * (x_val / sigma) ** 2)
delta_c -= delta_c >= 0.5
for index, c_delta in enumerate(delta_c):
# Use gaussian weighted errors
weights = np.sqrt(gaussian(c_delta, w_width))
weights = np.repeat(weights, size)
w_rho = rho * weights[:, None]
w_dist = distances.ravel() * weights
# Solve the linear least squares unconstrained
try:
d_k = (sparse.linalg.inv(w_rho.T.tocsc() @ w_rho) @ w_rho.T) @ w_dist
except RuntimeError:
d_k = np.linalg.lstsq(w_rho.toarray(), w_dist, rcond=None)[0]
# Find the position of the maximum
k_max = np.argmax(d_k)
# Create the matrices for quadratic solver
q_cone = 2 * w_rho.T.tocsc() @ w_rho
c_cone = -2 * w_dist.T @ w_rho
a_cone = a_cone_template.copy()
a_cone[k_max + 1 :, :] *= -1
b_cone = np.zeros(a_cone.shape[0])
data = dict(P=q_cone, A=a_cone, b=b_cone, c=c_cone)
cone = dict(l=a_cone.shape[0])
# Solve the constrained problem
solver = scs.SCS(data, cone, eps_abs=1e-7, eps_rel=1e-7, verbose=False)
sol = solver.solve(x=d_k)
if not sol["info"]["status"] == "solved":
raise RuntimeWarning("Conic solver did not converge.")
# The first distance is imposed to be 0
d_i_k[index, 1:-1] = sol["x"]
# pivots[index] = k_max + 1
return d_i_k, sol["info"]["res_pri"]
@partial(jax.jit, static_argnames=["sequences"])
def _regu_loss(phases, sequences):
"""
Compute the regularization loss based on phase delta deviation.
Parameters
----------
phases : jax.ndarray (float) [N]
The input phases.
sequences : tuple(int)
The acquisition sequence lengths.
Returns
-------
float
The loss value.
"""
regu_loss = 0
seq_start = 0
n_measures = phases.size - 2 * len(sequences)
for seq_len in sequences:
seq_phases = phases[seq_start : seq_start + seq_len]
deltas = (seq_phases[1:] - seq_phases[:-1]) % 1
deltas = (deltas[1:] - deltas[:-1]) ** 2
regu_loss += deltas.sum()
seq_start += seq_len
regu_loss /= n_measures
return regu_loss
@partial(jax.jit, static_argnames=["w_width"])
# @jax.jit
def _distance_loss(phases, distances, d_i_k, w_width):
"""
Compute the loss based on the deviation between measured image distances and
average image distance functions.
Parameters
----------
phases : jax.ndarray (float) [N]
The input phases.
distances : jax.jndarray (float) [N, N]
The pairwise image distances.
d_i_k : jax.jndarray (float) [N, K]
The average image distance functions.
w_width : float
The gaussian window width.
"""
def gaussian(x_val, sigma):
return jnp.exp(-0.5 * (x_val / sigma) ** 2)
k_range = jnp.arange(d_i_k.shape[1]) / (d_i_k.shape[1] - 1)
delta_phases = (phases[None, :] - phases[:, None]) % 1
diff_phases = delta_phases.copy()
diff_phases -= diff_phases >= 0.5
w_i_j = gaussian(diff_phases, w_width)
def body_func(index, loss_value):
d_i_j = jnp.interp(delta_phases[index], k_range, d_i_k[index])
diffs = (d_i_j - distances[index]) ** 2
loss_value += jnp.sum(diffs * w_i_j[index])
return loss_value
dist_loss = jax.lax.fori_loop(0, delta_phases.shape[0], body_func, 0.0)
dist_loss /= jnp.sum(w_i_j)
return dist_loss
def _gradient_descent_phase(
phases,
regu_lambda,
learning_rate,
a_tol,
grad_iter,
distances,
d_i_k,
w_width,
sequences,
optim=None,
):
"""
Minimize the loss using gradient descent.
Parameters
----------
phases : jax.ndarray (float) [N]
The input phases.
regu_lambda : float
The regularization strength, in [0,1].
learning_rate : float
The learning rate for gradient descent.
a_tol : float
The absolute tolerance for convergence.
grad_iter : int
The maximum number of iterations.
distances : jax.ndarray (float) [N, N]
The pairwise image distances.
d_i_k : jax.ndarray (float) [N, K]
The estimated distance functions.
w_width : float
The window width.
sequences : tuple(float)
The aqcuisition sequences length.
optim : tuple (optax.GradientTransformation, optax state), optional
The optimizer and its state. If left empty, adam is used (fresh init).
Returns
-------
phases : jax.ndarray (float) [N]
The estimated phases.
losses : tuple
The final losses, in order: (distance, regularization)
optim : tuple (optax.GradientTransformation, optax state)
The optimizer and its state. Provide it to the next iteration for
smoother convergence.
"""
d_r_loss = jax.value_and_grad(_regu_loss)
d_d_loss = jax.value_and_grad(_distance_loss)
if optim is None:
optim = optax.adam(learning_rate)
opt_state = optim.init(phases)
else:
optim, opt_state = optim
for _ in range(grad_iter):
r_loss, r_grad = d_r_loss(phases, sequences)
d_loss, d_grad = d_d_loss(phases, distances, d_i_k, w_width)
# total_loss = (1 - regu_lambda) * d_loss + regu_lambda * r_loss
total_grad = (1 - regu_lambda) * d_grad + regu_lambda * r_grad
updates, opt_state = optim.update(total_grad, opt_state, phases)
phases = optax.apply_updates(phases, updates)
if jnp.abs(updates).mean() < a_tol:
break
return phases, (float(d_loss), float(r_loss)), (optim, opt_state)
def _correction_pass(
c_phases,
distances,
seq_lengths,
k_size,
w1_width,
w2_width,
regu_lambda,
max_iter=100,
a_tol=1e-4,
grad_iter=100,
learning_rate=1e-3,
):
"""
Inner implementation of the phase correction algorithm.
Parameters
----------
c_phases : jax.ndarray (float) [N]
The initial estimate for the phases, given in acquisition order.
distances : jax.ndarray (float) [N, N]
The image-to-image distance pairs.
seq_lengths : tuple (int)
The lengths of sequences of images acquired consecutively, if multiple
sequences have been acquired. The total sum of sequence lengths must be
equal to the total number of images N.
k_size : int
Number of poins used to sample distance functions.
w1_width : float
Gaussian window width (standard deviation) used to weight contributions
of neighbours for distance function estimation.
w2_width : float
Gaussian window width (standard deviation) used to weight contributions
of neighbours for gradient descent.
regu_lambda : float
The regularization strength.
max_iter : int, optional
Maximum iterations of the algorithm, by default 100
a_tol : float, optional
Absolute tolerance used to define convergence, by default 1e-4
grad_iter : int, optional
Max number of iterations of the gradient descent, by default 100
learning_rate : float, optional
The learning rate of the gradient descent, by default 1e-3
Returns
-------
phases : numpy.ndarray (float) [N]
The corrected phases
status : tuple
The convergence status, containing:
converged : bool
If the algorithm converged.
n_iter : int
The number of iterations ran.
losses : tuple(float)
The final distance and regularization losses, in that order.
Raises
------
ValueError
"""
size = c_phases.size
d_i_k = np.empty((size, k_size + 1))
optim = None
converged = False
for out_iter in range(max_iter):
d_i_k, _ = estimate_di(c_phases, distances, k_size, w1_width, d_i_k=d_i_k)
jn_d_i_k = jnp.asarray(d_i_k)
c_phases_old = c_phases
c_phases, losses, optim = _gradient_descent_phase(
c_phases,
regu_lambda,
learning_rate,
a_tol,
grad_iter,
distances,
jn_d_i_k,
w2_width,
seq_lengths,
optim=optim,
)
if jnp.abs(c_phases_old - c_phases).mean() < a_tol:
converged = True
break
status = (converged, out_iter, losses)
return c_phases, status
[docs]def phase_correction(
phases,
images,
seq_lengths=None,
w_width=3e-2,
k_size=None,
regu_lambda=0.5,
precision_passes=0,
**kwargs,
):
"""
Correct a sequence of phases corresponding to a periodic signal using both
image-to-image and phase distances to find non uniform samplings. This
method assumes that the images have been acquired at regular time
intervals in one or multiple consecutive sequences.
Parameters
----------
phases : numpy.ndarray (float) [N]
The initial estimate for the phases, given in acquisition order.
images : numpy.ndarray (float) [N, ...]
The acquired images, in order of acquisition.
phases[i] should correspond to images[i].
seq_lengths : tuple (int)
The lengths of sequences of images acquired consecutively, if multiple
sequences have been acquired. The total sum of sequence lengths must be
equal to the total number of images N. By default, None: considers the
entire sequence to be consecutive.
w_width : float, optional
Gaussian window width (standard deviation) used to weight contributions
of neighbours, by default 3e-2
k_size : int, optional
Number of poins used to sample distance functions, by default None
(uses N/2)
regu_lambda : float, optional
The regularization strength, by default 0.5
precision_passes : int, optional
The number of precision passes to perform. At each pass, window width is
halved and k_size is doubled. By default 0.
** kwargs : named arguments passed to the `.vhf._correction_pass` function.
max_iter : int, optional
Maximum iterations of the algorithm, by default 100
a_tol : float, optional
Absolute tolerance used to define convergence, by default 1e-4
grad_iter : int, optional
Max number of iterations of the gradient descent, by default 100
learning_rate : float, optional
The learning rate of the gradient descent, by default 1e-3
Returns
-------
phases : numpy.ndarray (float) [N]
The corrected phases
status : tuple
The convergence status, containing:
converged : bool. If the algorithm converged.
n_iter : int. The number of iterations ran.
losses : tuple(float). The final distance and regularization losses, in that order.
Raises
------
ValueError
"""
size = phases.size
if seq_lengths is None:
seq_lengths = (size,)
if sum(seq_lengths) != size:
raise ValueError(
"The sum of sequence lengths must be equal to the total number of images."
)
# Compute all pairwise distances between images
distances = scidist.pdist(
images.reshape(images.shape[0], -1), metric="minkowski", p=1
)
distances /= distances.max()
distances = scidist.squareform(distances)
distances = jnp.asarray(distances)
if k_size is None:
k_size = size // 2
c_phases = jnp.asarray(phases)
for p_pass in range(precision_passes + 1):
if p_pass > 0:
w_width /= 2
k_size *= 2
c_phases, status = _correction_pass(
c_phases,
distances,
seq_lengths,
k_size=k_size,
w1_width=w_width,
w2_width=w_width,
regu_lambda=regu_lambda,
**kwargs,
)
return np.asarray(c_phases), status
[docs]def l_curve(
phases,
images,
seq_lengths=None,
w_width=3e-2,
k_size=None,
n_points=21,
iterations=50,
**kwargs,
):
"""
Correct a sequence of phases corresponding to a periodic signal using both
image-to-image and phase distances to find non uniform samplings. This
method assumes that the images have been acquired at regular time
intervals in one or multiple consecutive sequences.
Parameters
----------
phases : numpy.ndarray (float) [N]
The initial estimate for the phases, given in acquisition order.
images : numpy.ndarray (float) [N, ...]
The acquired images, in order of acquisition.
phases[i] should correspond to images[i].
seq_lengths : tuple (int)
The lengths of sequences of images acquired consecutively, if multiple
sequences have been acquired. The total sum of sequence lengths must be
equal to the total number of images N. By default, None: considers the
entire sequence to be consecutive.
w_width : float, optional
Gaussian window width (standard deviation) used to weight contributions
of neighbours, by default 3e-2
k_size : int, optional
Number of poins used to sample distance functions, by default None
(uses N/2)
n_points : int, optional
Number of values of lambda to compute for the L-curve. By default 21.
iterations : int, optional
Maximum number of iterations for the phase correction algorithm. By default 50.
** kwargs : named arguments passed to the `.vhf._correction_pass` function.
a_tol : float, optional. Absolute tolerance used to define convergence, by default 1e-4
learning_rate : float, optional. The learning rate of the gradient descent, by default 1e-3
Returns
-------
losses: numpy.ndarray [2, n_points]
The values of the losses for each lambda, given in the order
(distance cost, regularization cost)
lambdas: numpy.ndarray [n_points]
The lambda values
"""
size = phases.size
if seq_lengths is None:
seq_lengths = (size,)
if sum(seq_lengths) != size:
raise ValueError(
"The sum of sequence lengths must be equal to the total number of images."
)
# Compute all pairwise distances between images
distances = scidist.pdist(
images.reshape(images.shape[0], -1), metric="minkowski", p=1
)
distances /= distances.max()
distances = scidist.squareform(distances)
distances = jnp.asarray(distances)
if k_size is None:
k_size = size // 2
c_phases = jnp.asarray(phases)
lambdas = np.linspace(0, 1, n_points)
losses = []
for l_val in lambdas:
c_phases, status = _correction_pass(
c_phases,
distances,
seq_lengths,
k_size=k_size,
w1_width=w_width,
w2_width=w_width,
regu_lambda=l_val,
max_iter=iterations,
grad_iter=iterations,
**kwargs,
)
losses.append(status[-1])
losses = np.array(losses).T
return losses, lambdas
[docs]def vhf_phase_uncorrected(images, seq_lengths=None):
"""
Use naive vhf to estimate the phases corresponding to a periodic
signal sampled at regular time intervals, and its starting average
frequency. The phases are estimated to be uniformly sampled, which can be
improved by running the `.vhf.phase_correction` algorithm.
Parameters
----------
images : numpy.ndarray (float) [N, ...]
The acquired images, in acquisition order.
seq_lengths : tuple (int)
The lengths of sequences of images acquired consecutively, if multiple
sequences have been acquired. The total sum of sequence lengths must be
equal to the total number of images N. By default, None: considers the
entire sequence to be consecutive.
Returns
-------
uni_phases : numpy.ndarray (float) [N]
The estimated unitless phases corresponding to the input images.
"""
size = images.shape[0]
if seq_lengths is None:
seq_lengths = (size,)
if sum(seq_lengths) != size:
raise ValueError(
"The sum of sequence lengths must be equal to the total number of images."
)
uni_phases = vhf_phase_naive(images)
uni_phases = orient_phases(uni_phases, seq_lengths, in_place=True)
return uni_phases