diff --git a/skm_pyutils/py_log.py b/skm_pyutils/py_log.py index cff72cd..9d35b43 100644 --- a/skm_pyutils/py_log.py +++ b/skm_pyutils/py_log.py @@ -1,7 +1,5 @@ """ Logging related functions. - -TODO test the except hooks more. """ import datetime import logging @@ -32,9 +30,7 @@ def log_exception(ex, more_info="", location=None): """ if location is None: - default_loc = os.path.join( - os.path.expanduser("~"), ".skm_python", "caught_errors.txt" - ) + default_loc = get_default_log_loc("caught_errors.txt") else: default_loc = location @@ -53,13 +49,12 @@ def log_exception(ex, more_info="", location=None): def default_excepthook(exc_type, exc_value, exc_traceback): """Any uncaught exceptions will be logged from here.""" - default_loc = os.path.join( - os.path.expanduser("~"), ".skm_python", "uncaught_errors.txt" - ) + default_loc = get_default_log_loc("uncaught_errors.txt") - this_logger = logging.getLogger(__name__) + file_logger = logging.getLogger(__name__) + file_logger.propagate = False handler = logging.FileHandler(default_loc) - this_logger.addHandler(handler) + file_logger.addHandler(handler) # Don't catch CTRL+C exceptions if issubclass(exc_type, KeyboardInterrupt): @@ -67,13 +62,12 @@ def default_excepthook(exc_type, exc_value, exc_traceback): return now = datetime.datetime.now() - this_logger.critical( + file_logger.critical( "\n----------Uncaught Exception at {}----------".format(now), exc_info=(exc_type, exc_value, exc_traceback), ) - sys.stdout.write = default_write - print("A fatal error occurred in this Python program") + print("\nA fatal error occurred in this Python program") print( "The error info was: {}".format( "".join( @@ -83,6 +77,8 @@ def default_excepthook(exc_type, exc_value, exc_traceback): ) print("Please report this to {} and provide the file {}".format("us", default_loc)) + sys.exit(-1) + def override_excepthook(excepthook=None): """ @@ -104,3 +100,9 @@ def override_excepthook(excepthook=None): if excepthook is None: excepthook = default_excepthook sys.excepthook = excepthook + + +def get_default_log_loc(name): + default_loc = os.path.join(os.path.expanduser("~"), ".skm_python", name) + + return default_loc