Source code for CIME.case.check_input_data

"""
API for checking input for testcase
"""
from CIME.XML.standard_module_setup import *
from CIME.utils import SharedArea, find_files, safe_copy, expect
from CIME.XML.inputdata import Inputdata
import CIME.Servers

import glob, hashlib, shutil

logger = logging.getLogger(__name__)
# The inputdata_checksum.dat file will be read into this hash if it's available
chksum_hash = dict()
local_chksum_file = "inputdata_checksum.dat"


def _download_checksum_file(rundir):
    """
    Download the checksum files from each server and merge them into rundir.
    """
    inputdata = Inputdata()
    protocol = "svn"
    chksum_found = False
    # download and merge all available chksum files.
    while protocol is not None:
        protocol, address, user, passwd, chksum_file, _, _ = inputdata.get_next_server()
        if protocol not in vars(CIME.Servers):
            logger.info("Client protocol {} not enabled".format(protocol))
            continue
        logger.info(
            "Using protocol {} with user {} and passwd {}".format(
                protocol, user, passwd
            )
        )
        if protocol == "svn":
            server = CIME.Servers.SVN(address, user, passwd)
        elif protocol == "gftp":
            server = CIME.Servers.GridFTP(address, user, passwd)
        elif protocol == "ftp":
            server = CIME.Servers.FTP.ftp_login(address, user, passwd)
        elif protocol == "wget":
            server = CIME.Servers.WGET.wget_login(address, user, passwd)
        else:
            expect(False, "Unsupported inputdata protocol: {}".format(protocol))
        if not server:
            continue

        if chksum_file:
            chksum_found = True
        else:
            continue

        success = False
        rel_path = chksum_file
        full_path = os.path.join(rundir, local_chksum_file)
        new_file = full_path + ".raw"
        protocol = type(server).__name__
        logger.info(
            "Trying to download file: '{}' to path '{}' using {} protocol.".format(
                rel_path, new_file, protocol
            )
        )
        tmpfile = None
        if os.path.isfile(full_path):
            tmpfile = full_path + ".tmp"
            os.rename(full_path, tmpfile)
        # Use umask to make sure files are group read/writable. As long as parent directories
        # have +s, then everything should work.
        success = server.getfile(rel_path, new_file)
        if success:
            _reformat_chksum_file(full_path, new_file)
            if tmpfile:
                _merge_chksum_files(full_path, tmpfile)
            chksum_hash.clear()
        else:
            if tmpfile and os.path.isfile(tmpfile):
                os.rename(tmpfile, full_path)
                logger.warning(
                    "Could not automatically download file "
                    + full_path
                    + " Restoring existing version."
                )
            else:
                logger.warning(
                    "Could not automatically download file {}".format(full_path)
                )
    return chksum_found


def _reformat_chksum_file(chksum_file, server_file):
    """
    The checksum file on the server has 8 space seperated columns, I need the first and last ones.
    This function gets the first and last column of server_file and saves it to chksum_file
    """
    with open(server_file) as fd, open(chksum_file, "w") as fout:
        lines = fd.readlines()
        for line in lines:
            lsplit = line.split()
            if len(lsplit) < 8 or " DIR " in line:
                continue

            # remove the first directory ('inputdata/') from the filename
            chksum = lsplit[0]
            fname = (lsplit[7]).split("/", 1)[1]
            fout.write(" ".join((chksum, fname)) + "\n")
    os.remove(server_file)


def _merge_chksum_files(new_file, old_file):
    """
    If more than one server checksum file is available, this merges the files and removes
    any duplicate lines
    """
    with open(old_file) as fin:
        lines = fin.readlines()
    with open(new_file) as fin:
        lines += fin.readlines()
    lines = set(lines)
    with open(new_file, "w") as fout:
        fout.write("".join(lines))
    os.remove(old_file)


def _download_if_in_repo(
    server, input_data_root, rel_path, isdirectory=False, ic_filepath=None
):
    """
    Return True if successfully downloaded
    server is an object handle of type CIME.Servers
    input_data_root is the local path to inputdata (DIN_LOC_ROOT)
    rel_path is the path to the file or directory relative to input_data_root
    user is the user name of the person running the script
    isdirectory indicates that this is a directory download rather than a single file
    """
    if not (rel_path or server.fileexists(rel_path)):
        return False
    full_path = os.path.join(input_data_root, rel_path)
    if ic_filepath:
        full_path = full_path.replace(ic_filepath, "/")
    logger.info(
        "Trying to download file: '{}' to path '{}' using {} protocol.".format(
            rel_path, full_path, type(server).__name__
        )
    )
    # Make sure local path exists, create if it does not
    if isdirectory or full_path.endswith(os.sep):
        if not os.path.exists(full_path):
            logger.info("Creating directory {}".format(full_path))
            os.makedirs(full_path + ".tmp")
        isdirectory = True
    elif not os.path.exists(os.path.dirname(full_path)):
        os.makedirs(os.path.dirname(full_path))

    # Use umask to make sure files are group read/writable. As long as parent directories
    # have +s, then everything should work.
    if isdirectory:
        success = server.getdirectory(rel_path, full_path + ".tmp")
        # this is intended to prevent a race condition in which
        # one case attempts to use a refdir before another one has
        # completed the download
        if success:
            os.rename(full_path + ".tmp", full_path)
        else:
            shutil.rmtree(full_path + ".tmp")
    else:
        success = server.getfile(rel_path, full_path)

    return success


