diff --git a/graphgen/generate.py b/graphgen/generate.py index eec168d6..b689133e 100644 --- a/graphgen/generate.py +++ b/graphgen/generate.py @@ -65,7 +65,7 @@ def main(): "GraphGen with unique ID %s logging to %s", unique_id, os.path.join( - working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log" + working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log" ), ) diff --git a/graphgen/utils/log.py b/graphgen/utils/log.py index 32b9bac6..b4e0e475 100644 --- a/graphgen/utils/log.py +++ b/graphgen/utils/log.py @@ -1,32 +1,55 @@ import logging +from logging.handlers import RotatingFileHandler + +from rich.logging import RichHandler logger = logging.getLogger("graphgen") -def set_logger(log_file: str, log_level: int = logging.INFO, if_stream: bool = True): - logger.setLevel(log_level) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) +def set_logger( + log_file: str, + log_level: int = logging.INFO, + *, + if_stream: bool = True, + max_bytes: int = 50 * 1024 * 1024, # 50 MB + backup_count: int = 5, + force: bool = False, +): - file_handler = logging.FileHandler(log_file, mode='w') - file_handler.setLevel(log_level) - file_handler.setFormatter(formatter) + if logger.hasHandlers() and not force: + return - stream_handler = None + if force: + logger.handlers.clear() - if if_stream: - stream_handler = logging.StreamHandler() - stream_handler.setLevel(log_level) - stream_handler.setFormatter(formatter) + logger.setLevel(log_level) + logger.propagate = False - if not logger.handlers: - logger.addHandler(file_handler) - if if_stream and stream_handler: - logger.addHandler(stream_handler) + if logger.handlers: + logger.handlers.clear() + + if if_stream: + console = RichHandler(level=log_level, show_path=False, rich_tracebacks=True) + console.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(console) + + file_handler = RotatingFileHandler( + log_file, + maxBytes=max_bytes, + backupCount=backup_count, + encoding="utf-8", + ) + file_handler.setLevel(log_level) + file_handler.setFormatter( + logging.Formatter( + "[%(asctime)s] %(levelname)s [%(name)s:%(filename)s:%(lineno)d] %(message)s", + datefmt="%y-%m-%d %H:%M:%S", + ) + ) + logger.addHandler(file_handler) def parse_log(log_file: str): - with open(log_file, "r", encoding='utf-8') as f: + with open(log_file, "r", encoding="utf-8") as f: lines = f.readlines() return lines