diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index a48d8aa4a..75f4143aa 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -1,4 +1,5 @@ import json +import logging import traceback from datetime import timedelta from typing import Any @@ -6,6 +7,7 @@ from typing import cast import redis from celery import Celery +from celery import current_task from celery import signals from celery import Task from celery.contrib.abortable import AbortableTask # type: ignore @@ -64,6 +66,8 @@ from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest from danswer.redis.redis_pool import RedisPool +from danswer.utils.logger import ColoredFormatter +from danswer.utils.logger import PlainFormatter from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import ( @@ -136,8 +140,7 @@ def cleanup_connector_credential_pair_task( add_deletion_failure_message(db_session, cc_pair.id, error_message) task_logger.exception( f"Failed to run connector_deletion. " - f"connector_id={connector_id} credential_id={credential_id}\n" - f"Stack Trace:\n{stack_trace}" + f"connector_id={connector_id} credential_id={credential_id}" ) raise e @@ -883,6 +886,77 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: r.delete(key) +class CeleryTaskPlainFormatter(PlainFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) + + +class CeleryTaskColoredFormatter(ColoredFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + # TODO: could unhardcode format and colorize and accept these as options from + # celery's config + + # reformats celery's worker logger + root_logger = logging.getLogger() + + root_handler = logging.StreamHandler() # Set up a handler for the root logger + root_formatter = ColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_handler.setFormatter(root_formatter) + root_logger.addHandler(root_handler) # Apply the handler to the root logger + + if logfile: + root_file_handler = logging.FileHandler(logfile) + root_file_formatter = PlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_file_handler.setFormatter(root_file_formatter) + root_logger.addHandler(root_file_handler) + + root_logger.setLevel(loglevel) + + # reformats celery's task logger + task_formatter = CeleryTaskColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_handler = logging.StreamHandler() # Set up a handler for the task logger + task_handler.setFormatter(task_formatter) + task_logger.addHandler(task_handler) # Apply the handler to the task logger + + if logfile: + task_file_handler = logging.FileHandler(logfile) + task_file_formatter = CeleryTaskPlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_file_handler.setFormatter(task_file_formatter) + task_logger.addHandler(task_file_handler) + + task_logger.setLevel(loglevel) + task_logger.propagate = False + + ##### # Celery Beat (Periodic Tasks) Settings ##### diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index 9489a6244..5436b2328 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -80,6 +80,16 @@ class DanswerLoggingAdapter(logging.LoggerAdapter): ) +class PlainFormatter(logging.Formatter): + """Adds log levels.""" + + def format(self, record: logging.LogRecord) -> str: + levelname = record.levelname + level_display = f"{levelname}:" + formatted_message = super().format(record) + return f"{level_display.ljust(9)} {formatted_message}" + + class ColoredFormatter(logging.Formatter): """Custom formatter to add colors to log levels."""