mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
261 lines
9.1 KiB
Python
261 lines
9.1 KiB
Python
import contextvars
|
|
import logging
|
|
import os
|
|
from collections.abc import MutableMapping
|
|
from logging.handlers import RotatingFileHandler
|
|
from typing import Any
|
|
|
|
from shared_configs.configs import DEV_LOGGING_ENABLED
|
|
from shared_configs.configs import LOG_FILE_NAME
|
|
from shared_configs.configs import LOG_LEVEL
|
|
from shared_configs.configs import MULTI_TENANT
|
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
|
from shared_configs.configs import SLACK_CHANNEL_ID
|
|
from shared_configs.configs import TENANT_ID_PREFIX
|
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
|
|
|
|
|
logging.addLevelName(logging.INFO + 5, "NOTICE")
|
|
|
|
pruning_ctx: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
|
|
"pruning_ctx", default=dict()
|
|
)
|
|
|
|
doc_permission_sync_ctx: contextvars.ContextVar[
|
|
dict[str, Any]
|
|
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
|
|
|
|
|
|
class LoggerContextVars:
|
|
@staticmethod
|
|
def reset() -> None:
|
|
pruning_ctx.set(dict())
|
|
doc_permission_sync_ctx.set(dict())
|
|
|
|
|
|
class TaskAttemptSingleton:
|
|
"""Used to tell if this process is an indexing job, and if so what is the
|
|
unique identifier for this indexing attempt. For things like the API server,
|
|
main background job (scheduler), etc. this will not be used."""
|
|
|
|
_INDEX_ATTEMPT_ID: None | int = None
|
|
_CONNECTOR_CREDENTIAL_PAIR_ID: None | int = None
|
|
|
|
@classmethod
|
|
def get_index_attempt_id(cls) -> None | int:
|
|
return cls._INDEX_ATTEMPT_ID
|
|
|
|
@classmethod
|
|
def get_connector_credential_pair_id(cls) -> None | int:
|
|
return cls._CONNECTOR_CREDENTIAL_PAIR_ID
|
|
|
|
@classmethod
|
|
def set_cc_and_index_id(
|
|
cls, index_attempt_id: int, connector_credential_pair_id: int
|
|
) -> None:
|
|
cls._INDEX_ATTEMPT_ID = index_attempt_id
|
|
cls._CONNECTOR_CREDENTIAL_PAIR_ID = connector_credential_pair_id
|
|
|
|
|
|
def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
|
|
log_level_dict = {
|
|
"CRITICAL": logging.CRITICAL,
|
|
"ERROR": logging.ERROR,
|
|
"WARNING": logging.WARNING,
|
|
"NOTICE": logging.getLevelName("NOTICE"),
|
|
"INFO": logging.INFO,
|
|
"DEBUG": logging.DEBUG,
|
|
"NOTSET": logging.NOTSET,
|
|
}
|
|
|
|
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
|
|
|
|
|
|
class OnyxLoggingAdapter(logging.LoggerAdapter):
|
|
def process(
|
|
self, msg: str, kwargs: MutableMapping[str, Any]
|
|
) -> tuple[str, MutableMapping[str, Any]]:
|
|
# If this is an indexing job, add the attempt ID to the log message
|
|
# This helps filter the logs for this specific indexing
|
|
while True:
|
|
pruning_ctx_dict = pruning_ctx.get()
|
|
if len(pruning_ctx_dict) > 0:
|
|
if "request_id" in pruning_ctx_dict:
|
|
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
|
|
|
|
if "cc_pair_id" in pruning_ctx_dict:
|
|
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
|
break
|
|
|
|
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
|
if len(doc_permission_sync_ctx_dict) > 0:
|
|
if "request_id" in doc_permission_sync_ctx_dict:
|
|
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
|
break
|
|
|
|
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
|
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
|
|
|
if index_attempt_id is not None:
|
|
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
|
|
|
if cc_pair_id is not None:
|
|
msg = f"[CC Pair: {cc_pair_id}] {msg}"
|
|
|
|
break
|
|
# Add tenant information if it differs from default
|
|
# This will always be the case for authenticated API requests
|
|
if MULTI_TENANT:
|
|
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
|
if tenant_id != POSTGRES_DEFAULT_SCHEMA and tenant_id is not None:
|
|
# Strip tenant_ prefix and take first 8 chars for cleaner logs
|
|
tenant_display = tenant_id.removeprefix(TENANT_ID_PREFIX)
|
|
short_tenant = (
|
|
tenant_display[:8] if len(tenant_display) > 8 else tenant_display
|
|
)
|
|
msg = f"[t:{short_tenant}] {msg}"
|
|
|
|
# For Slack Bot, logs the channel relevant to the request
|
|
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
|
|
if channel_id:
|
|
msg = f"[Channel ID: {channel_id}] {msg}"
|
|
|
|
return msg, kwargs
|
|
|
|
def notice(self, msg: Any, *args: Any, **kwargs: Any) -> None:
|
|
# Stacklevel is set to 2 to point to the actual caller of notice instead of here
|
|
self.log(
|
|
logging.getLevelName("NOTICE"), str(msg), *args, **kwargs, stacklevel=2
|
|
)
|
|
|
|
|
|
class PlainFormatter(logging.Formatter):
|
|
"""Adds log levels."""
|
|
|
|
def format(self, record: logging.LogRecord) -> str:
|
|
levelname = record.levelname
|
|
level_display = f"{levelname}:"
|
|
formatted_message = super().format(record)
|
|
return f"{level_display.ljust(9)} {formatted_message}"
|
|
|
|
|
|
class ColoredFormatter(logging.Formatter):
|
|
"""Custom formatter to add colors to log levels."""
|
|
|
|
COLORS = {
|
|
"CRITICAL": "\033[91m", # Red
|
|
"ERROR": "\033[91m", # Red
|
|
"WARNING": "\033[93m", # Yellow
|
|
"NOTICE": "\033[94m", # Blue
|
|
"INFO": "\033[92m", # Green
|
|
"DEBUG": "\033[96m", # Light Green
|
|
"NOTSET": "\033[91m", # Reset
|
|
}
|
|
|
|
def format(self, record: logging.LogRecord) -> str:
|
|
levelname = record.levelname
|
|
if levelname in self.COLORS:
|
|
prefix = self.COLORS[levelname]
|
|
suffix = "\033[0m"
|
|
formatted_message = super().format(record)
|
|
# Ensure the levelname with colon is 9 characters long
|
|
# accounts for the extra characters for coloring
|
|
level_display = f"{prefix}{levelname}{suffix}:"
|
|
return f"{level_display.ljust(18)} {formatted_message}"
|
|
return super().format(record)
|
|
|
|
|
|
def get_standard_formatter() -> ColoredFormatter:
|
|
"""Returns a standard colored logging formatter."""
|
|
return ColoredFormatter(
|
|
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
|
datefmt="%m/%d/%Y %I:%M:%S %p",
|
|
)
|
|
|
|
|
|
DANSWER_DOCKER_ENV_STR = "DANSWER_RUNNING_IN_DOCKER"
|
|
|
|
|
|
def is_running_in_container() -> bool:
|
|
return os.getenv(DANSWER_DOCKER_ENV_STR) == "true"
|
|
|
|
|
|
def setup_logger(
|
|
name: str = __name__,
|
|
log_level: int = get_log_level_from_str(),
|
|
extra: MutableMapping[str, Any] | None = None,
|
|
) -> OnyxLoggingAdapter:
|
|
logger = logging.getLogger(name)
|
|
|
|
# If the logger already has handlers, assume it was already configured and return it.
|
|
if logger.handlers:
|
|
return OnyxLoggingAdapter(logger, extra=extra)
|
|
|
|
logger.setLevel(log_level)
|
|
|
|
formatter = get_standard_formatter()
|
|
|
|
handler = logging.StreamHandler()
|
|
handler.setLevel(log_level)
|
|
handler.setFormatter(formatter)
|
|
|
|
logger.addHandler(handler)
|
|
|
|
uvicorn_logger = logging.getLogger("uvicorn.access")
|
|
if uvicorn_logger:
|
|
uvicorn_logger.handlers = []
|
|
uvicorn_logger.addHandler(handler)
|
|
uvicorn_logger.setLevel(log_level)
|
|
|
|
is_containerized = is_running_in_container()
|
|
if LOG_FILE_NAME and (is_containerized or DEV_LOGGING_ENABLED):
|
|
log_levels = ["debug", "info", "notice"]
|
|
for level in log_levels:
|
|
file_name = (
|
|
f"/var/log/{LOG_FILE_NAME}_{level}.log"
|
|
if is_containerized
|
|
else f"./log/{LOG_FILE_NAME}_{level}.log"
|
|
)
|
|
file_handler = RotatingFileHandler(
|
|
file_name,
|
|
maxBytes=25 * 1024 * 1024, # 25 MB
|
|
backupCount=5, # Keep 5 backup files
|
|
)
|
|
file_handler.setLevel(get_log_level_from_str(level))
|
|
file_handler.setFormatter(formatter)
|
|
logger.addHandler(file_handler)
|
|
|
|
if uvicorn_logger:
|
|
uvicorn_logger.addHandler(file_handler)
|
|
|
|
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
|
|
|
|
return OnyxLoggingAdapter(logger, extra=extra)
|
|
|
|
|
|
def print_loggers() -> None:
|
|
"""Print information about all loggers. Use to debug logging issues."""
|
|
root_logger = logging.getLogger()
|
|
loggers: list[logging.Logger | logging.PlaceHolder] = [root_logger]
|
|
loggers.extend(logging.Logger.manager.loggerDict.values())
|
|
|
|
for logger in loggers:
|
|
if isinstance(logger, logging.PlaceHolder):
|
|
# Skip placeholders that aren't actual loggers
|
|
continue
|
|
|
|
print(f"Logger: '{logger.name}' (Level: {logging.getLevelName(logger.level)})")
|
|
if logger.handlers:
|
|
for handler in logger.handlers:
|
|
print(f" Handler: {handler}")
|
|
else:
|
|
print(" No handlers")
|
|
|
|
print(f" Propagate: {logger.propagate}")
|
|
print()
|
|
|
|
|
|
def format_error_for_logging(e: Exception) -> str:
|
|
"""Clean error message by removing newlines for better logging."""
|
|
return str(e).replace("\n", " ")
|