-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #280 from HSF/dev
Dev
- Loading branch information
Showing
19 changed files
with
632 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,16 +6,19 @@ | |
# http://www.apache.org/licenses/LICENSE-2.0OA | ||
# | ||
# Authors: | ||
# - Wen Guan, <[email protected]>, 2019 - 2023 | ||
# - Wen Guan, <[email protected]>, 2019 - 2024 | ||
|
||
|
||
import base64 | ||
import errno | ||
import datetime | ||
import importlib | ||
import logging | ||
import json | ||
import os | ||
import re | ||
import requests | ||
import signal | ||
import subprocess | ||
import sys | ||
import tarfile | ||
|
@@ -27,6 +30,7 @@ | |
from itertools import groupby | ||
from operator import itemgetter | ||
from packaging import version as packaging_version | ||
from typing import Any, Callable | ||
|
||
from idds.common.config import (config_has_section, config_has_option, | ||
config_get, config_get_bool) | ||
|
@@ -234,15 +238,112 @@ def check_database(): | |
return False | ||
|
||
|
||
def run_process(cmd, stdout=None, stderr=None): | ||
def kill_process_group(pgrp, nap=10): | ||
""" | ||
Kill the process group. | ||
DO NOT MOVE TO PROCESSES.PY - will lead to circular import since execute() needs it as well. | ||
:param pgrp: process group id (int). | ||
:param nap: napping time between kill signals in seconds (int) | ||
:return: boolean (True if SIGTERM followed by SIGKILL signalling was successful) | ||
""" | ||
|
||
status = False | ||
_sleep = True | ||
|
||
# kill the process gracefully | ||
print(f"killing group process {pgrp}") | ||
try: | ||
os.killpg(pgrp, signal.SIGTERM) | ||
except Exception as error: | ||
print(f"exception thrown when killing child group process under SIGTERM: {error}") | ||
_sleep = False | ||
else: | ||
print(f"SIGTERM sent to process group {pgrp}") | ||
|
||
if _sleep: | ||
print(f"sleeping {nap} s to allow processes to exit") | ||
time.sleep(nap) | ||
|
||
try: | ||
os.killpg(pgrp, signal.SIGKILL) | ||
except Exception as error: | ||
print(f"exception thrown when killing child group process with SIGKILL: {error}") | ||
else: | ||
print(f"SIGKILL sent to process group {pgrp}") | ||
status = True | ||
|
||
return status | ||
|
||
|
||
def kill_all(process: Any) -> str: | ||
""" | ||
Kill all processes after a time-out exception in process.communication(). | ||
:param process: process object | ||
:return: stderr (str). | ||
""" | ||
|
||
stderr = '' | ||
try: | ||
print('killing lingering subprocess and process group') | ||
time.sleep(1) | ||
# process.kill() | ||
kill_process_group(os.getpgid(process.pid)) | ||
except ProcessLookupError as exc: | ||
stderr += f'\n(kill process group) ProcessLookupError={exc}' | ||
except Exception as exc: | ||
stderr += f'\n(kill_all 1) exception caught: {exc}' | ||
try: | ||
print('killing lingering process') | ||
time.sleep(1) | ||
os.kill(process.pid, signal.SIGTERM) | ||
print('sleeping a bit before sending SIGKILL') | ||
time.sleep(10) | ||
os.kill(process.pid, signal.SIGKILL) | ||
except ProcessLookupError as exc: | ||
stderr += f'\n(kill process) ProcessLookupError={exc}' | ||
except Exception as exc: | ||
stderr += f'\n(kill_all 2) exception caught: {exc}' | ||
print(f'sent soft kill signals - final stderr: {stderr}') | ||
return stderr | ||
|
||
|
||
def run_process(cmd, stdout=None, stderr=None, wait=False, timeout=7 * 24 * 3600): | ||
""" | ||
Runs a command in an out-of-procees shell. | ||
""" | ||
print(f"To run command: {cmd}") | ||
if stdout and stderr: | ||
process = subprocess.Popen(cmd, shell=True, stdout=stdout, stderr=stderr, preexec_fn=os.setsid) | ||
process = subprocess.Popen(cmd, shell=True, stdout=stdout, stderr=stderr, preexec_fn=os.setsid, encoding='utf-8') | ||
else: | ||
process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setsid, encoding='utf-8') | ||
if not wait: | ||
return process | ||
|
||
try: | ||
print(f'subprocess.communicate() will use timeout={timeout} s') | ||
process.communicate(timeout=timeout) | ||
except subprocess.TimeoutExpired as ex: | ||
stderr = f'subprocess communicate sent TimeoutExpired: {ex}' | ||
print(stderr) | ||
stderr = kill_all(process) | ||
print(f'Killing process: {stderr}') | ||
exit_code = -1 | ||
except Exception as ex: | ||
stderr = f'subprocess has an exception: {ex}' | ||
print(stderr) | ||
stderr = kill_all(process) | ||
print(f'Killing process: {stderr}') | ||
exit_code = -1 | ||
else: | ||
process = subprocess.Popen(cmd, shell=True) | ||
return process | ||
exit_code = process.poll() | ||
|
||
try: | ||
process.wait(timeout=60) | ||
except subprocess.TimeoutExpired: | ||
print("process did not complete within the timeout of 60s - terminating") | ||
process.terminate() | ||
return exit_code | ||
|
||
|
||
def run_command(cmd): | ||
|
@@ -630,3 +731,148 @@ def group_list(input_list, key): | |
update_groups[item_tuple] = {'keys': [], 'items': item} | ||
update_groups[item_tuple]['keys'].append(item_key) | ||
return update_groups | ||
|
||
|
||
def import_fun(name: str) -> Callable[..., Any]: | ||
"""Returns a function from a dotted path name. Example: `path.to.module:func`. | ||
When the attribute we look for is a staticmethod, module name in its | ||
dotted path is not the last-before-end word | ||
E.g.: package_a.package_b.module_a:ClassA.my_static_method | ||
Thus we remove the bits from the end of the name until we can import it | ||
Args: | ||
name (str): The name (reference) to the path. | ||
Raises: | ||
ValueError: If no module is found or invalid attribute name. | ||
Returns: | ||
Any: An attribute (normally a Callable) | ||
""" | ||
name_bits = name.split(':') | ||
module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]] | ||
module_name_bits = module_name_bits.split('.') | ||
attribute_bits = attribute_bits.split('.') | ||
module = None | ||
while len(module_name_bits): | ||
try: | ||
module_name = '.'.join(module_name_bits) | ||
module = importlib.import_module(module_name) | ||
break | ||
except ImportError: | ||
attribute_bits.insert(0, module_name_bits.pop()) | ||
|
||
if module is None: | ||
# maybe it's a builtin | ||
try: | ||
return __builtins__[name] | ||
except KeyError: | ||
raise ValueError('Invalid attribute name: %s' % name) | ||
|
||
attribute_name = '.'.join(attribute_bits) | ||
if hasattr(module, attribute_name): | ||
return getattr(module, attribute_name) | ||
# staticmethods | ||
attribute_name = attribute_bits.pop() | ||
attribute_owner_name = '.'.join(attribute_bits) | ||
try: | ||
attribute_owner = getattr(module, attribute_owner_name) | ||
except: # noqa | ||
raise ValueError('Invalid attribute name: %s' % attribute_name) | ||
|
||
if not hasattr(attribute_owner, attribute_name): | ||
raise ValueError('Invalid attribute name: %s' % name) | ||
return getattr(attribute_owner, attribute_name) | ||
|
||
|
||
def import_attribute(name: str) -> Callable[..., Any]: | ||
"""Returns an attribute from a dotted path name. Example: `path.to.func`. | ||
When the attribute we look for is a staticmethod, module name in its | ||
dotted path is not the last-before-end word | ||
E.g.: package_a.package_b.module_a.ClassA.my_static_method | ||
Thus we remove the bits from the end of the name until we can import it | ||
Args: | ||
name (str): The name (reference) to the path. | ||
Raises: | ||
ValueError: If no module is found or invalid attribute name. | ||
Returns: | ||
Any: An attribute (normally a Callable) | ||
""" | ||
name_bits = name.split('.') | ||
module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]] | ||
module = None | ||
while len(module_name_bits): | ||
try: | ||
module_name = '.'.join(module_name_bits) | ||
module = importlib.import_module(module_name) | ||
break | ||
except ImportError: | ||
attribute_bits.insert(0, module_name_bits.pop()) | ||
|
||
if module is None: | ||
# maybe it's a builtin | ||
try: | ||
return __builtins__[name] | ||
except KeyError: | ||
raise ValueError('Invalid attribute name: %s' % name) | ||
|
||
attribute_name = '.'.join(attribute_bits) | ||
if hasattr(module, attribute_name): | ||
return getattr(module, attribute_name) | ||
# staticmethods | ||
attribute_name = attribute_bits.pop() | ||
attribute_owner_name = '.'.join(attribute_bits) | ||
try: | ||
attribute_owner = getattr(module, attribute_owner_name) | ||
except: # noqa | ||
raise ValueError('Invalid attribute name: %s' % attribute_name) | ||
|
||
if not hasattr(attribute_owner, attribute_name): | ||
raise ValueError('Invalid attribute name: %s' % name) | ||
return getattr(attribute_owner, attribute_name) | ||
|
||
|
||
def decode_base64(sb): | ||
try: | ||
if isinstance(sb, str): | ||
sb_bytes = bytes(sb, 'ascii') | ||
elif isinstance(sb, bytes): | ||
sb_bytes = sb | ||
else: | ||
return sb | ||
return base64.b64decode(sb_bytes).decode("utf-8") | ||
except Exception as ex: | ||
logging.error("decode_base64 %s: %s" % (sb, ex)) | ||
return sb | ||
|
||
|
||
def encode_base64(sb): | ||
try: | ||
if isinstance(sb, str): | ||
sb_bytes = bytes(sb, 'ascii') | ||
elif isinstance(sb, bytes): | ||
sb_bytes = sb | ||
return base64.b64encode(sb_bytes).decode("utf-8") | ||
except Exception as ex: | ||
logging.error("encode_base64 %s: %s" % (sb, ex)) | ||
return sb | ||
|
||
|
||
def create_archive_file(work_dir, archive_filename, files): | ||
if not archive_filename.startswith("/"): | ||
archive_filename = os.path.join(work_dir, archive_filename) | ||
|
||
with tarfile.open(archive_filename, "w:gz", dereference=True) as tar: | ||
for local_file in files: | ||
# base_name = os.path.basename(local_file) | ||
tar.add(local_file, arcname=os.path.basename(local_file)) | ||
return archive_filename |
Oops, something went wrong.