def _check_all_input_data_impl(
    self,
    protocol,
    address,
    input_data_root,
    data_list_dir,
    download,
    chksum,
):
    success = False
    if protocol is not None and address is not None:
        success = self.check_input_data(
            protocol=protocol,
            address=address,
            download=download,
            input_data_root=input_data_root,
            data_list_dir=data_list_dir,
            chksum=chksum,
        )
    else:
        if chksum:
            chksum_found = _download_checksum_file(self.get_value("RUNDIR"))

        clm_usrdat_name = self.get_value("CLM_USRDAT_NAME")
        if clm_usrdat_name and clm_usrdat_name == "UNSET":
            clm_usrdat_name = None

        if download and clm_usrdat_name:
            success = _downloadfromserver(
                self,
                input_data_root,
                data_list_dir,
                attributes={"CLM_USRDAT_NAME": clm_usrdat_name},
            )
        if not success:
            success = self.check_input_data(
                protocol=protocol,
                address=address,
                download=False,
                input_data_root=input_data_root,
                data_list_dir=data_list_dir,
                chksum=chksum and chksum_found,
            )
        if download and not success:
            if chksum:
                chksum_found = _download_checksum_file(self.get_value("RUNDIR"))
            success = _downloadfromserver(self, input_data_root, data_list_dir)

    expect(
        not download or (download and success),
        "Could not find all inputdata on any server",
    )
    self.stage_refcase(input_data_root=input_data_root, data_list_dir=data_list_dir)
    return success


