"""
Interface to the env_mach_pes.xml file. This class inherits from EntryID
"""
from CIME.XML.standard_module_setup import *
from CIME.XML.env_base import EnvBase
import math
logger = logging.getLogger(__name__)
[docs]class EnvMachPes(EnvBase):
def __init__(self, case_root=None, infile="env_mach_pes.xml", components=None, read_only=False):
"""
initialize an object interface to file env_mach_pes.xml in the case directory
"""
self._components = components
schema = os.path.join(get_cime_root(), "config", "xml_schemas", "env_mach_pes.xsd")
EnvBase.__init__(self, case_root, infile, schema=schema, read_only=read_only)
[docs] def get_value(self, vid, attribute=None, resolved=True, subgroup=None, max_mpitasks_per_node=None): # pylint: disable=arguments-differ
# Special variable NINST_MAX is used to determine the number of
# drivers in multi-driver mode.
if vid == "NINST_MAX":
# in the nuopc driver there is only a single NINST value
value = 1
for comp in self._components:
if comp != "CPL":
value = max(value, self.get_value("NINST_{}".format(comp)))
return value
value = EnvBase.get_value(self, vid, attribute, resolved, subgroup)
if "NTASKS" in vid or "ROOTPE" in vid:
if max_mpitasks_per_node is None:
max_mpitasks_per_node = self.get_value("MAX_MPITASKS_PER_NODE")
if value is not None and value < 0:
value = -1*value*max_mpitasks_per_node
# in the nuopc driver there is only one NINST value
# so that NINST_{comp} = NINST
if "NINST_" in vid and value is None:
value = self.get_value("NINST")
return value
[docs] def set_value(self, vid, value, subgroup=None, ignore_type=False):
"""
Set the value of an entry-id field to value
Returns the value or None if not found
subgroup is ignored in the general routine and applied in specific methods
"""
if vid == "MULTI_DRIVER" and value:
ninst_max = self.get_value("NINST_MAX")
for comp in self._components:
if comp == "CPL":
continue
ninst = self.get_value("NINST_{}".format(comp))
expect(ninst == ninst_max,
"All components must have the same NINST value in multi_driver mode. NINST_{}={} shoud be {}".format(comp,ninst,ninst_max))
if "NTASKS" in vid or "NTHRDS" in vid:
expect(value != 0, "Cannot set NTASKS or NTHRDS to 0")
return EnvBase.set_value(self, vid, value, subgroup=subgroup, ignore_type=ignore_type)
[docs] def get_max_thread_count(self, comp_classes):
''' Find the maximum number of openmp threads for any component in the case '''
max_threads = 1
for comp in comp_classes:
threads = self.get_value("NTHRDS",attribute={"compclass":comp})
expect(threads is not None, "Error no thread count found for component class {}".format(comp))
if threads > max_threads:
max_threads = threads
return max_threads
[docs] def get_total_tasks(self, comp_classes):
total_tasks = 0
maxinst = self.get_value("NINST")
if maxinst:
comp_interface = "nuopc"
else:
comp_interface = 'unknown'
maxinst = 1
for comp in comp_classes:
ntasks = self.get_value("NTASKS", attribute={"compclass":comp})
rootpe = self.get_value("ROOTPE", attribute={"compclass":comp})
pstrid = self.get_value("PSTRID", attribute={"compclass":comp})
if comp != "CPL" and comp_interface!="nuopc":
ninst = self.get_value("NINST", attribute={"compclass":comp})
maxinst = max(maxinst, ninst)
tt = rootpe + (ntasks - 1) * pstrid + 1
total_tasks = max(tt, total_tasks)
if self.get_value("MULTI_DRIVER"):
total_tasks *= maxinst
return total_tasks
[docs] def get_tasks_per_node(self, total_tasks, max_thread_count):
expect(total_tasks > 0,"totaltasks > 0 expected, totaltasks = {}".format(total_tasks))
tasks_per_node = min(self.get_value("MAX_TASKS_PER_NODE")// max_thread_count,
self.get_value("MAX_MPITASKS_PER_NODE"), total_tasks)
return tasks_per_node if tasks_per_node > 0 else 1
[docs] def get_total_nodes(self, total_tasks, max_thread_count):
"""
Return (num_active_nodes, num_spare_nodes)
"""
tasks_per_node = self.get_tasks_per_node(total_tasks, max_thread_count)
num_nodes = int(math.ceil(float(total_tasks) / tasks_per_node))
return num_nodes, self.get_spare_nodes(num_nodes)
[docs] def get_spare_nodes(self, num_nodes):
force_spare_nodes = self.get_value("FORCE_SPARE_NODES")
if force_spare_nodes != -999:
return force_spare_nodes
if self.get_value("ALLOCATE_SPARE_NODES"):
ten_pct = int(math.ceil(float(num_nodes) * 0.1))
if ten_pct < 1:
return 1 # Always provide at lease one spare node
elif ten_pct > 10:
return 10 # Never provide more than 10 spare nodes
else:
return ten_pct
else:
return 0