Fix EE Import backoff Logic (#1959)

This commit is contained in:
Yuhong Sun 2024-07-27 11:06:11 -07:00 committed by GitHub
parent 6c32821ad4
commit f2f60c9cc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 22 deletions

View File

@ -37,9 +37,24 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any:
module_full = f"ee.{module}" if is_ee else module
try:
return getattr(importlib.import_module(module_full), attribute)
except ModuleNotFoundError:
# try the non-ee version as a fallback
except ModuleNotFoundError as e:
logger.warning(
"Failed to fetch versioned implementation for %s.%s: %s",
module_full,
attribute,
e,
)
if is_ee:
if "ee.danswer" not in str(e):
# If it's a non Danswer related import failure, this is likely because
# a dependent library has not been installed. Should raise this failure
# instead of letting the server start up
raise e
# Use the MIT version as a fallback, this allows us to develop MIT
# versions independently and later add additional EE functionality
# similar to feature flagging
return getattr(importlib.import_module(module), attribute)
raise

View File

@ -6,9 +6,9 @@ from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from ee.danswer.db.analytics import fetch_danswerbot_analytics
from ee.danswer.db.analytics import fetch_per_user_query_analytics
from ee.danswer.db.analytics import fetch_query_analytics
@ -27,7 +27,7 @@ class QueryAnalyticsResponse(BaseModel):
def get_query_analytics(
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[QueryAnalyticsResponse]:
daily_query_usage_info = fetch_query_analytics(
@ -58,7 +58,7 @@ class UserAnalyticsResponse(BaseModel):
def get_user_analytics(
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserAnalyticsResponse]:
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
@ -92,7 +92,7 @@ class DanswerbotAnalyticsResponse(BaseModel):
def get_danswerbot_analytics(
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[DanswerbotAnalyticsResponse]:
daily_danswerbot_info = fetch_danswerbot_analytics(

View File

@ -2,9 +2,9 @@ from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from ee.danswer.db.api_key import ApiKeyDescriptor
from ee.danswer.db.api_key import fetch_api_keys
from ee.danswer.db.api_key import insert_api_key
@ -13,12 +13,13 @@ from ee.danswer.db.api_key import remove_api_key
from ee.danswer.db.api_key import update_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
router = APIRouter(prefix="/admin/api-key")
@router.get("")
def list_api_keys(
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[ApiKeyDescriptor]:
return fetch_api_keys(db_session)
@ -27,7 +28,7 @@ def list_api_keys(
@router.post("")
def create_api_key(
api_key_args: APIKeyArgs,
user: db_models.User | None = Depends(current_admin_user),
user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ApiKeyDescriptor:
return insert_api_key(db_session, api_key_args, user.id if user else None)
@ -36,7 +37,7 @@ def create_api_key(
@router.post("/{api_key_id}/regenerate")
def regenerate_existing_api_key(
api_key_id: int,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ApiKeyDescriptor:
return regenerate_api_key(db_session, api_key_id)
@ -46,7 +47,7 @@ def regenerate_existing_api_key(
def update_existing_api_key(
api_key_id: int,
api_key_args: APIKeyArgs,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ApiKeyDescriptor:
return update_api_key(db_session, api_key_id, api_key_args)
@ -55,7 +56,7 @@ def update_existing_api_key(
@router.delete("/{api_key_id}")
def delete_api_key(
api_key_id: int,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
remove_api_key(db_session, api_key_id)

View File

@ -12,7 +12,6 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_display_email
from danswer.chat.chat_utils import create_chat_chain
@ -22,9 +21,9 @@ from danswer.db.chat import get_chat_session_by_id
from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import User
from ee.danswer.db.query_history import fetch_chat_sessions_eagerly_by_time
router = APIRouter()
@ -303,7 +302,7 @@ def get_chat_session_history(
feedback_type: QAFeedbackType | None = None,
start: datetime | None = None,
end: datetime | None = None,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[ChatSessionMinimal]:
return fetch_and_process_chat_session_history_minimal(
@ -320,7 +319,7 @@ def get_chat_session_history(
@router.get("/admin/chat-session-history/{chat_session_id}")
def get_chat_session_admin(
chat_session_id: int,
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
try:
@ -349,7 +348,7 @@ def get_chat_session_admin(
@router.get("/admin/query-history-csv")
def get_query_history_as_csv(
_: db_models.User | None = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
complete_chat_session_history = fetch_and_process_chat_session_history(

View File

@ -4,9 +4,9 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.db.user_group import insert_user_group
from ee.danswer.db.user_group import prepare_user_group_for_deletion
@ -20,7 +20,7 @@ router = APIRouter(prefix="/manage")
@router.get("/admin/user-group")
def list_user_groups(
_: db_models.User = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
user_groups = fetch_user_groups(db_session, only_current=False)
@ -30,7 +30,7 @@ def list_user_groups(
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,
_: db_models.User = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
@ -48,7 +48,7 @@ def create_user_group(
def patch_user_group(
user_group_id: int,
user_group: UserGroupUpdate,
_: db_models.User = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
@ -62,7 +62,7 @@ def patch_user_group(
@router.delete("/admin/user-group/{user_group_id}")
def delete_user_group(
user_group_id: int,
_: db_models.User = Depends(current_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try: