"""
Aprun is far too complex to handle purely through XML. We need python
code to compute and assemble aprun commands.
"""
from CIME.XML.standard_module_setup import *
import math
logger = logging.getLogger(__name__)
###############################################################################
def _get_aprun_cmd_for_case_impl(
ntasks,
nthreads,
rootpes,
pstrids,
max_tasks_per_node,
max_mpitasks_per_node,
pio_numtasks,
pio_async_interface,
compiler,
machine,
run_exe,
extra_args,
):
###############################################################################
"""
No one really understands this code, but we can at least test it.
>>> ntasks = [512, 675, 168, 512, 128, 168, 168, 512, 1]
>>> nthreads = [2, 2, 2, 2, 4, 2, 2, 2, 1]
>>> rootpes = [0, 0, 512, 0, 680, 512, 512, 0, 0]
>>> pstrids = [1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> max_tasks_per_node = 16
>>> max_mpitasks_per_node = 16
>>> pio_numtasks = -1
>>> pio_async_interface = False
>>> compiler = "pgi"
>>> machine = "titan"
>>> run_exe = "e3sm.exe"
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 4 -n 680 -N 8 -d 2 e3sm.exe : -S 2 -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> compiler = "intel"
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 4 -cc numa_node -n 680 -N 8 -d 2 e3sm.exe : -S 2 -cc numa_node -n 128 -N 4 -d 4 e3sm.exe ', 117, 808, 4, 4)
>>> ntasks = [64, 64, 64, 64, 64, 64, 64, 64, 1]
>>> nthreads = [1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> rootpes = [0, 0, 0, 0, 0, 0, 0, 0, 0]
>>> pstrids = [1, 1, 1, 1, 1, 1, 1, 1, 1]
>>> _get_aprun_cmd_for_case_impl(ntasks, nthreads, rootpes, pstrids, max_tasks_per_node, max_mpitasks_per_node, pio_numtasks, pio_async_interface, compiler, machine, run_exe, None)
(' -S 8 -cc numa_node -n 64 -N 16 -d 1 e3sm.exe ', 4, 64, 16, 1)
"""
if extra_args is None:
extra_args = {}
max_tasks_per_node = 1 if max_tasks_per_node < 1 else max_tasks_per_node
total_tasks = 0
for ntask, rootpe, pstrid in zip(ntasks, rootpes, pstrids):
tt = rootpe + (ntask - 1) * pstrid + 1
total_tasks = max(tt, total_tasks)
# Check if we need to add pio's tasks to the total task count
if pio_async_interface:
total_tasks += pio_numtasks if pio_numtasks > 0 else max_mpitasks_per_node
# Compute max threads for each mpi task
maxt = [0] * total_tasks
for ntask, nthrd, rootpe, pstrid in zip(ntasks, nthreads, rootpes, pstrids):
c2 = 0
while c2 < ntask:
s = rootpe + c2 * pstrid
if nthrd > maxt[s]:
maxt[s] = nthrd
c2 += 1
# make sure all maxt values at least 1
for c1 in range(0, total_tasks):
if maxt[c1] < 1:
maxt[c1] = 1
global_flags = " ".join(
[x for x, y in extra_args.items() if y["position"] == "global"]
)
per_flags = " ".join([x for x, y in extra_args.items() if y["position"] == "per"])
# Compute task and thread settings for batch commands
(
tasks_per_node,
min_tasks_per_node,
task_count,
thread_count,
max_thread_count,
total_node_count,
total_task_count,
aprun_args,
) = (0, max_mpitasks_per_node, 1, maxt[0], maxt[0], 0, 0, f" {global_flags}")
c1list = list(range(1, total_tasks))
c1list.append(None)
for c1 in c1list:
if c1 is None or maxt[c1] != thread_count:
tasks_per_node = min(
max_mpitasks_per_node, int(max_tasks_per_node / thread_count)
)
tasks_per_node = min(task_count, tasks_per_node)
# Compute for every subset
task_per_numa = int(math.ceil(tasks_per_node / 2.0))
# Option for Titan
if machine == "titan" and tasks_per_node > 1:
aprun_args += " -S {:d}".format(task_per_numa)
if compiler == "intel":
aprun_args += " -cc numa_node"
aprun_args += " -n {:d} -N {:d} -d {:d} {} {} {}".format(
task_count,
tasks_per_node,
thread_count,
per_flags,
run_exe,
"" if c1 is None else ":",
)
node_count = int(math.ceil(float(task_count) / tasks_per_node))
total_node_count += node_count
total_task_count += task_count
if tasks_per_node < min_tasks_per_node:
min_tasks_per_node = tasks_per_node
if c1 is not None:
thread_count = maxt[c1]
max_thread_count = max(max_thread_count, maxt[c1])
task_count = 1
else:
task_count += 1
return (
aprun_args,
total_node_count,
total_task_count,
min_tasks_per_node,
max_thread_count,
)
###############################################################################
[docs]
def get_aprun_cmd_for_case(case, run_exe, overrides=None, extra_args=None):
###############################################################################
"""
Given a case, construct and return the aprun command and optimized node count
"""
models = case.get_values("COMP_CLASSES")
ntasks, nthreads, rootpes, pstrids = [], [], [], []
for model in models:
model = "CPL" if model == "DRV" else model
for the_list, item_name in zip(
[ntasks, nthreads, rootpes, pstrids],
["NTASKS", "NTHRDS", "ROOTPE", "PSTRID"],
):
the_list.append(case.get_value("_".join([item_name, model])))
max_tasks_per_node = case.get_value("MAX_TASKS_PER_NODE")
if overrides:
overrides = {
x: y if isinstance(y, int) or y is None else int(y)
for x, y in overrides.items()
}
if "max_tasks_per_node" in overrides:
max_tasks_per_node = overrides["max_tasks_per_node"]
if "total_tasks" in overrides:
ntasks = [overrides["total_tasks"] if x > 1 else x for x in ntasks]
if "thread_count" in overrides:
nthreads = [overrides["thread_count"] if x > 1 else x for x in nthreads]
return _get_aprun_cmd_for_case_impl(
ntasks,
nthreads,
rootpes,
pstrids,
max_tasks_per_node,
case.get_value("MAX_MPITASKS_PER_NODE"),
case.get_value("PIO_NUMTASKS"),
case.get_value("PIO_ASYNC_INTERFACE"),
case.get_value("COMPILER"),
case.get_value("MACH"),
run_exe,
extra_args,
)