"""
The mpi module allows to distribute operations in MPI communicators.
It is an optional feature of cbi_toolbox that requires a working implementation
of MPI.
"""
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by François Marelli <francois.marelli@idiap.ch>
#
# This file is part of CBI Toolbox.
#
# CBI Toolbox is free software: you can redistribute it and/or modify
# it under the terms of the 3-Clause BSD License.
#
# CBI Toolbox is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# 3-Clause BSD License for more details.
#
# You should have received a copy of the 3-Clause BSD License along
# with CBI Toolbox. If not, see https://opensource.org/licenses/BSD-3-Clause.
#
# SPDX-License-Identifier: BSD-3-Clause
import mpi4py.MPI as MPI
import numpy as np
import numpy.lib.format as npformat
from cbi_toolbox import utils
from . import distribute_bin, distribute_bin_all
_MPI_dtypes = {"float64": MPI.DOUBLE}
[docs]def get_size(mpi_comm=MPI.COMM_WORLD):
"""
Get the process count in the communicator.
Parameters
----------
mpi_comm : mpi4py.MPI.Comm, optional
The MPI communicator, by default MPI.COMM_WORLD.
Returns
-------
int
The size of the MPI communicator.
"""
return mpi_comm.Get_size()
[docs]def is_root_process(mpi_comm=MPI.COMM_WORLD):
"""
Check if the current process is root.
Parameters
----------
mpi_comm : mpi4py.MPI.Comm, optional
The MPI communicator, by default MPI.COMM_WORLD.
Returns
-------
bool
True if the current process is the root of the communicator.
"""
return mpi_comm.Get_rank() == 0
[docs]def get_rank(mpi_comm=MPI.COMM_WORLD):
"""
Get this process number in the communicator.
Parameters
----------
mpi_comm : mpi4py.MPI.Comm, optional
The communicator, by default MPI.COMM_WORLD.
Returns
-------
int
The rank of the process.
"""
return mpi_comm.Get_rank()
[docs]def wait_all(mpi_comm=MPI.COMM_WORLD):
"""
Wait for all processes to reach this line (MPI barrier)
This is just a wrapper for ease.
Parameters
----------
mpi_comm : mpi4py.MPI.Comm, optional
The communicator, by default MPI.COMM_WORLD.
"""
mpi_comm.Barrier()
[docs]def distribute_mpi(dimension, mpi_comm=MPI.COMM_WORLD):
"""
Computes the start index and bin size to evenly split array-like data into
multiple bins on an MPI communicator.
Parameters
----------
dimension : int
The size of the array to distribute.
mpi_comm : mpi4py.MPI.Comm, optional
The communicator, by default MPI.COMM_WORLD.
Returns
-------
(int, int)
The start index of this bin, and its size.
The distributed data should be array[start:start + bin_size].
"""
return distribute_bin(dimension, mpi_comm.Get_rank(), mpi_comm.Get_size())
[docs]def distribute_mpi_all(dimension, mpi_comm=MPI.COMM_WORLD):
"""
Computes the start indexes and bin sizes of all splits to distribute
computations across an MPI communicator.
Parameters
----------
dimension : int
the size of the array to be distributed
mpi_comm : mpi4py.MPI.Comm, optional
the communicator, by default MPI.COMM_WORLD
Returns
-------
([int], [int])
The list of start indexes and the list of bin sizes to distribute data.
"""
return distribute_bin_all(dimension, mpi_comm.Get_size())
[docs]def to_mpi_datatype(np_datatype):
"""
Returns the MPI datatype corresponding to the numpy dtype provided.
Parameters
----------
np_datatype : numpy.dtype or str
The numpy datatype, or name.
Returns
-------
mpi4py.MPI.Datatype
The corresponding MPI datatype.
Raises
------
NotImplementedError
If the numpy datatype is not listed in the conversion table.
"""
if isinstance(np_datatype, np.dtype):
dtype = np_datatype.name
else:
dtype = np_datatype
try:
return _MPI_dtypes[dtype]
except KeyError:
raise NotImplementedError("Type not in conversion table: {}".format(dtype))
[docs]def create_slice_view(axis, n_slices, array=None, shape=None, dtype=None):
"""
Create a MPI vector datatype to access given slices of a non distributed
array. If the array is not provided, its shape and dtype must be
specified.
Parameters
----------
axis : int
The axis on which to slice.
n_slices : int
How many contiguous slices to take.
array : numpy.ndarray, optional
The array to slice, by default None (then shape and dtype must be given).
shape : the shape of the array to slice, optional
The shape of the array, by default None.
dtype : numpy.dtype or str, optional
The datatype of the array, by default None.
Returns
-------
mpi4py.MPI.Datatype
The strided datatype allowing to access slices in the array.
Raises
------
ValueError
If array, shape and dtype are all None.
"""
if array is not None:
shape = array.shape
dtype = array.dtype
elif shape is None or dtype is None:
raise ValueError("array, or shape and dtype must be not None")
axis = utils.positive_index(axis, len(shape))
base_type = to_mpi_datatype(dtype)
stride = np.prod(shape[axis:], dtype=int)
block = np.prod(shape[axis + 1 :], dtype=int) * n_slices
count = np.prod(shape[:axis], dtype=int)
extent = block * base_type.extent
return base_type.Create_vector(count, block, stride).Create_resized(0, extent)
[docs]def compute_vector_extent(axis, array=None, shape=None, dtype=None):
"""
Compute the extent in bytes of a sliced view of a given array.
Parameters
----------
axis : int
Axis on which the slices are taken.
array : numpy.ndarray, optional
The array to slice, by default None (then shape and dtype must be given).
shape : the shape of the array to slice, optional
The shape of the array, by default None.
dtype : numpy.dtype or str, optional
The datatype of the array, by default None.
Returns
-------
int
The extent of the slices underlying data.
Raises
------
ValueError
If array, shape and dtype are all None.
"""
if array is not None:
shape = array.shape
dtype = array.dtype
elif shape is None or dtype is None:
raise ValueError("array, or shape and dtype must be not None")
ndims = len(shape)
axis = utils.positive_index(axis, ndims)
base_type = to_mpi_datatype(dtype)
return np.prod(shape[axis + 1 :], dtype=int) * base_type.extent
[docs]def create_vector_type(
src_axis, tgt_axis, array=None, shape=None, dtype=None, block_size=1
):
"""
Create a MPI vector datatype to communicate a distributed array and split it
along a different axis.
Parameters
----------
src_axis : int
The original axis on which the array is distributed.
tgt_axis : int
The axis on which the array is to be distributed.
array : numpy.ndarray, optional
The array to slice, by default None (then shape and dtype must be given).
shape : the shape of the array to slice, optional
The shape of the array, by default None.
dtype : numpy.dtype or str, optional
The datatype of the array, by default None.
block_size : int, optional
The size of the distributed bin, by default 1.
Returns
-------
mpi4py.MPI.Datatype
The vector datatype used for transmission/reception of the data.
Raises
------
ValueError
If array, shape and dtype are all None.
ValueError
If the source and destination axes are the same.
NotImplementedError
If the array has more than 4 axes (should work, but tests needed).
ValueError
If the block size is bigger than the source axis.
"""
if array is not None:
shape = array.shape
dtype = array.dtype
elif shape is None or dtype is None:
raise ValueError("array, or shape and dtype must be not None")
ndims = len(shape)
src_axis = utils.positive_index(src_axis, ndims)
tgt_axis = utils.positive_index(tgt_axis, ndims)
if src_axis == tgt_axis:
raise ValueError(
"Source and target are identical, no communication should be " "performed"
)
if len(shape) > 4:
raise NotImplementedError(
"This has never been tested for arrays with more than 4 axes.\n"
"It will probably work, but please run a test before"
"(and if works, tell me!)"
)
if block_size > shape[src_axis]:
raise ValueError(
"Block size cannot be bigger than the dimension of the source axis"
)
base_type = to_mpi_datatype(dtype)
min_axis = min(src_axis, tgt_axis)
max_axis = max(src_axis, tgt_axis)
i_count = np.prod(shape[min_axis + 1 : max_axis], dtype=int)
i_block = np.prod(shape[max_axis + 1 :], dtype=int)
i_stride = np.prod(shape[max_axis:], dtype=int)
i_extent = np.prod(shape[src_axis + 1 :], dtype=int) * base_type.extent
# only happens if the array is empty, avoid division by zero warnings
if i_extent == 0:
i_extent = 1
inner_stride = base_type.Create_vector(i_count, i_block, i_stride).Create_resized(
0, i_extent
)
o_count = np.prod(shape[:min_axis], dtype=int)
o_block = block_size
o_stride = (np.prod(shape[min_axis:], dtype=int) * base_type.extent) // i_extent
o_extent = np.prod(shape[tgt_axis + 1 :], dtype=int) * base_type.extent
outer_stride = inner_stride.Create_vector(
o_count, o_block, o_stride
).Create_resized(0, o_extent)
return outer_stride
[docs]def gather_full_shape(array, axis, mpi_comm=MPI.COMM_WORLD):
"""
Gather the full shape of an array distributed across an MPI communicator
along a given axis.
Parameters
----------
array : numpy.ndarray
The distributed array.
axis : int
The axis on which the array is distributed.
mpi_comm : mpi4py.MPI.Comm, optional
The communicator, by default MPI.COMM_WORLD.
Raises
------
NotImplementedError
This is not implemented yet.
"""
raise NotImplementedError
[docs]def load(file_name, axis, mpi_comm=MPI.COMM_WORLD):
"""
Load a numpy array across parallel jobs in the MPI communicator.
The array is sliced along the chosen dimension, with minimal bandwidth.
Parameters
----------
file_name : str
The numpy array file to load.
axis : int
The axis on which to distribute the array.
mpi_comm : mpi4py.MPI.Comm, optional
The MPI communicator used to distribute, by default MPI.COMM_WORLD.
Returns
-------
(numpy.ndarray, tuple(int))
The distributed array, and the size of the full array.
Raises
------
ValueError
If the numpy version used to save the file is not supported.
NotImplementedError
If the array is saved in Fortran order.
"""
header = None
if is_root_process(mpi_comm):
with open(file_name, "rb") as fp:
version, _ = npformat.read_magic(fp)
if version == 1:
header = npformat.read_array_header_1_0(fp)
elif version == 2:
header = npformat.read_array_header_2_0(fp)
else:
raise ValueError("Invalid numpy format version: {}".format(version))
header = *header, fp.tell()
header = mpi_comm.bcast(header, root=0)
full_shape, fortran, dtype, header_offset = header
if fortran:
raise NotImplementedError(
"Fortran-ordered (column-major) arrays are not supported"
)
ndims = len(full_shape)
axis = utils.positive_index(axis, ndims)
i_start, bin_size = distribute_mpi(full_shape[axis], mpi_comm)
l_shape = list(full_shape)
l_shape[axis] = bin_size
l_array = np.empty(l_shape, dtype=dtype)
slice_type = create_slice_view(axis, bin_size, shape=full_shape, dtype=dtype)
slice_type.Commit()
single_slice_extent = slice_type.extent
if bin_size != 0:
single_slice_extent /= bin_size
displacement = header_offset + i_start * single_slice_extent
base_type = to_mpi_datatype(l_array.dtype)
fh = MPI.File.Open(mpi_comm, file_name, MPI.MODE_RDONLY)
fh.Set_view(displacement, filetype=slice_type)
fh.Read_all([l_array, l_array.size, base_type])
fh.Close()
slice_type.Free()
return l_array, full_shape
[docs]def save(file_name, array, axis, full_shape=None, mpi_comm=MPI.COMM_WORLD):
"""
Save a numpy array from parallel jobs in the MPI communicator.
The array is gathered along the chosen dimension.
Parameters
----------
file_name : str
The numpy array file to load.
array : numpy.ndarray
The distributed array.
axis : int
The axis on which to distribute the array.
full_shape : tuple(int), optional
The size of the full array, by default None.
mpi_comm : mpi4py.MPI.Comm, optional
The MPI communicator used to distribute, by default MPI.COMM_WORLD.
"""
if full_shape is None:
full_shape = gather_full_shape(array, axis, mpi_comm)
axis = utils.positive_index(axis, len(full_shape))
header_offset = None
if is_root_process(mpi_comm):
header_dict = {
"shape": full_shape,
"fortran_order": False,
"descr": npformat.dtype_to_descr(array.dtype),
}
with open(file_name, "wb") as fp:
try:
npformat.write_array_header_1_0(fp, header_dict)
except ValueError:
npformat.write_array_header_2_0(fp, header_dict)
header_offset = fp.tell()
header_offset = mpi_comm.bcast(header_offset, root=0)
i_start, bin_size = distribute_mpi(full_shape[axis], mpi_comm)
slice_type = create_slice_view(axis, bin_size, shape=full_shape, dtype=array.dtype)
slice_type.Commit()
single_slice_extent = slice_type.extent
if bin_size != 0:
single_slice_extent /= bin_size
displacement = header_offset + i_start * single_slice_extent
base_type = to_mpi_datatype(array.dtype)
fh = MPI.File.Open(mpi_comm, file_name, MPI.MODE_WRONLY | MPI.MODE_APPEND)
fh.Set_view(displacement, filetype=slice_type)
fh.Write_all([array, array.size, base_type])
fh.Close()
slice_type.Free()
[docs]def redistribute(array, src_axis, tgt_axis, full_shape=None, mpi_comm=MPI.COMM_WORLD):
"""
Redistribute an array along a different dimension.
Parameters
----------
array : numpy.ndarray
The distributed array.
src_axis : int
The original axis on which the array is distributed.
tgt_axis : int
The axis on which the array is to be distributed.
full_shape : tuple(int), optional
The full shape of the array, by default None.
mpi_comm : mpi4py.MPI.Comm, optional
The MPI communicator used to distribute, by default MPI.COMM_WORLD.
Returns
-------
np.ndarray
The array distributed along the new axis.
"""
if full_shape is None:
full_shape = gather_full_shape(array, src_axis, mpi_comm)
ndims = len(full_shape)
src_axis = utils.positive_index(src_axis, ndims)
tgt_axis = utils.positive_index(tgt_axis, ndims)
if src_axis == tgt_axis:
return array
rank = mpi_comm.Get_rank()
size = mpi_comm.Get_size()
src_starts, src_bins = distribute_mpi_all(full_shape[src_axis], mpi_comm)
tgt_starts, tgt_bins = distribute_mpi_all(full_shape[tgt_axis], mpi_comm)
src_has_data = np.atleast_1d(src_bins)
src_has_data[src_has_data > 0] = 1
tgt_has_data = np.atleast_1d(tgt_bins)
tgt_has_data[tgt_has_data > 0] = 1
n_shape = list(full_shape)
n_shape[tgt_axis] = tgt_bins[rank]
n_array = np.empty(n_shape, dtype=array.dtype)
send_datatypes = []
recv_datatypes = []
for ji in range(size):
send_datatypes.append(
create_vector_type(src_axis, tgt_axis, array, block_size=src_bins[rank])
)
recv_datatypes.append(
create_vector_type(src_axis, tgt_axis, n_array, block_size=src_bins[ji])
)
send_extent = compute_vector_extent(tgt_axis, array)
recv_extent = compute_vector_extent(src_axis, n_array)
send_counts = np.multiply(tgt_bins, src_has_data[rank])
send_displs = np.multiply(tgt_starts, send_extent)
sendbuf = [array, send_counts, send_displs, send_datatypes]
recv_counts = np.multiply(src_has_data, tgt_bins[rank])
recv_displs = np.multiply(src_starts, recv_extent)
recvbuf = [n_array, recv_counts, recv_displs, recv_datatypes]
for ji in range(size):
send_datatypes[ji].Commit()
recv_datatypes[ji].Commit()
mpi_comm.Alltoallw(sendbuf, recvbuf)
for ji in range(size):
send_datatypes[ji].Free()
recv_datatypes[ji].Free()
return n_array