qlat.mpi — Low-Level MPI Utilities¶
Source: qlat/qlat/mpi.pyx
Note: Update this document when updating the source file.
Outline¶
Overview¶
mpi provides the low-level MPI communication primitives used throughout qlat:
begin/end of MPI communicators, node layout queries, global sum reductions, and
broadcasts. These are the building blocks that higher-level functions
(begin_with_mpi, end_with_mpi, bcast_py, etc.) rely on.
Most users should use the high-level wrappers q.begin_with_mpi(size_node_list)
and q.end_with_mpi() from mpi_utils.py rather than calling begin/end
directly. The glb_sum generic dispatcher is the recommended interface for
global sums.
MPI Communicator Lifecycle¶
begin¶
begin(id_node, size_node, color=0)
Initialize an MPI communicator for this node. Low-level; prefer
begin_with_mpi(size_node_list) for user code.
Parameter |
Type |
Description |
|---|---|---|
|
|
Node index within the MPI grid |
|
|
Dimensions of the MPI grid |
|
|
MPI communicator color (default 0) |
end¶
end(is_preserving_cache=False)
Finalize the currently active MPI communicator. Low-level; prefer end_with_mpi()
for user code. Clears cached data unless is_preserving_cache is True.
Node Layout Queries¶
get_size_node¶
get_size_node() -> Coordinate
Return the dimensions of the MPI grid as a Coordinate. For example, with 2
processes along the last dimension, get_size_node() returns
Coordinate([1, 1, 1, 2]).
get_coor_node¶
get_coor_node() -> Coordinate
Return this MPI process’s position within the grid as a Coordinate. For
example, process 0 returns Coordinate([0, 0, 0, 0]) and process 1 returns
Coordinate([0, 0, 0, 1]) in a 1×1×1×2 grid.
Global Sum Reductions¶
glb_sum¶
glb_sum(x) -> same_type_as_x
Compute the global MPI sum of x across all processes. The input is not
modified; the result is returned as a new value.
Dispatches by type:
|
Behavior |
|---|---|
|
Calls |
|
Calls |
|
Calls |
|
Calls |
|
Recursively sums each element |
|
Recursively sums each element |
|
Calls |
Type-Specific Global Sums¶
glb_sum_long(x) -> int¶
Global sum of a long integer. Modifies x in place (for C++ interop).
Recommend using glb_sum(x) instead for a non-mutating interface.
glb_sum_double(x) -> float¶
Global sum of a double. Same caveat about in-place modification.
glb_sum_complex(x) -> complex¶
Global sum of a complex number. Same caveat about in-place modification.
glb_sum_np(x) -> np.ndarray¶
Global sum of a NumPy array of any shape and dtype (float64, int64, or
complex128). Returns a new array; x is not modified.
glb_sum_lat_data(ld) -> LatData¶
Global sum of LatData. Returns a copy with the sum applied.
Broadcast Operations¶
bcast_long, bcast_double, bcast_complex¶
bcast_long(x, root=0) -> int
bcast_double(x, root=0) -> float
bcast_complex(x, root=0) -> complex
Broadcast a scalar value from the root process to all processes. The input
x is modified in place on non-root processes and returned.
bcast_lat_data¶
bcast_lat_data(ld, root=0) -> LatData
Broadcast a LatData object from the root process. Returns a copy of ld
with the broadcast applied (the original ld is not modified).
IO Shuffle Utilities¶
get_id_node_list_for_shuffle¶
get_id_node_list_for_shuffle() -> list[int]
Return the list of node indices used for distributed IO shuffling. The returned
list maps id_node_in_shuffle → id_node, i.e.,
list[id_node_in_shuffle] = id_node.
Examples¶
Global Sum of Scalars and Arrays¶
import qlat as q
import numpy as np
size_node_list = [
[1, 1, 1, 1],
[1, 1, 1, 2],
]
q.begin_with_mpi(size_node_list)
id_node = q.get_id_node()
# Scalar global sums
a = 12 + id_node
total = q.glb_sum(a) # e.g., 25 with 2 nodes
b = 12.4 + id_node
total_b = q.glb_sum(b) # e.g., 25.4 with 2 nodes
# NumPy array global sum
arr = np.arange(3.0) + 1.0 + id_node
total_arr = q.glb_sum(arr) # element-wise sum across nodes
# Nested list global sum
lst = [1.0 + id_node, 2.0 + 1.0j * id_node]
total_lst = q.glb_sum(lst) # recursive element-wise sum
q.end_with_mpi()
Broadcast from Root¶
import qlat as q
size_node_list = [
[1, 1, 1, 1],
[1, 1, 1, 2],
]
q.begin_with_mpi(size_node_list)
id_node = q.get_id_node()
# Root prepares the value; all nodes receive it
if id_node == 0:
val = 42
else:
val = 0
val = q.bcast_long(val)
# val == 42 on all nodes
q.end_with_mpi()
Node Layout Inspection¶
import qlat as q
size_node_list = [
[1, 1, 1, 1],
[1, 1, 1, 2],
]
q.begin_with_mpi(size_node_list)
print(q.get_id_node()) # 0 or 1
print(q.get_num_node()) # 2
print(q.get_coor_node()) # Coordinate([0, 0, 0, 0]) or Coordinate([0, 0, 0, 1])
print(q.get_size_node()) # Coordinate([1, 1, 1, 2])
q.end_with_mpi()