qlat.mat_mpi — Distributed NumPy Arrays over MPI

Source: qlat/qlat/mat_mpi.py

Note: Update this document when updating the source file.

Outline

  1. Overview

  2. Module-Level Configuration

  3. The DistArray Class

  4. Distributed Linear Algebra

  5. Scatter and Gather

  6. Examples


Overview

The qlat.mat_mpi module provides a lightweight distributed-array layer on top of NumPy and mpi4py. The central abstraction is DistArray, which partitions the first dimension of a NumPy array evenly across MPI ranks (with zero-padding when the size does not divide evenly).

Features:

  • Element-wise arithmetic (+, -, *, /) between DistArray objects and between DistArray and np.ndarray.

  • Distributed sum, transpose, conj, matmul, and trace operations.

  • scatter_arr / gather_arr / all_gather_arr for moving data between a single np.ndarray and a DistArray.

  • A use_reference_implementation flag for switching between optimized and reference (gather-to-root) implementations for debugging.


Module-Level Configuration

set_mpi_comm

set_mpi_comm(comm) -> None

Set the default MPI communicator used when no comm argument is passed. Must be called before any DistArray creation if a communicator other than MPI.COMM_WORLD is desired.


get_mpi_comm

get_mpi_comm() -> mpi4py.MPI.Intracomm

Return the default MPI communicator. Falls back to MPI.COMM_WORLD if set_mpi_comm has not been called.


bcast_py

bcast_py(x, root=0, comm=None) -> Any

Broadcast a Python object x from root to all ranks. A thin wrapper around comm.bcast.


The DistArray Class

Constructor

DistArray(*, comm=None)

Create an empty distributed array. The array is initialized to a single float64 zero. After construction, set self.x (local NumPy array) and self.n (total first-dimension size) before use.

Attribute

Type

Description

n

int

Total size of the distributed (first) dimension

x

np.ndarray

Local portion of the array (padded with zeros if needed)

comm

MPI.Intracomm

MPI communicator


Arithmetic Operators

All standard arithmetic operators are supported between two DistArray objects (requiring the same communicator, total size, and number of dimensions) and between DistArray and scalars / np.ndarray.

Operator

Method

Description

+

__add__, __radd__

Element-wise addition

-

__sub__, __rsub__

Element-wise subtraction

*

__mul__, __rmul__

Element-wise multiplication

/

__truediv__, __rtruediv__

Element-wise division

@

__matmul__

Distributed matrix-vector product (delegates to d_matmul)

All arithmetic operations return a new DistArray.


DistArray.sum

sum(axis=None, *, keepdims=False) -> np.ndarray

Compute the sum across all ranks. Collective operation. The result is a regular np.ndarray, not a DistArray.

When axis does not include the distributed dimension (axis 0), each rank sums locally and the result is gathered. When axis 0 is included, an Allreduce is performed.


DistArray.transpose / transpose2d

transpose(axes=None) -> DistArray
transpose2d() -> DistArray

Transpose the first two dimensions of the distributed array. Collective operation. For a 2D DistArray with shape (n/m, m_local), the result has shape (m_local_new, n) where each rank holds a different row slice.

transpose delegates to transpose2d for 2D arrays. An Alltoall is used internally for the optimized path; transpose2d_ref gathers to root for the reference implementation.


DistArray.conj

conj() -> DistArray

Return a new DistArray with the complex conjugate of the local data.


Distributed Linear Algebra

d_matmul

d_matmul(d_mat: DistArray, d_vec: DistArray) -> DistArray

Compute the distributed matrix-vector (or matrix-matrix) product d_mat @ d_vec. The full d_vec is gathered on each rank via all_gather_arr, then local np.matmul is performed.

Parameter

Type

Description

d_mat

DistArray

Distributed matrix (first dim is distributed)

d_vec

DistArray

Distributed vector / matrix

Returns

DistArray

Result of d_mat @ d_vec


d_trace

d_trace(d_mat: DistArray) -> float | np.ndarray

Compute the distributed trace. Each rank contributes the trace of its local diagonal block (offset by rank * d_vec_len), then results are reduced via Allreduce.


Scatter and Gather

scatter_arr

scatter_arr(vec: np.ndarray, root: int = 0, comm=None) -> DistArray

Scatter a np.ndarray from root to all ranks, returning a DistArray. The array is zero-padded if its first dimension is not evenly divisible by the number of ranks. Only the root rank needs to supply valid data.


gather_arr

gather_arr(d_vec: DistArray, root: int = 0) -> np.ndarray | None

Gather a DistArray onto root, returning a np.ndarray with padded zeros removed. Returns None on non-root ranks. Collective operation.


all_gather_arr

all_gather_arr(d_vec: DistArray) -> np.ndarray

Gather a DistArray onto all ranks, returning a np.ndarray with padded zeros removed. Collective operation.


Examples

Basic Distributed Array Operations

import numpy as np
import qlat as q
from qlat.mat_mpi import DistArray, scatter_arr, all_gather_arr, set_mpi_comm

from mpi4py import MPI

size_node_list = [[1, 1, 1, 1]]
q.begin_with_mpi(size_node_list)

set_mpi_comm(MPI.COMM_WORLD)

# Scatter a vector to all ranks
vec = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
d_vec = scatter_arr(vec)
print(f"Rank {d_vec.comm.Get_rank()}: local shape = {d_vec.x.shape}")

# Arithmetic
d_sum = d_vec + d_vec
d_prod = d_vec * 2.0
full_sum = all_gather_arr(d_sum)
print(f"Gathered sum: {full_sum}")

q.end_with_mpi()

Distributed Matrix-Vector Product

import numpy as np
import qlat as q
from qlat.mat_mpi import DistArray, scatter_arr, all_gather_arr, set_mpi_comm

from mpi4py import MPI

size_node_list = [[1, 1, 1, 1]]
q.begin_with_mpi(size_node_list)

set_mpi_comm(MPI.COMM_WORLD)

# Create a 4x4 matrix and a vector
mat = np.arange(16, dtype=np.float64).reshape(4, 4)
vec = np.ones(4, dtype=np.float64)

d_mat = scatter_arr(mat)
d_vec = scatter_arr(vec)

# Distributed matmul
d_result = d_mat @ d_vec
result = all_gather_arr(d_result)
print(f"mat @ vec = {result}")  # should match mat @ vec

q.end_with_mpi()

Distributed Trace

import numpy as np
import qlat as q
from qlat.mat_mpi import DistArray, scatter_arr, d_trace, set_mpi_comm

from mpi4py import MPI

size_node_list = [[1, 1, 1, 1]]
q.begin_with_mpi(size_node_list)

set_mpi_comm(MPI.COMM_WORLD)

mat = np.arange(16, dtype=np.float64).reshape(4, 4)
d_mat = scatter_arr(mat)

tr = d_trace(d_mat)
print(f"Trace = {tr}")  # should be 0 + 5 + 10 + 15 = 30

q.end_with_mpi()