mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-18 05:41:58 +01:00
* Update text embedding model to version 005 and enhance embedding retrieval process * re * Fix formatting issues * Add support for Bedrock reranking provider and AWS credentials handling * fix: improve AWS key format validation and error messages * Fix vertex embedding model crash * feat: add environment template for local development setup * Add display name for Claude 3.7 Sonnet model * Add display names for Gemini 2.0 models and update Claude 3.7 Sonnet entry * Fix ruff errors by ensuring lines are within 130 characters * revert to currently default onyx browser settings * add / fix boto requirements --------- Co-authored-by: ferdinand loesch <f.loesch@sportradar.com> Co-authored-by: Ferdinand Loesch <ferdinandloesch@me.com> Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
import asyncio
|
|
import time
|
|
from collections.abc import Callable
|
|
from collections.abc import Generator
|
|
from collections.abc import Iterator
|
|
from functools import wraps
|
|
from typing import Any
|
|
from typing import cast
|
|
from typing import TypeVar
|
|
|
|
import torch
|
|
|
|
from model_server.constants import GPUStatus
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
F = TypeVar("F", bound=Callable)
|
|
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
|
|
|
|
|
def simple_log_function_time(
|
|
func_name: str | None = None,
|
|
debug_only: bool = False,
|
|
include_args: bool = False,
|
|
) -> Callable[[F], F]:
|
|
def decorator(func: F) -> F:
|
|
if asyncio.iscoroutinefunction(func):
|
|
|
|
@wraps(func)
|
|
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
|
|
start_time = time.time()
|
|
result = await func(*args, **kwargs)
|
|
elapsed_time_str = str(time.time() - start_time)
|
|
log_name = func_name or func.__name__
|
|
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
|
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
|
if debug_only:
|
|
logger.debug(final_log)
|
|
else:
|
|
logger.notice(final_log)
|
|
return result
|
|
|
|
return cast(F, wrapped_async_func)
|
|
else:
|
|
|
|
@wraps(func)
|
|
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
elapsed_time_str = str(time.time() - start_time)
|
|
log_name = func_name or func.__name__
|
|
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
|
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
|
if debug_only:
|
|
logger.debug(final_log)
|
|
else:
|
|
logger.notice(final_log)
|
|
return result
|
|
|
|
return cast(F, wrapped_sync_func)
|
|
|
|
return decorator
|
|
|
|
|
|
def get_gpu_type() -> str:
|
|
if torch.cuda.is_available():
|
|
return GPUStatus.CUDA
|
|
if torch.backends.mps.is_available():
|
|
return GPUStatus.MAC_MPS
|
|
|
|
return GPUStatus.NONE
|
|
|
|
|
|
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
|
|
"""Parse AWS API key string into components.
|
|
|
|
Args:
|
|
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
|
|
|
|
Returns:
|
|
Tuple of (access_key, secret_key, region)
|
|
|
|
Raises:
|
|
ValueError: If key format is invalid
|
|
"""
|
|
if not api_key.startswith("aws"):
|
|
raise ValueError("API key must start with 'aws' prefix")
|
|
|
|
parts = api_key.split("_")
|
|
if len(parts) != 4:
|
|
raise ValueError(
|
|
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
|
|
"this is an onyx specific format for formatting the aws secrets for bedrock"
|
|
)
|
|
|
|
try:
|
|
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
|
|
return aws_access_key_id, aws_secret_access_key, aws_region
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to parse AWS key components: {str(e)}")
|