Source code for aiida_fireworks_scheduler.awareness

"""
Runtime scheduler awareness
:noindex:
"""
import subprocess
import os
import re
import tempfile
import logging
from datetime import datetime, timedelta, timezone

LOGGER = logging.getLogger(__name__)


[docs]class SchedulerAwareness: """Scheduler object"""
[docs] def __init__(self, *args, **kwargs): """SchedulerAwareness object for accessing information from the scheduler""" del args del kwargs self._job_id = None self._ncpus = None
[docs] def get_n_cpus(self): """Return the number of CPUS in this job""" raise NotImplementedError
@property def user_name(self): """Return the name of the current user""" return os.environ['USER']
[docs] def get_remaining_seconds(self): """Get the remaining time before this job gets killed""" raise NotImplementedError
@property def is_in_job(self): """Return wether I am in a remote job""" if self.job_id is None: return False return True @property def job_id(self): """ID of the current job""" raise NotImplementedError
[docs] @classmethod def get_awareness(cls): """Automatically get the specialised Awareness instance""" for trial in [SlurmAwareness, SGEAwareness, DummyAwareness]: obj = trial() if obj.is_in_job: return obj return None
[docs]class DummyAwareness(SchedulerAwareness): """DummyAwareness for running jobs locally""" DEFAULT_REMAINING_TIME = 3600 * 24 * 30
[docs] def __init__(self, *args, **kwargs): """Instantiate an DummyAwareness - as if we have lots of time to run""" super(DummyAwareness, self).__init__(*args, **kwargs) self._job_id = str('0')
[docs] def get_n_cpus(self): """Get the number of CPUS""" return 4
@property def job_id(self): """The id of the job""" return self._job_id
[docs] def get_remaining_seconds(self): """Get the remaining time. Default to 30 days""" return self.DEFAULT_REMAINING_TIME
@property def is_in_job(self): """Are we inside an scheduler job - always true for a dummy""" return True
[docs]class SGEAwareness(SchedulerAwareness): """SGE runtime awareness"""
[docs] def __init__(self, *args, **kwargs): """Initialise the SGEAwareness object""" super(SGEAwareness, self).__init__(*args, **kwargs) if self.is_in_job: self._readtask_info() self._start_time = None self._end_time = None
@property def job_id(self): """ID of the job""" if self._job_id is None: job_id = os.environ.get('JOB_ID') task_id = os.environ.get('SGE_TASK_ID') if task_id and task_id != 'undefined': job_id = job_id + '.' + task_id LOGGER.warning( 'WARNING: REMAINING TIME IS NOT CORRECT FOR TASK ARRAY') self._job_id = job_id return self._job_id
[docs] def _readtask_info(self): """Read more detailed task infomation""" raw_data = subprocess.check_output( ['qstat', '-j', f'{self.job_id}'], # pylint: disable=unexpected-keyword-arg universal_newlines=True) raw_data = raw_data.split('\n') task_info = {} for line in raw_data[1:]: # Ignore lines that are not in the right format try: key, value = line.split(':', maxsplit=1) except ValueError: continue task_info[key.strip()] = value.strip() self._task_info = task_info
[docs] def get_n_cpus(self): """Get the number of CPUS""" nslots = os.environ.get('NSLOTS') if nslots: return int(nslots) return None
[docs] def get_max_run_seconds(self): """Return the maximum run time in seconds""" rlist = self._task_info['hard resource_list'] match = re.search(r'h_rt=(\d+)', rlist) if match: return int(match.group(1)) return None
[docs] def get_end_time(self, refresh=False): """Return the time when the job is expected to finish""" end_time = self.get_start_time(refresh=refresh) + timedelta( seconds=self.get_max_run_seconds()) return end_time
[docs] def get_start_time(self, refresh=False): """Return the start time of this job""" if self._start_time is None or refresh: output = subprocess.check_output( # pylint: disable=unexpected-keyword-arg ['qstat', '-j', str(self.job_id), '-xml'], universal_newlines=True) match = re.search(r'<JAT_start_time>(.+)</JAT_start_time>', output) if match: raw = match.group(1) time_int = int(raw) # SchedulerAwareness always use UTC time - not may note be true everywhere start_time = datetime.utcfromtimestamp(time_int).replace( tzinfo=timezone.utc) self._start_time = start_time return self._start_time
[docs] def get_remaining_seconds(self): """Return the remaining time in seconds""" # Everything much be time zone aware to work with BST tdelta = self.get_end_time() - datetime.now().astimezone() return int(tdelta.total_seconds())
[docs]class SlurmAwareness(SchedulerAwareness): """SlurmAwareness object for storing and extracting information in slurm""" _task_info = None _warning = 0
[docs] def __init__(self): """Initialise and SlurmAwareness instance""" super(SlurmAwareness, self).__init__() self.task_info = {} if self._task_info is None: self._readtask_info() self._task_info = self.task_info else: self.task_info = self._task_info
@property def is_in_job(self): """Wether I am in a job""" job_id = os.environ.get('SLURM_JOB_ID', None) if job_id is None: return False return True @property def job_id(self): """Id of the job""" if self._job_id is None: self._job_id = os.environ.get('SLURM_JOB_ID') return self._job_id
[docs] def _readtask_info(self): """A function to extract information from environmental variables SLURM_JOB_ID unique to each job Return an dictionnary contain job information. If not in slurm, return None TODO Refactor avoid saving intermediate file """ # We proceeed sinfo_dict = {} try: job_id = os.environ['SLURM_JOB_ID'] except KeyError: if self._warning == 0: LOGGER.debug('NOT STARTED FROM SLURM') self._warning += 1 self.task_info = {} return # Read information from scontrol commend # Temporary file for storing output with tempfile.TemporaryFile(mode='w+') as tmp_file: subprocess.run('scontrol show jobid={:s}'.format(job_id), shell=True, check=True, stdout=tmp_file) # Iterate through lines tmp_file.seek(0) for line in tmp_file: # Iterate through each pair for pair in line.split(): # Parse each pair pair_s = pair.split('=', maxsplit=2) if len(pair_s) == 2: sinfo_dict[pair_s[0]] = pair_s[1] # Empty field - put None elif len(pair_s) == 1: sinfo_dict[pair_s[0]] = None type(self)._task_info = sinfo_dict self.task_info = sinfo_dict
[docs] def get_end_time(self): """ Query the end time of an job Return a datetime object """ if self.task_info: end_time = datetime.strptime(self.task_info['EndTime'], '%Y-%m-%dT%H:%M:%S') else: end_time = None return end_time
[docs] def get_remaining_seconds(self): """Return the remaining time in seconds""" return int((self.get_end_time() - datetime.now()).total_seconds())
[docs] def get_n_cpus(self): """Return number of CPU allocated""" return self.task_info.get('NumCPUs', None)