Source code for qlat.mpi_utils

class q:
    from qlat_utils import (
        timer_verbose,
        get_num_node,
        get_id_node,
        Coordinate,
        get_chunk_list,
        displayln_info,
        displayln,
    )
    from .c import (
        get_size_node,
        get_coor_node,
        begin,
        end,
    )

default_size_node_list = list(
    map(
        q.Coordinate,
        [
            [
                1,
                1,
                1,
                1,
            ],
            [
                1,
                1,
                1,
                2,
            ],
            [
                1,
                1,
                2,
                2,
            ],
            [
                1,
                2,
                2,
                2,
            ],
            [
                1,
                2,
                2,
                2,
            ],
            [
                2,
                2,
                2,
                2,
            ],
            [
                2,
                2,
                2,
                4,
            ],
            [
                2,
                2,
                4,
                4,
            ],
            [
                2,
                4,
                4,
                4,
            ],
            [
                4,
                4,
                4,
                4,
            ],
            [
                4,
                4,
                4,
                8,
            ],
            [
                4,
                4,
                8,
                8,
            ],
            [
                4,
                8,
                8,
                8,
            ],
            [
                8,
                8,
                8,
                8,
            ],
            [
                8,
                8,
                8,
                16,
            ],
            [
                1,
                1,
                1,
                3,
            ],
            [
                1,
                1,
                2,
                3,
            ],
            [
                1,
                2,
                2,
                3,
            ],
            [
                2,
                2,
                2,
                3,
            ],
            [
                2,
                2,
                2,
                6,
            ],
            [
                2,
                2,
                4,
                6,
            ],
            [
                2,
                4,
                4,
                6,
            ],
            [
                4,
                4,
                4,
                6,
            ],
            [
                4,
                4,
                4,
                12,
            ],
            [
                4,
                4,
                8,
                12,
            ],
            [
                4,
                8,
                8,
                12,
            ],
            [
                8,
                8,
                8,
                12,
            ],
        ],
    )
)

comm = None

def set_comm(x):
    global comm
    comm = x

def get_comm():
    return comm

def begin_with_mpi(size_node_list=None):
    global comm
    from mpi4py import MPI
    #
    comm = MPI.COMM_WORLD
    num_node = comm.size
    id_node = comm.rank
    if size_node_list is None:
        size_node_list = []
    else:
        assert isinstance(size_node_list, list)
        size_node_list = list(map(q.Coordinate, size_node_list))
    size_node_list = size_node_list + default_size_node_list
    size_node = None
    for size_node_check in size_node_list:
        if size_node_check.volume() == num_node:
            size_node = size_node_check
            break
    if size_node is None:
        if id_node == 0:
            q.displayln(size_node_list)
        comm.barrier()
        raise Exception("begin_with_mpi: size_node_list not match num_node")
    q.begin(id_node, size_node)

def end_with_mpi(is_preserving_cache=False):
    q.end(is_preserving_cache)
    from mpi4py import MPI
    #
    MPI.Finalize()

@q.timer_verbose
def show_machine():
    q.displayln(
        f"id_node: {q.get_id_node():4} / {q.get_num_node()}"
        f" ; coor_node: {str(q.get_coor_node()):9}"
        f" / {str(q.get_size_node())}"
    )

[docs] def get_mpi_chunk(total_list, *, rng_state=None): """ rng_state has to be the same on all the nodes e.g. rng_state = q.RngState("get_mpi_chunk") """ chunk_number = q.get_num_node() chunk_id = q.get_id_node() chunk_list = q.get_chunk_list( total_list, chunk_number=chunk_number, rng_state=rng_state ) if chunk_id < len(chunk_list): return chunk_list[chunk_id] else: return []