reformat celery logging to match danswer style logging across services (#2409)

* reformat celery logging to match danswer style logging across services

* mypy fixes

* handle logfile argument
This commit is contained in:
rkuo-danswer 2024-09-12 18:51:51 -07:00 committed by GitHub
parent e9a616e579
commit da8e68b320
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import json
import logging
import traceback
from datetime import timedelta
from typing import Any
@ -6,6 +7,7 @@ from typing import cast
import redis
from celery import Celery
from celery import current_task
from celery import signals
from celery import Task
from celery.contrib.abortable import AbortableTask # type: ignore
@ -64,6 +66,8 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import RedisPool
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
@ -136,8 +140,7 @@ def cleanup_connector_credential_pair_task(
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={connector_id} credential_id={credential_id}\n"
f"Stack Trace:\n{stack_trace}"
f"connector_id={connector_id} credential_id={credential_id}"
)
raise e
@ -883,6 +886,77 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
r.delete(key)
class CeleryTaskPlainFormatter(PlainFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
class CeleryTaskColoredFormatter(ColoredFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats celery's worker logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
root_formatter = ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
root_logger.addHandler(root_file_handler)
root_logger.setLevel(loglevel)
# reformats celery's task logger
task_formatter = CeleryTaskColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_formatter = CeleryTaskPlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
task_logger.addHandler(task_file_handler)
task_logger.setLevel(loglevel)
task_logger.propagate = False
#####
# Celery Beat (Periodic Tasks) Settings
#####

View File

@ -80,6 +80,16 @@ class DanswerLoggingAdapter(logging.LoggerAdapter):
)
class PlainFormatter(logging.Formatter):
"""Adds log levels."""
def format(self, record: logging.LogRecord) -> str:
levelname = record.levelname
level_display = f"{levelname}:"
formatted_message = super().format(record)
return f"{level_display.ljust(9)} {formatted_message}"
class ColoredFormatter(logging.Formatter):
"""Custom formatter to add colors to log levels."""