qlat.mat_mpi — Distributed NumPy Arrays over MPI¶
Source: qlat/qlat/mat_mpi.py
Note: Update this document when updating the source file.
Outline¶
Overview¶
The qlat.mat_mpi module provides a lightweight distributed-array layer on
top of NumPy and mpi4py. The central abstraction is DistArray, which
partitions the first dimension of a NumPy array evenly across MPI ranks
(with zero-padding when the size does not divide evenly).
Features:
Element-wise arithmetic (
+,-,*,/) betweenDistArrayobjects and betweenDistArrayandnp.ndarray.Distributed
sum,transpose,conj,matmul, andtraceoperations.scatter_arr/gather_arr/all_gather_arrfor moving data between a singlenp.ndarrayand aDistArray.A
use_reference_implementationflag for switching between optimized and reference (gather-to-root) implementations for debugging.
Module-Level Configuration¶
set_mpi_comm¶
set_mpi_comm(comm) -> None
Set the default MPI communicator used when no comm argument is passed.
Must be called before any DistArray creation if a communicator other than
MPI.COMM_WORLD is desired.
get_mpi_comm¶
get_mpi_comm() -> mpi4py.MPI.Intracomm
Return the default MPI communicator. Falls back to MPI.COMM_WORLD if
set_mpi_comm has not been called.
bcast_py¶
bcast_py(x, root=0, comm=None) -> Any
Broadcast a Python object x from root to all ranks. A thin wrapper
around comm.bcast.
The DistArray Class¶
Constructor¶
DistArray(*, comm=None)
Create an empty distributed array. The array is initialized to a single
float64 zero. After construction, set self.x (local NumPy array) and
self.n (total first-dimension size) before use.
Attribute |
Type |
Description |
|---|---|---|
|
|
Total size of the distributed (first) dimension |
|
|
Local portion of the array (padded with zeros if needed) |
|
|
MPI communicator |
Arithmetic Operators¶
All standard arithmetic operators are supported between two DistArray
objects (requiring the same communicator, total size, and number of
dimensions) and between DistArray and scalars / np.ndarray.
Operator |
Method |
Description |
|---|---|---|
|
|
Element-wise addition |
|
|
Element-wise subtraction |
|
|
Element-wise multiplication |
|
|
Element-wise division |
|
|
Distributed matrix-vector product (delegates to |
All arithmetic operations return a new DistArray.
DistArray.sum¶
sum(axis=None, *, keepdims=False) -> np.ndarray
Compute the sum across all ranks. Collective operation. The result is a
regular np.ndarray, not a DistArray.
When axis does not include the distributed dimension (axis 0), each rank
sums locally and the result is gathered. When axis 0 is included, an
Allreduce is performed.
DistArray.transpose / transpose2d¶
transpose(axes=None) -> DistArray
transpose2d() -> DistArray
Transpose the first two dimensions of the distributed array. Collective
operation. For a 2D DistArray with shape (n/m, m_local), the result
has shape (m_local_new, n) where each rank holds a different row slice.
transpose delegates to transpose2d for 2D arrays. An Alltoall is used
internally for the optimized path; transpose2d_ref gathers to root for the
reference implementation.
DistArray.conj¶
conj() -> DistArray
Return a new DistArray with the complex conjugate of the local data.
Distributed Linear Algebra¶
d_matmul¶
d_matmul(d_mat: DistArray, d_vec: DistArray) -> DistArray
Compute the distributed matrix-vector (or matrix-matrix) product
d_mat @ d_vec. The full d_vec is gathered on each rank via
all_gather_arr, then local np.matmul is performed.
Parameter |
Type |
Description |
|---|---|---|
|
|
Distributed matrix (first dim is distributed) |
|
|
Distributed vector / matrix |
Returns |
|
Result of |
d_trace¶
d_trace(d_mat: DistArray) -> float | np.ndarray
Compute the distributed trace. Each rank contributes the trace of its local
diagonal block (offset by rank * d_vec_len), then results are reduced via
Allreduce.
Scatter and Gather¶
scatter_arr¶
scatter_arr(vec: np.ndarray, root: int = 0, comm=None) -> DistArray
Scatter a np.ndarray from root to all ranks, returning a DistArray.
The array is zero-padded if its first dimension is not evenly divisible by
the number of ranks. Only the root rank needs to supply valid data.
gather_arr¶
gather_arr(d_vec: DistArray, root: int = 0) -> np.ndarray | None
Gather a DistArray onto root, returning a np.ndarray with padded
zeros removed. Returns None on non-root ranks. Collective operation.
all_gather_arr¶
all_gather_arr(d_vec: DistArray) -> np.ndarray
Gather a DistArray onto all ranks, returning a np.ndarray with padded
zeros removed. Collective operation.
Examples¶
Basic Distributed Array Operations¶
import numpy as np
import qlat as q
from qlat.mat_mpi import DistArray, scatter_arr, all_gather_arr, set_mpi_comm
from mpi4py import MPI
size_node_list = [[1, 1, 1, 1]]
q.begin_with_mpi(size_node_list)
set_mpi_comm(MPI.COMM_WORLD)
# Scatter a vector to all ranks
vec = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
d_vec = scatter_arr(vec)
print(f"Rank {d_vec.comm.Get_rank()}: local shape = {d_vec.x.shape}")
# Arithmetic
d_sum = d_vec + d_vec
d_prod = d_vec * 2.0
full_sum = all_gather_arr(d_sum)
print(f"Gathered sum: {full_sum}")
q.end_with_mpi()
Distributed Matrix-Vector Product¶
import numpy as np
import qlat as q
from qlat.mat_mpi import DistArray, scatter_arr, all_gather_arr, set_mpi_comm
from mpi4py import MPI
size_node_list = [[1, 1, 1, 1]]
q.begin_with_mpi(size_node_list)
set_mpi_comm(MPI.COMM_WORLD)
# Create a 4x4 matrix and a vector
mat = np.arange(16, dtype=np.float64).reshape(4, 4)
vec = np.ones(4, dtype=np.float64)
d_mat = scatter_arr(mat)
d_vec = scatter_arr(vec)
# Distributed matmul
d_result = d_mat @ d_vec
result = all_gather_arr(d_result)
print(f"mat @ vec = {result}") # should match mat @ vec
q.end_with_mpi()
Distributed Trace¶
import numpy as np
import qlat as q
from qlat.mat_mpi import DistArray, scatter_arr, d_trace, set_mpi_comm
from mpi4py import MPI
size_node_list = [[1, 1, 1, 1]]
q.begin_with_mpi(size_node_list)
set_mpi_comm(MPI.COMM_WORLD)
mat = np.arange(16, dtype=np.float64).reshape(4, 4)
d_mat = scatter_arr(mat)
tr = d_trace(d_mat)
print(f"Trace = {tr}") # should be 0 + 5 + 10 + 15 = 30
q.end_with_mpi()