mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Prompt user for OpenAI key
This commit is contained in:
@@ -12,6 +12,7 @@ SECTION_CONTINUATION = "section_continuation"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ALLOWED_GROUPS = "allowed_groups"
|
||||
NO_AUTH_USER = "FooBarUser" # TODO rework this temporary solution
|
||||
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
|
@@ -1,15 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.question_answer import OpenAICompletionQA
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
|
||||
|
||||
def get_default_backend_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
internal_model: str = INTERNAL_MODEL_VERSION, **kwargs: dict[str, Any]
|
||||
) -> QAModel:
|
||||
if internal_model == "openai-completion":
|
||||
return OpenAICompletionQA()
|
||||
return OpenAICompletionQA(**kwargs)
|
||||
elif internal_model == "openai-chat-completion":
|
||||
return OpenAIChatCompletionQA()
|
||||
return OpenAIChatCompletionQA(**kwargs)
|
||||
else:
|
||||
raise ValueError("Wrong internal QA model set.")
|
||||
|
21
backend/danswer/direct_qa/key_validation.py
Normal file
21
backend/danswer/direct_qa/key_validation.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.question_answer import OpenAIQAModel
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
|
||||
def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
|
||||
if not openai_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key)
|
||||
if not isinstance(qa_model, OpenAIQAModel):
|
||||
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
|
||||
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
@@ -3,6 +3,7 @@ import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
@@ -17,6 +18,7 @@ from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
@@ -29,15 +31,19 @@ from danswer.direct_qa.qa_prompts import json_chat_processor
|
||||
from danswer.direct_qa.qa_prompts import json_processor
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.utils.logging import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
from danswer.utils.timing import log_function_time
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
openai.api_key = OPENAI_API_KEY
|
||||
|
||||
def get_openai_api_key():
|
||||
return OPENAI_API_KEY or get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
||||
|
||||
|
||||
def get_json_line(json_dict: dict) -> str:
|
||||
@@ -181,16 +187,23 @@ def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class OpenAICompletionQA(QAModel):
|
||||
# used to check if the QAModel is an OpenAI model
|
||||
class OpenAIQAModel(QAModel):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAICompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: Callable[[str, list[str]], str] = json_processor,
|
||||
model_version: str = OPENAI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
@@ -202,6 +215,7 @@ class OpenAICompletionQA(QAModel):
|
||||
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
api_key=self.api_key,
|
||||
prompt=filled_prompt,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
@@ -214,6 +228,9 @@ class OpenAICompletionQA(QAModel):
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
@@ -232,6 +249,7 @@ class OpenAICompletionQA(QAModel):
|
||||
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
api_key=self.api_key,
|
||||
prompt=filled_prompt,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
@@ -263,7 +281,9 @@ class OpenAICompletionQA(QAModel):
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
@@ -276,7 +296,7 @@ class OpenAICompletionQA(QAModel):
|
||||
yield quotes_dict
|
||||
|
||||
|
||||
class OpenAIChatCompletionQA(QAModel):
|
||||
class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: Callable[
|
||||
@@ -285,11 +305,13 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
model_version: str = OPENAI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
reflexion_try_count: int = 0,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.reflexion_try_count = reflexion_try_count
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
@@ -302,6 +324,7 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
for _ in range(self.reflexion_try_count + 1):
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
api_key=self.api_key,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
@@ -316,6 +339,9 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.warning(f"Model failure for query: {query}")
|
||||
@@ -335,6 +361,7 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
api_key=self.api_key,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
@@ -370,7 +397,9 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
@@ -36,3 +37,11 @@ class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
|
||||
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
|
||||
with open(self.dir_path / key) as f:
|
||||
return cast(JSON_ro, json.load(f))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
file_path = self.dir_path / key
|
||||
if not file_path.exists():
|
||||
raise ConfigNotFoundError
|
||||
lock = _get_file_lock(file_path)
|
||||
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
|
||||
os.remove(file_path)
|
||||
|
@@ -21,3 +21,7 @@ class DynamicConfigStore:
|
||||
@abc.abstractmethod
|
||||
def load(self, key: str) -> JSON_ro:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import NO_AUTH_USER
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.connectors.factory import build_connector
|
||||
from danswer.connectors.google_drive.connector_auth import get_auth_url
|
||||
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||
@@ -17,7 +19,13 @@ from danswer.db.index_attempt import insert_index_attempt
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.key_validation import (
|
||||
check_openai_api_key_is_valid,
|
||||
)
|
||||
from danswer.direct_qa.question_answer import get_openai_api_key
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.models import ApiKey
|
||||
from danswer.server.models import AuthStatus
|
||||
from danswer.server.models import AuthUrl
|
||||
from danswer.server.models import GDriveCallback
|
||||
@@ -27,6 +35,7 @@ from danswer.server.models import ListIndexAttemptsResponse
|
||||
from danswer.utils.logging import setup_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/admin")
|
||||
@@ -140,3 +149,59 @@ def list_all_index_attempts(
|
||||
for index_attempt in index_attempts
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.head("/openai-api-key/validate")
|
||||
def validate_existing_openai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
is_valid = False
|
||||
try:
|
||||
openai_api_key = get_openai_api_key()
|
||||
is_valid = check_openai_api_key_is_valid(openai_api_key)
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key provided")
|
||||
|
||||
|
||||
@router.get("/openai-api-key")
|
||||
def get_openai_api_key_from_dynamic_config_store(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> ApiKey:
|
||||
"""
|
||||
NOTE: Only gets value from dynamic config store as to not expose env variables.
|
||||
"""
|
||||
try:
|
||||
# only get last 4 characters of key to not expose full key
|
||||
return ApiKey(
|
||||
api_key=cast(
|
||||
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
||||
)[-4:]
|
||||
)
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
|
||||
@router.post("/openai-api-key")
|
||||
def store_openai_api_key(
|
||||
request: ApiKey,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
is_valid = check_openai_api_key_is_valid(request.api_key)
|
||||
if not is_valid:
|
||||
raise HTTPException(400, "Invalid API key provided")
|
||||
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@router.delete("/openai-api-key")
|
||||
def delete_openai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY)
|
||||
|
@@ -73,3 +73,7 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
|
||||
class ListIndexAttemptsResponse(BaseModel):
|
||||
index_attempts: list[IndexAttemptSnapshot]
|
||||
|
||||
|
||||
class ApiKey(BaseModel):
|
||||
api_key: str
|
||||
|
Reference in New Issue
Block a user