Source code for qlat.mpi_utils

import numpy as np

class q:
    from qlat_utils import (
            timer_verbose,
            get_num_node,
            get_id_node,
            Coordinate,
            get_chunk_list,
            displayln_info,
            displayln,
            )
    from .c import (
            get_size_node,
            get_coor_node,
            begin,
            end,
            )

default_size_node_list = list(map(q.Coordinate, [
    [ 1, 1, 1, 1, ],
    [ 1, 1, 1, 2, ],
    [ 1, 1, 2, 2, ],
    [ 1, 2, 2, 2, ],
    [ 1, 2, 2, 2, ],
    [ 2, 2, 2, 2, ],
    [ 2, 2, 2, 4, ],
    [ 2, 2, 4, 4, ],
    [ 2, 4, 4, 4, ],
    [ 4, 4, 4, 4, ],
    [ 4, 4, 4, 8, ],
    [ 4, 4, 8, 8, ],
    [ 4, 8, 8, 8, ],
    [ 8, 8, 8, 8, ],
    [ 8, 8, 8, 16, ],
    [ 1, 1, 1, 3, ],
    [ 1, 1, 2, 3, ],
    [ 1, 2, 2, 3, ],
    [ 2, 2, 2, 3, ],
    [ 2, 2, 2, 6, ],
    [ 2, 2, 4, 6, ],
    [ 2, 4, 4, 6, ],
    [ 4, 4, 4, 6, ],
    [ 4, 4, 4, 12, ],
    [ 4, 4, 8, 12, ],
    [ 4, 8, 8, 12, ],
    [ 8, 8, 8, 12, ],
    ]))

comm = None

def set_comm(x):
    global comm
    comm = x

def get_comm():
    return comm

def begin_with_mpi(size_node_list=None):
    global comm
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    num_node = comm.size
    id_node = comm.rank
    if size_node_list is None:
        size_node_list = []
    else:
        assert isinstance(size_node_list, list)
        size_node_list = list(map(q.Coordinate, size_node_list))
    size_node_list = size_node_list + default_size_node_list
    size_node = None
    for size_node_check in size_node_list:
        if size_node_check.volume() == num_node:
            size_node = size_node_check
            break
    if size_node is None:
        if id_node == 0:
            q.displayln(size_node_list)
        comm.barrier()
        raise Exception("begin_with_mpi: size_node_list not match num_node")
    q.begin(id_node, size_node)

def end_with_mpi(is_preserving_cache=False):
    q.end(is_preserving_cache)
    from mpi4py import MPI
    MPI.Finalize()

@q.timer_verbose
def show_machine():
    q.displayln(f"id_node: {q.get_id_node():4} / {q.get_num_node()}"
            f" ; coor_node: {str(q.get_coor_node()):9}"
            f" / {str(q.get_size_node())}")

[docs] def get_mpi_chunk(total_list, *, rng_state=None): """ rng_state has to be the same on all the nodes e.g. rng_state = q.RngState("get_mpi_chunk") """ chunk_number = q.get_num_node() chunk_id = q.get_id_node() chunk_list = q.get_chunk_list(total_list, chunk_number=chunk_number, rng_state=rng_state) if chunk_id < len(chunk_list): return chunk_list[chunk_id] else: return []