[docs] def check_all_input_data( self, protocol=None, address=None, input_data_root=None, data_list_dir="Buildconf", download=True, chksum=False, ): """ Read through all files of the form *.input_data_list in the data_list_dir directory. These files contain a list of input and boundary files needed by each model component. For each file in the list confirm that it is available in input_data_root and if not (optionally download it from a server at address using protocol. Perform a chksum of the downloaded file. """ # Run the entire impl in a SharedArea to help avoid permission problems with SharedArea(): return _check_all_input_data_impl( self, protocol, address, input_data_root, data_list_dir, download, chksum )
def _downloadfromserver(case, input_data_root, data_list_dir, attributes=None): """ Download files """ success = False protocol = "svn" inputdata = Inputdata() if not input_data_root: input_data_root = case.get_value("DIN_LOC_ROOT") while not success and protocol is not None: protocol, address, user, passwd, _, ic_filepath, _ = inputdata.get_next_server( attributes=attributes ) logger.info("Checking server {} with protocol {}".format(address, protocol)) success = case.check_input_data( protocol=protocol, address=address, download=True, input_data_root=input_data_root, data_list_dir=data_list_dir, user=user, passwd=passwd, ic_filepath=ic_filepath, ) return success
[docs] def stage_refcase(self, input_data_root=None, data_list_dir=None): """ Get a REFCASE for a hybrid or branch run This is the only case in which we are downloading an entire directory instead of a single file at a time. """ get_refcase = self.get_value("GET_REFCASE") run_type = self.get_value("RUN_TYPE") continue_run = self.get_value("CONTINUE_RUN") # We do not fully populate the inputdata directory on every # machine and do not expect every user to download the 3TB+ of # data in our inputdata repository. This code checks for the # existence of inputdata in the local inputdata directory and # attempts to download data from the server if it's needed and # missing. if get_refcase and run_type != "startup" and not continue_run: din_loc_root = self.get_value("DIN_LOC_ROOT") run_refdate = self.get_value("RUN_REFDATE") run_refcase = self.get_value("RUN_REFCASE") run_refdir = self.get_value("RUN_REFDIR") rundir = self.get_value("RUNDIR") if os.path.isabs(run_refdir): refdir = run_refdir expect( os.path.isdir(refdir), "Reference case directory {} does not exist or is not readable".format( refdir ), ) else: refdir = os.path.join(din_loc_root, run_refdir, run_refcase, run_refdate) if not os.path.isdir(refdir): logger.warning( "Refcase not found in {}, will attempt to download from inputdata".format( refdir ) ) with open( os.path.join("Buildconf", "refcase.input_data_list"), "w" ) as fd: fd.write("refdir = {}{}".format(refdir, os.sep)) if input_data_root is None: input_data_root = din_loc_root if data_list_dir is None: data_list_dir = "Buildconf" success = _downloadfromserver( self, input_data_root=input_data_root, data_list_dir=data_list_dir ) expect(success, "Could not download refcase from any server") logger.info(" - Prestaging REFCASE ({}) to {}".format(refdir, rundir)) # prestage the reference case's files. if not os.path.exists(rundir): logger.debug("Creating run directory: {}".format(rundir)) os.makedirs(rundir) rpointerfile = None # copy the refcases' rpointer files to the run directory for rpointerfile in glob.iglob(os.path.join("{}", "*rpointer*").format(refdir)): logger.info("Copy rpointer {}".format(rpointerfile)) safe_copy(rpointerfile, rundir) os.chmod(os.path.join(rundir, os.path.basename(rpointerfile)), 0o644) expect( rpointerfile, "Reference case directory {} does not contain any rpointer files".format( refdir ), ) # link everything else for rcfile in glob.iglob(os.path.join(refdir, "*")): rcbaseline = os.path.basename(rcfile) if not os.path.exists("{}/{}".format(rundir, rcbaseline)): logger.info("Staging file {}".format(rcfile)) os.symlink(rcfile, "{}/{}".format(rundir, rcbaseline)) # Backward compatibility, some old refcases have cam2 in the name # link to local cam file. for cam2file in glob.iglob(os.path.join("{}", "*.cam2.*").format(rundir)): camfile = cam2file.replace("cam2", "cam") os.symlink(cam2file, camfile) elif not get_refcase and run_type != "startup": logger.info( "GET_REFCASE is false, the user is expected to stage the refcase to the run directory." ) if os.path.exists(os.path.join("Buildconf", "refcase.input_data_list")): os.remove(os.path.join("Buildconf", "refcase.input_data_list")) return True
def _check_input_data_impl( case, protocol, address, input_data_root, data_list_dir, download, user, passwd, chksum, ic_filepath, ): case.load_env(reset=True) rundir = case.get_value("RUNDIR") # Fill in defaults as needed input_data_root = ( case.get_value("DIN_LOC_ROOT") if input_data_root is None else input_data_root ) input_ic_root = case.get_value("DIN_LOC_IC", resolved=True) expect( os.path.isdir(data_list_dir), "Invalid data_list_dir directory: '{}'".format(data_list_dir), ) data_list_files = find_files(data_list_dir, "*.input_data_list") if not data_list_files: logger.warning( "WARNING: No .input_data_list files found in dir '{}'".format(data_list_dir) ) no_files_missing = True if download: if protocol not in vars(CIME.Servers): logger.info("Client protocol {} not enabled".format(protocol)) return False logger.info( "Using protocol {} with user {} and passwd {}".format( protocol, user, passwd ) ) if protocol == "svn": server = CIME.Servers.SVN(address, user, passwd) elif protocol == "gftp": server = CIME.Servers.GridFTP(address, user, passwd) elif protocol == "ftp": server = CIME.Servers.FTP.ftp_login(address, user, passwd) elif protocol == "wget": server = CIME.Servers.WGET.wget_login(address, user, passwd) else: expect(False, "Unsupported inputdata protocol: {}".format(protocol)) if not server: return None for data_list_file in data_list_files: logger.info("Loading input file list: '{}'".format(data_list_file)) with open(data_list_file, "r") as fd: lines = fd.readlines() for line in lines: line = line.strip() use_ic_path = False if line and not line.startswith("#"): tokens = line.split("=") description, full_path = tokens[0].strip(), tokens[1].strip() if ( description.endswith("datapath") or description.endswith("data_path") or full_path.endswith("/dev/null") ): continue if description.endswith("file") or description.endswith("filename"): # There are required input data with key, or 'description' entries # that specify in their names whether they are files or filenames # rather than 'datapath's or 'data_path's so we check to make sure # the input data list has correct non-path values for input files. # This check happens whether or not a file already exists locally. expect( (not full_path.endswith(os.sep)), "Unsupported directory path in input_data_list named {}. Line entry is '{} = {}'.".format( data_list_file, description, full_path ), ) if full_path: # expand xml variables full_path = case.get_resolved_value(full_path) rel_path = full_path if input_ic_root and input_ic_root in full_path and ic_filepath: rel_path = full_path.replace(input_ic_root, ic_filepath) use_ic_path = True elif input_data_root in full_path: rel_path = full_path.replace(input_data_root, "") elif input_ic_root and ( input_ic_root not in input_data_root and input_ic_root in full_path ): if ic_filepath: rel_path = full_path.replace(input_ic_root, ic_filepath) use_ic_path = True model = os.path.basename(data_list_file).split(".")[0] isdirectory = rel_path.endswith(os.sep) if ( "/" in rel_path and rel_path == full_path and not full_path.startswith("unknown") ): # User pointing to a file outside of input_data_root, we cannot determine # rel_path, and so cannot download the file. If it already exists, we can # proceed if not os.path.exists(full_path): print( "Model {} missing file {} = '{}'".format( model, description, full_path ) ) # Data download path must be DIN_LOC_ROOT, DIN_LOC_IC or RUNDIR rundir = case.get_value("RUNDIR") if download: if full_path.startswith(rundir): filepath = os.path.dirname(full_path) if not os.path.exists(filepath): logger.info( "Creating directory {}".format(filepath) ) os.makedirs(filepath) tmppath = full_path[len(rundir) + 1 :] success = _download_if_in_repo( server, os.path.join(rundir, "inputdata"), tmppath[10:], isdirectory=isdirectory, ic_filepath="/", ) no_files_missing = success else: logger.warning( " Cannot download file since it lives outside of the input_data_root '{}'".format( input_data_root ) ) else: no_files_missing = False else: logger.debug(" Found input file: '{}'".format(full_path)) else: # There are some special values of rel_path that # we need to ignore - some of the component models # set things like 'NULL' or 'same_as_TS' - # basically if rel_path does not contain '/' (a # directory tree) you can assume it's a special # value and ignore it (perhaps with a warning) if ( "/" in rel_path and not os.path.exists(full_path) and not full_path.startswith("unknown") ): print( "Model {} missing file {} = '{}'".format( model, description, full_path ) ) if download: if use_ic_path: success = _download_if_in_repo( server, input_ic_root, rel_path.strip(os.sep), isdirectory=isdirectory, ic_filepath=ic_filepath, ) else: success = _download_if_in_repo( server, input_data_root, rel_path.strip(os.sep), isdirectory=isdirectory, ic_filepath=ic_filepath, ) if not success: no_files_missing = False if success and chksum: verify_chksum( input_data_root, rundir, rel_path.strip(os.sep), isdirectory, ) else: no_files_missing = False else: if chksum: verify_chksum( input_data_root, rundir, rel_path.strip(os.sep), isdirectory, ) logger.info( "Chksum passed for file {}".format( os.path.join(input_data_root, rel_path) ) ) logger.debug( " Already had input file: '{}'".format(full_path) ) else: model = os.path.basename(data_list_file).split(".")[0] logger.warning( "Model {} no file specified for {}".format(model, description) ) return no_files_missing
[docs] def check_input_data( case, protocol="svn", address=None, input_data_root=None, data_list_dir="Buildconf", download=False, user=None, passwd=None, chksum=False, ic_filepath=None, ): """ For a given case check for the relevant input data as specified in data_list_dir/*.input_data_list in the directory input_data_root, if not found optionally download it using the servers specified in config_inputdata.xml. If a chksum file is available compute the chksum and compare it to that in the file. Return True if no files missing """ # Run the entire impl in a SharedArea to help avoid permission problems with SharedArea(): return _check_input_data_impl( case, protocol, address, input_data_root, data_list_dir, download, user, passwd, chksum, ic_filepath, )
[docs] def verify_chksum(input_data_root, rundir, filename, isdirectory): """ For file in filename perform a chksum and compare the result to that stored in the local checksumfile, if isdirectory chksum all files in the directory of form *.* """ hashfile = os.path.join(rundir, local_chksum_file) if not chksum_hash: if not os.path.isfile(hashfile): logger.warning("Failed to find or download file {}".format(hashfile)) return with open(hashfile) as fd: lines = fd.readlines() for line in lines: fchksum, fname = line.split() if fname in chksum_hash: expect( chksum_hash[fname] == fchksum, " Inconsistent hashes in chksum for file {}".format(fname), ) else: chksum_hash[fname] = fchksum if isdirectory: filenames = glob.glob(os.path.join(filename, "*.*")) else: filenames = [filename] for fname in filenames: if not os.sep in fname: continue chksum = md5(os.path.join(input_data_root, fname)) if chksum_hash: if not fname in chksum_hash: logger.warning( "Did not find hash for file {} in chksum file {}".format( filename, hashfile ) ) else: expect( chksum == chksum_hash[fname], "chksum mismatch for file {} expected {} found {}".format( os.path.join(input_data_root, fname), chksum, chksum_hash[fname] ), )
[docs] def md5(fname): """ performs an md5 sum one chunk at a time to avoid memory issues with large files. """ hash_md5 = hashlib.md5() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest()