261 lines
9.1 KiB
Python

import contextvars
import logging
import os
from collections.abc import MutableMapping
from logging.handlers import RotatingFileHandler
from typing import Any
from shared_configs.configs import DEV_LOGGING_ENABLED
from shared_configs.configs import LOG_FILE_NAME
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logging.addLevelName(logging.INFO + 5, "NOTICE")
pruning_ctx: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
"pruning_ctx", default=dict()
)
doc_permission_sync_ctx: contextvars.ContextVar[
dict[str, Any]
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
class LoggerContextVars:
@staticmethod
def reset() -> None:
pruning_ctx.set(dict())
doc_permission_sync_ctx.set(dict())
class TaskAttemptSingleton:
"""Used to tell if this process is an indexing job, and if so what is the
unique identifier for this indexing attempt. For things like the API server,
main background job (scheduler), etc. this will not be used."""
_INDEX_ATTEMPT_ID: None | int = None
_CONNECTOR_CREDENTIAL_PAIR_ID: None | int = None
@classmethod
def get_index_attempt_id(cls) -> None | int:
return cls._INDEX_ATTEMPT_ID
@classmethod
def get_connector_credential_pair_id(cls) -> None | int:
return cls._CONNECTOR_CREDENTIAL_PAIR_ID
@classmethod
def set_cc_and_index_id(
cls, index_attempt_id: int, connector_credential_pair_id: int
) -> None:
cls._INDEX_ATTEMPT_ID = index_attempt_id
cls._CONNECTOR_CREDENTIAL_PAIR_ID = connector_credential_pair_id
def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
log_level_dict = {
"CRITICAL": logging.CRITICAL,
"ERROR": logging.ERROR,
"WARNING": logging.WARNING,
"NOTICE": logging.getLevelName("NOTICE"),
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"NOTSET": logging.NOTSET,
}
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
class OnyxLoggingAdapter(logging.LoggerAdapter):
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
) -> tuple[str, MutableMapping[str, Any]]:
# If this is an indexing job, add the attempt ID to the log message
# This helps filter the logs for this specific indexing
while True:
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
break
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
if len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
break
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
if index_attempt_id is not None:
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
if cc_pair_id is not None:
msg = f"[CC Pair: {cc_pair_id}] {msg}"
break
# Add tenant information if it differs from default
# This will always be the case for authenticated API requests
if MULTI_TENANT:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id != POSTGRES_DEFAULT_SCHEMA and tenant_id is not None:
# Strip tenant_ prefix and take first 8 chars for cleaner logs
tenant_display = tenant_id.removeprefix(TENANT_ID_PREFIX)
short_tenant = (
tenant_display[:8] if len(tenant_display) > 8 else tenant_display
)
msg = f"[t:{short_tenant}] {msg}"
# For Slack Bot, logs the channel relevant to the request
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
if channel_id:
msg = f"[Channel ID: {channel_id}] {msg}"
return msg, kwargs
def notice(self, msg: Any, *args: Any, **kwargs: Any) -> None:
# Stacklevel is set to 2 to point to the actual caller of notice instead of here
self.log(
logging.getLevelName("NOTICE"), str(msg), *args, **kwargs, stacklevel=2
)
class PlainFormatter(logging.Formatter):
"""Adds log levels."""
def format(self, record: logging.LogRecord) -> str:
levelname = record.levelname
level_display = f"{levelname}:"
formatted_message = super().format(record)
return f"{level_display.ljust(9)} {formatted_message}"
class ColoredFormatter(logging.Formatter):
"""Custom formatter to add colors to log levels."""
COLORS = {
"CRITICAL": "\033[91m", # Red
"ERROR": "\033[91m", # Red
"WARNING": "\033[93m", # Yellow
"NOTICE": "\033[94m", # Blue
"INFO": "\033[92m", # Green
"DEBUG": "\033[96m", # Light Green
"NOTSET": "\033[91m", # Reset
}
def format(self, record: logging.LogRecord) -> str:
levelname = record.levelname
if levelname in self.COLORS:
prefix = self.COLORS[levelname]
suffix = "\033[0m"
formatted_message = super().format(record)
# Ensure the levelname with colon is 9 characters long
# accounts for the extra characters for coloring
level_display = f"{prefix}{levelname}{suffix}:"
return f"{level_display.ljust(18)} {formatted_message}"
return super().format(record)
def get_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
DANSWER_DOCKER_ENV_STR = "DANSWER_RUNNING_IN_DOCKER"
def is_running_in_container() -> bool:
return os.getenv(DANSWER_DOCKER_ENV_STR) == "true"
def setup_logger(
name: str = __name__,
log_level: int = get_log_level_from_str(),
extra: MutableMapping[str, Any] | None = None,
) -> OnyxLoggingAdapter:
logger = logging.getLogger(name)
# If the logger already has handlers, assume it was already configured and return it.
if logger.handlers:
return OnyxLoggingAdapter(logger, extra=extra)
logger.setLevel(log_level)
formatter = get_standard_formatter()
handler = logging.StreamHandler()
handler.setLevel(log_level)
handler.setFormatter(formatter)
logger.addHandler(handler)
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
uvicorn_logger.handlers = []
uvicorn_logger.addHandler(handler)
uvicorn_logger.setLevel(log_level)
is_containerized = is_running_in_container()
if LOG_FILE_NAME and (is_containerized or DEV_LOGGING_ENABLED):
log_levels = ["debug", "info", "notice"]
for level in log_levels:
file_name = (
f"/var/log/{LOG_FILE_NAME}_{level}.log"
if is_containerized
else f"./log/{LOG_FILE_NAME}_{level}.log"
)
file_handler = RotatingFileHandler(
file_name,
maxBytes=25 * 1024 * 1024, # 25 MB
backupCount=5, # Keep 5 backup files
)
file_handler.setLevel(get_log_level_from_str(level))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if uvicorn_logger:
uvicorn_logger.addHandler(file_handler)
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
return OnyxLoggingAdapter(logger, extra=extra)
def print_loggers() -> None:
"""Print information about all loggers. Use to debug logging issues."""
root_logger = logging.getLogger()
loggers: list[logging.Logger | logging.PlaceHolder] = [root_logger]
loggers.extend(logging.Logger.manager.loggerDict.values())
for logger in loggers:
if isinstance(logger, logging.PlaceHolder):
# Skip placeholders that aren't actual loggers
continue
print(f"Logger: '{logger.name}' (Level: {logging.getLevelName(logger.level)})")
if logger.handlers:
for handler in logger.handlers:
print(f" Handler: {handler}")
else:
print(" No handlers")
print(f" Propagate: {logger.propagate}")
print()
def format_error_for_logging(e: Exception) -> str:
"""Clean error message by removing newlines for better logging."""
return str(e).replace("\n", " ")