from .timer import *
from .cache import *
from .c import *
from .json import *
import math
import sys
import os
import numpy as np
import inspect
import importlib
import importlib.util
def getenv(*names, default=None):
assert len(names) > 0
for name in names:
val = os.getenv(name)
if val is not None:
displayln_info(0, f"{name}='{val}'")
return val
val = default
displayln_info(0, f"{names[0]}='{val}' (default)")
return val
def get_arg(option, default=None, *, argv=None, is_removing_from_argv=False):
"""
Get the `arg` of the option when it first appears.
Remove the option and its arg if `is_removing_from_argv`.
"""
if argv is None:
argv = sys.argv
i_max = len(argv) - 1
for i in range(len(argv)):
if argv[i] == option:
if i == i_max:
if is_removing_from_argv:
argv.pop(i)
return ""
else:
arg = argv[i + 1]
if is_removing_from_argv:
argv.pop(i)
argv.pop(i)
return arg
return default
def get_option(option, *, argv=None, is_removing_from_argv=False):
"""
Return if `option` in `argv`.
Remove the option if `is_removing_from_argv`
"""
if argv is None:
argv = sys.argv
if option in argv:
if is_removing_from_argv:
argv.remove(option)
return True
else:
return False
def show_memory_usage():
try:
import psutil
rss = psutil.Process().memory_info().rss / (1024 * 1024 * 1024)
displayln_info(f"show_memory_usage: rss = {rss:.6f} GB")
# displayln_info_malloc_stats()
except:
displayln_info(f"show_memory_usage: no psutil.")
def import_file(module_name, file_path):
"""
return the imported module
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def displayln_info_malloc_stats():
if get_id_node() == 0:
return displayln_malloc_stats()
def lazy_call(f, *args, **kwargs):
is_thunk = True
ret = None
def get():
nonlocal ret, is_thunk
if is_thunk:
ret = f(*args, **kwargs)
is_thunk = False
return ret
return get
[docs]
@timer
def get_fname():
"""
Return the function name of the current function ``fname``
"""
f = inspect.currentframe().f_back
return f.f_code.co_name
def sqr(x):
return x * x
def set_zero(x):
x.set_zero()
def set_unit(x, coef=1.0):
x.set_unit(coef)
def show(x):
return x.show()
def unitarize(x):
x.unitarize()
[docs]
def get_chunk_list(total_list, *, chunk_size=None, chunk_number=None, rng_state=None):
"""
Split ``total_list`` into ``chunk_number`` chunks or chunks with ``chunk_size``.
One (and only one) of ``chunk_size`` and ``chunk_number`` should not be ``None``.
#
Returns a list of chunks.
Number of chunks is less or equal to ``chunk_number``.
Chunk sizes are less or equal to ``chunk_size``.
if rng_state is None:
Do not randomly permute the list
"""
assert chunk_size is None or chunk_number is None
assert chunk_size is not None or chunk_number is not None
chunk_list = []
if rng_state is not None:
assert isinstance(rng_state, RngState)
total_list = random_permute(total_list, rng_state)
total = len(total_list)
if chunk_size is not None:
assert isinstance(chunk_size, int)
assert chunk_size >= 1
chunk_number = (total - 1) // chunk_size + 1
elif chunk_number is not None:
assert isinstance(chunk_number, int)
assert chunk_number >= 1
chunk_size = (total - 1) // chunk_number + 1
for i in range(chunk_number):
start = min(i * chunk_size, total);
stop = min(start + chunk_size, total);
if stop > start:
chunk_list.append(total_list[start:stop])
return chunk_list
[docs]
def rel_mod(x, size):
"""
Return ``x % size`` or ``x % size - size``
"""
x = x % size
assert x >= 0
if 2 * x >= size:
return x - size
else:
return x
[docs]
def rel_mod_sym(x, size):
"""
Return ``x % size`` or ``x % size - size`` or ``0``
"""
x = x % size
assert x >= 0
if 2 * x > size:
return x - size
elif 2 * x < size:
return x
else:
assert 2 * x == size
return 0
[docs]
def rel_mod_arr(x, size):
"""
Return ``x % size`` or ``x % size - size`` where ``x`` and ``size`` are np.array of same shape
"""
# assert x.shape == size.shape
ans = x % size
assert np.all(ans >= 0)
mask = 2 * ans >= size
ans[mask] = (ans - size)[mask]
return ans
[docs]
def rel_mod_sym_arr(x, size):
"""
Return ``x % size`` or ``x % size - size`` or ``0`` where ``x`` and ``size`` are np.array of same shape
"""
# assert x.shape == size.shape
ans = x % size
assert np.all(ans >= 0)
mask1 = 2 * ans > size
mask2 = 2 * ans == size
ans[mask1] = (ans - size)[mask1]
ans[mask2] = 0
return ans
def c_sqr(x):
return sum([ sqr(v) for v in x ])
def c_rel_mod(x, size):
l = len(size)
assert l == len(x)
return [ rel_mod(x[i], size[i]) for i in range(l) ]
def c_rel_mod_sqr(x, size):
l = len(size)
assert l == len(x)
return sum([ sqr(rel_mod(x[i], size[i])) for i in range(l) ])
def phat_sqr(q, size):
l = len(size)
assert l == len(q)
return 4 * sum([ sqr(math.sin(math.pi * (q[i] % size[i]) / size[i])) for i in range(l) ])
def get_r_sq(x_rel):
"""
get spatial distance square as int
"""
return sum([ x * x for x in x_rel[:3] ])
def get_r_limit(total_site):
"""
Return the limit for spatial ``r`` as float.\n
:params total_site: must be Coordinate type
"""
return math.sqrt(sum([ (l / 2)**2 for l in total_site.to_list()[:3] ]))
def mk_r_sq_list_3d(r_sq_limit):
r_limit = int(math.sqrt(r_sq_limit))
r_sq_set = set()
for x in range(0, r_limit + 1):
for y in range(0, x + 1):
for z in range(0, y + 1):
r_sq = x**2 + y**2 + z**2
if r_sq > r_sq_limit:
continue
r_sq_set.add(r_sq)
return sorted(list(r_sq_set))
[docs]
def mk_r_sq_list(r_sq_limit, dimension='3D'):
if dimension == '4D':
# Lagrange's four-square theorem
# https://en.wikipedia.org/wiki/Lagrange%27s_four-square_theorem
return list(range(0, r_sq_limit))
elif dimension == '3D':
return mk_r_sq_list_3d(r_sq_limit)
else:
raise Exception(f"mk_r_sq_list: dimension='{dimension}' not recognized.")
[docs]
def mk_r_list(r_limit, *, r_all_limit=28.0, r_scaling_factor=5.0, dimension='3D'):
"""
Make a list of `r` values from `0` up to `r_limit`.\n
Parameters
----------
r_limit: the limit for the generated `r` list.
r_scaling_factor: After `r_all_limit`, include `r` with integer values divide `r_scaling_factor`
r_all_limit: include all possible `r` values up to (include) this limit.
dimension: '3D' or '4D'
"""
r_list = [ math.sqrt(r_sq) for r_sq in mk_r_sq_list(int(min(r_limit, r_all_limit)**2 + 0.5)) ]
r_second_start = r_all_limit
if r_list:
r_second_start = min(r_second_start, r_list[-1])
r_second_start_idx = int(r_second_start * r_scaling_factor + 1.5)
r_second_stop_idx = math.ceil(r_limit * r_scaling_factor + 1.5)
for i in range(r_second_start_idx, r_second_stop_idx):
r = i / r_scaling_factor
r_list.append(r)
return r_list
[docs]
def mk_interp_tuple(x, x0, x1, x_idx):
"""
Returns `(x_idx_low, x_idx_high, coef_low, coef_high,)`\n
`x_idx` corresponds to `x0`
`x_idx + 1` corresponds to `x1`
"""
assert x0 <= x and x <= x1
x_idx_low = x_idx
x_idx_high = x_idx_low + 1
x_interval = x1 - x0
coef_low = (x1 - x) / x_interval
coef_high = (x - x0) / x_interval
return (x_idx_low, x_idx_high, coef_low, coef_high,)
[docs]
def mk_r_sq_interp_idx_coef_list(r_list):
"""
Return a list of tuples:\n
``r_sq_interp_idx_coef_list = [ (r_idx_low, r_idx_high, coef_low, coef_high,), ... ]``
where:
``r_sq_interp_idx_coef_list[r_sq] = (r_idx_low, r_idx_high, coef_low, coef_high,)``
`r_sq` ranges from `0` to `int(r_list[-1]**2 + 1.5)`
"""
r_list_len = len(r_list)
assert r_list_len >= 2
assert r_list[0] == 0.0
r_sq_list = list(range(0, int(r_list[-1]**2 + 1.5)))
r_idx = 0
r_sq_interp_idx_coef_list = []
for r_sq in r_sq_list:
r = math.sqrt(r_sq)
while True:
if r_idx + 1 >= r_list_len:
r_sq_interp_idx_coef_list.append((r_idx, r_idx, 1.0, 0.0,))
break
r0 = r_list[r_idx]
r1 = r_list[r_idx + 1]
if r0 <= r and r < r1:
r_sq_interp_idx_coef_list.append(mk_interp_tuple(r, r0, r1, r_idx))
break
r_idx += 1
return r_sq_interp_idx_coef_list
@timer
def check_log_json(script_file, json_results, *, check_eps=1e-5):
"""
json_results = [ (name, value, check_eps,), (name, value,), (name,), ... ]
"""
fname = get_fname()
if 0 == get_id_node():
json_fn_name = os.path.splitext(script_file)[0] + ".log.json"
qtouch(json_fn_name + ".new", json_dumps(json_results, indent=1))
if does_file_exist_qar(json_fn_name):
json_results_load = json_loads(qcat(json_fn_name))
for i, (p, pl,) in enumerate(zip(json_results, json_results_load)):
if len(p) != len(pl):
displayln(-1, f"CHECK: {i} {p} load:{pl}")
displayln(-1, f"CHECK: ERROR: JSON results length does not match.")
continue
if len(p) == 1:
eps = check_eps
epsl = check_eps
v = 0
vl = 0
n, = p
nl, = pl
elif len(p) == 2:
eps = check_eps
epsl = check_eps
n, v = p
nl, vl = pl
elif len(p) == 3:
n, v, eps = p
nl, vl, epsl = pl
else:
displayln(-1, f"CHECK: {i} {p} load:{pl}")
displayln(-1, f"CHECK: ERROR: JSON results length not 2 or 3.")
continue
if n != nl:
displayln(-1, f"CHECK: {i} {p} load:{pl}")
displayln(-1, f"CHECK: ERROR: JSON results item does not match.")
continue
if eps != epsl:
displayln(-1, f"CHECK: {i} {p} load:{pl}")
displayln(-1, f"CHECK: ERROR: JSON results eps does not match.")
continue
actual_eps = 0.0
if (abs(v) + abs(vl)) > 0:
actual_eps = 2 * abs(v - vl) / (abs(v) + abs(vl))
if actual_eps > eps:
displayln(-1, f"CHECK: {i} '{n}' actual: {v} ; load: {vl} .")
displayln(-1, f"CHECK: target eps: {eps} ; actual eps: {actual_eps} .")
displayln(-1, f"CHECK: ERROR: JSON results value does not match.")
elif actual_eps != 0.0:
displayln(-1, f"INFO: {fname}: {i} '{n}'")
displayln(-1, f"INFO: {fname}: target eps: {eps} ; actual eps: {actual_eps} .")
if len(json_results) != len(json_results_load):
displayln(-1, f"CHECK: len(json_results)={len(json_results)} load:{len(json_results_load)}")
displayln(-1, f"CHECK: ERROR: JSON results len does not match.")