Support for Request accessed GenAI Models (#270)

This commit is contained in:
Yuhong Sun
2023-08-06 18:31:47 -07:00
committed by GitHub
parent 0e667d3384
commit 3bfc72484d
19 changed files with 613 additions and 351 deletions

View File

@@ -138,12 +138,6 @@ CHUNK_WORD_OVERLAP = 5
CHUNK_MAX_CHAR_OVERLAP = 50 CHUNK_MAX_CHAR_OVERLAP = 50
#####
# Other API Keys
#####
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
##### #####
# Encoder Model Endpoint Configs (Currently unused, running the models in memory) # Encoder Model Endpoint Configs (Currently unused, running the models in memory)
##### #####

View File

@@ -12,7 +12,7 @@ SECTION_CONTINUATION = "section_continuation"
ALLOWED_USERS = "allowed_users" ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups" ALLOWED_GROUPS = "allowed_groups"
METADATA = "metadata" METADATA = "metadata"
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key" GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
HTML_SEPARATOR = "\n" HTML_SEPARATOR = "\n"
PUBLIC_DOC_PAT = "PUBLIC" PUBLIC_DOC_PAT = "PUBLIC"
@@ -30,3 +30,26 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard" PRODUCTBOARD = "productboard"
FILE = "file" FILE = "file"
NOTION = "notion" NOTION = "notion"
class DanswerGenAIModel(str, Enum):
"""This represents the internal Danswer GenAI model which determines the class that is used
to generate responses to the user query. Different models/services require different internal
handling, this allows for modularity of implementation within Danswer"""
OPENAI = "openai-completion"
OPENAI_CHAT = "openai-chat-completion"
GPT4ALL = "gpt4all-completion"
GPT4ALL_CHAT = "gpt4all-chat-completion"
HUGGINGFACE = "huggingface-inference-completion"
HUGGINGFACE_CHAT = "huggingface-inference-chat-completion"
REQUEST = "request-completion"
class ModelHostType(str, Enum):
"""For GenAI models interfaced via requests, different services have different
expectations for what fields are included in the request"""
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
# TODO support for Azure, AWS, GCP GenAI model hosting

View File

@@ -1,4 +1,8 @@
import os import os
from enum import Enum
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
# Important considerations when choosing models # Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512) # Max tokens count needs to be high considering use case (at least 512)
@@ -30,35 +34,46 @@ CROSS_EMBED_CONTEXT_SIZE = 512
# Purely an optimization, memory limitation consideration # Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8 BATCH_SIZE_ENCODE_CHUNKS = 8
# QA Model API Configs
# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models
# Valid list:
# - openai-completion
# - openai-chat-completion
# - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
# - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
# To use gpt4all, run: pip install --upgrade gpt4all==1.0.5
# These support HuggingFace Inference API, Inference Endpoints and servers running the text-generation-inference backend
# - huggingface-inference-completion
# - huggingface-inference-chat-completion
#####
# Generative AI Model Configs
#####
# Other models should work as well, check the library/API compatibility.
# But these are the models that have been verified to work with the existing prompts.
# Using a different model may require some prompt tuning. See qa_prompts.py
VERIFIED_MODELS = {
DanswerGenAIModel.OPENAI: ["text-davinci-003"],
DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"],
DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"],
DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"],
# The "chat" model below is actually "instruction finetuned" and does not support conversational
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
}
# Sets the internal Danswer model class to use
INTERNAL_MODEL_VERSION = os.environ.get( INTERNAL_MODEL_VERSION = os.environ.get(
"INTERNAL_MODEL_VERSION", "openai-chat-completion" "INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
) )
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo") # If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", "")
# If using GPT4All or OpenAI, specify the model version
GEN_AI_MODEL_VERSION = os.environ.get(
"GEN_AI_MODEL_VERSION",
VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0],
)
# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
# set the two below to specify
# - Where to hit the endpoint
# - How should the request be formed
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
# Set this to be enough for an answer + quotes
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512")) GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512"))
# Use HuggingFace API Token for Huggingface inference client
GEN_AI_HUGGINGFACE_API_TOKEN = os.environ.get("GEN_AI_HUGGINGFACE_API_TOKEN", None)
# Use the conversational API with the huggingface-inference-chat-completion internal model
# Note - this only works with models that support conversational interfaces
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL = (
os.environ.get("GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL", "").lower() == "true"
)
# Disable streaming responses. Set this to true to "polyfill" streaming for models that don't support streaming
GEN_AI_HUGGINGFACE_DISABLE_STREAM = (
os.environ.get("GEN_AI_HUGGINGFACE_DISABLE_STREAM", "").lower() == "true"
)
# Danswer custom Deep Learning Models # Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model" INTENT_MODEL_VERSION = "danswer/intent-model"

View File

@@ -1,25 +1,30 @@
from typing import Any from typing import Any
import pkg_resources
from openai.error import AuthenticationError from openai.error import AuthenticationError
from openai.error import Timeout
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import ( from danswer.configs.constants import DanswerGenAIModel
GEN_AI_HUGGINGFACE_API_TOKEN, from danswer.configs.constants import ModelHostType
INTERNAL_MODEL_VERSION, from danswer.configs.model_configs import GEN_AI_API_KEY
) from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.huggingface_inference import (
HuggingFaceInferenceChatCompletionQA,
HuggingFaceInferenceCompletionQA,
)
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
# Imports commented out temporarily due to incompatibility of gpt4all with M1 Mac hardware currently logger = setup_logger()
# from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
# from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
def check_model_api_key_is_valid(model_api_key: str) -> bool: def check_model_api_key_is_valid(model_api_key: str) -> bool:
@@ -35,32 +40,66 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
return True return True
except AuthenticationError: except AuthenticationError:
return False return False
except Timeout: except Exception as e:
pass logger.warning(f"GenAI API key failed for the following reason: {e}")
return False return False
def get_default_backend_qa_model( def get_default_backend_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION, internal_model: str = INTERNAL_MODEL_VERSION,
api_key: str | None = None, endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT, timeout: int = QA_TIMEOUT,
**kwargs: Any **kwargs: Any,
) -> QAModel: ) -> QAModel:
if internal_model == "openai-completion": if not api_key:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
pass
if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
]:
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
pkg_resources.get_distribution("gpt4all")
if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "openai-chat-completion": elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs) return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "huggingface-inference-completion": elif internal_model == DanswerGenAIModel.GPT4ALL.value:
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN return GPT4AllCompletionQA(**kwargs)
return HuggingFaceInferenceCompletionQA(api_key=api_key, **kwargs) elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
elif internal_model == "huggingface-inference-chat-completion": return GPT4AllChatCompletionQA(**kwargs)
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
return HuggingFaceInferenceChatCompletionQA(api_key=api_key, **kwargs) return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
# Note GPT4All is not supported for M1 Mac machines currently, removing until support is added elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
# elif internal_model == "gpt4all-completion": return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
# return GPT4AllCompletionQA(**kwargs) elif internal_model == DanswerGenAIModel.REQUEST.value:
# elif internal_model == "gpt4all-chat-completion": if endpoint is None or model_host_type is None:
# return GPT4AllChatCompletionQA(**kwargs) raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if model_host_type == ModelHostType.HUGGINGFACE.value:
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
prompt_processor=WeakModelFreeformProcessor(),
timeout=timeout,
)
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
timeout=timeout,
)
else: else:
raise UnknownModelError(internal_model) raise UnknownModelError(internal_model)

View File

@@ -1,8 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from gpt4all import GPT4All # type:ignore
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
@@ -18,9 +16,30 @@ from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
logger = setup_logger() logger = setup_logger()
class DummyGPT4All:
"""In the case of import failure due to M1 Mac incompatibility,
so this module does not raise exceptions during server startup,
as long as this module isn't actually used"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("GPT4All library not installed.")
try:
from gpt4all import GPT4All # type:ignore
except ImportError:
logger.warning(
"GPT4All library not installed. "
"If you wish to run GPT4ALL (in memory) to power Danswer's "
"Generative AI features, please install gpt4all==1.0.5. "
"As of Aug 2023, this library is not compatible with M1 Mac."
)
GPT4All = DummyGPT4All
GPT4ALL_MODEL: GPT4All | None = None GPT4ALL_MODEL: GPT4All | None = None
@@ -56,6 +75,10 @@ class GPT4AllCompletionQA(QAModel):
self.max_output_tokens = max_output_tokens self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata self.include_metadata = include_metadata
@property
def requires_api_key(self) -> bool:
return False
def warm_up_model(self) -> None: def warm_up_model(self) -> None:
get_gpt_4_all_model(self.model_version) get_gpt_4_all_model(self.model_version)
@@ -117,6 +140,13 @@ class GPT4AllChatCompletionQA(QAModel):
self.max_output_tokens = max_output_tokens self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata self.include_metadata = include_metadata
@property
def requires_api_key(self) -> bool:
return False
def warm_up_model(self) -> None:
get_gpt_4_all_model(self.model_version)
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]

View File

@@ -0,0 +1,189 @@
from collections.abc import Generator
from typing import Any
from huggingface_hub import InferenceClient # type:ignore
from huggingface_hub.utils import HfHubHTTPError # type:ignore
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import FreeformProcessor
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.direct_qa.qa_utils import simulate_streaming_response
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"do_sample": False,
"seed": 69, # For reproducibility
**kwargs,
}
class HuggingFaceCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = FreeformProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_stream = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class HuggingFaceChatCompletionQA(QAModel):
"""Chat in this class refers to the HuggingFace Conversational API.
Not to be confused with Chat/Instruction finetuned models.
Llama2-chat... means it is an Instruction finetuned model, not necessarily that
it supports Conversational API"""
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@staticmethod
def _convert_chat_to_hf_conversational_format(
dialog: list[dict[str, str]]
) -> tuple[str, list[str], list[str]]:
if dialog[-1]["role"] != "user":
raise Exception(
"Last message in conversational dialog must be User message"
)
user_message = dialog[-1]["content"]
dialog = dialog[0:-1]
generated_responses = []
past_user_inputs = []
for message in dialog:
# HuggingFace inference client doesn't support system messages today
# so lumping them in with user messages
if message["role"] in ["user", "system"]:
past_user_inputs += [message["content"]]
else:
generated_responses += [message["content"]]
return user_message, generated_responses, past_user_inputs
def _get_hf_model_output(
self, query: str, context_docs: list[InferenceChunk]
) -> str:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
(
query,
past_responses,
past_inputs,
) = self._convert_chat_to_hf_conversational_format(filled_prompt)
logger.debug(f"Last Input: {query}")
logger.debug(f"Past Inputs: {past_inputs}")
logger.debug(f"Past Responses: {past_responses}")
try:
model_output = self.client.conversational(
query,
generated_responses=past_responses,
past_user_inputs=past_inputs,
parameters={"max_length": self.max_output_tokens},
)
except HfHubHTTPError as model_error:
if model_error.response.status_code == 422:
raise ValueError(
"Selected HuggingFace Model does not support HuggingFace Conversational API,"
"try using the huggingface-inference-completion in Danswer instead"
)
raise
logger.debug(model_output)
return model_output
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
model_output = self._get_hf_model_output(query, context_docs)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
So here it is faked by streaming characters within Danswer from the model output
"""
model_output = self._get_hf_model_output(query, context_docs)
model_stream = simulate_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -1,248 +0,0 @@
from collections.abc import Generator
from typing import Any, Iterator
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import (
GEN_AI_HUGGINGFACE_DISABLE_STREAM,
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL,
GEN_AI_MAX_OUTPUT_TOKENS,
)
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import (
ChatPromptProcessor,
JsonChatProcessor,
JsonProcessor,
)
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from huggingface_hub import InferenceClient
logger = setup_logger()
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"do_sample": False,
"seed": 69, # For reproducibility
**kwargs,
}
def _generic_chat_dialog_to_prompt_formatter(dialog: list[dict[str, str]]) -> str:
"""
Utility to convert chat dialog to a text-generation prompt for models tuned for chat.
Note - This is a "best guess" attempt at a generic completions prompt for chat
completion models. It isn't optimized for all chat trained models, but tries
to serialize to a format that most models understand.
Models like Llama2-chat have been optimized for certain formatting of chat
completions, and this function doesn't take that into account, so you won't
always get the best possible outcome.
TODO - Add the ability to pass custom formatters for chat dialogue
"""
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.
If a question does not make any sense or is not factually coherent, explain why instead of answering incorrectly.
If you don't know the answer to a question, don't share false information."""
prompt = ""
if dialog[0]["role"] != "system":
dialog = [
{
"role": "system",
"content": DEFAULT_SYSTEM_PROMPT,
}
] + dialog
for message in dialog:
prompt += f"{message['role'].upper()}: {message['content']}\n"
prompt += "ASSISTANT:"
return prompt
def _mock_streaming_response(tokens: str) -> Generator[str, None, None]:
"""Utility to mock a streaming response"""
for token in tokens:
yield token
class HuggingFaceInferenceCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
model_stream = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
else:
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
model_stream = _mock_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class HuggingFaceInferenceChatCompletionQA(QAModel):
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@staticmethod
def convert_dialog_to_conversational_format(
dialog: list[dict[str, str]]
) -> tuple[str, list[str], list[str]]:
if dialog[-1]["role"] != "user":
raise Exception(
"Last message in conversational dialog must be User message"
)
user_message = dialog[-1]["content"]
dialog = dialog[0:-1]
generated_responses = []
past_user_inputs = []
for message in dialog:
# HuggingFace inference client doesn't support system messages today
# so lumping them in with user messages
if message["role"] in ["user", "system"]:
past_user_inputs += [message["content"]]
else:
generated_responses += [message["content"]]
return user_message, generated_responses, past_user_inputs
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
(
message,
generated_responses,
past_user_inputs,
) = self.convert_dialog_to_conversational_format(filled_prompt)
model_output = self.client.conversational(
message,
generated_responses=generated_responses,
past_user_inputs=past_user_inputs,
parameters={"max_length": self.max_output_tokens},
)
else:
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_output = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
raise Exception(
"Conversational API is not available with streaming enabled. Please either "
+ "disable streaming, or disable using conversational API."
)
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_stream = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
else:
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
(
message,
generated_responses,
past_user_inputs,
) = self.convert_dialog_to_conversational_format(filled_prompt)
model_output = self.client.conversational(
message,
generated_responses=generated_responses,
past_user_inputs=past_user_inputs,
parameters={"max_length": self.max_output_tokens},
)
else:
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_output = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens
),
)
logger.debug(model_output)
model_stream = _mock_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -23,6 +23,12 @@ class DanswerQuote:
class QAModel: class QAModel:
@property
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
return True
def warm_up_model(self) -> None: def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory """This is called during server start up to load the models into memory
pass if model is accessed via API""" pass if model is accessed via API"""

View File

@@ -15,8 +15,6 @@ from openai.error import Timeout
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import INCLUDE_METADATA from danswer.configs.app_configs import INCLUDE_METADATA
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import OpenAIKeyMissing
@@ -28,9 +26,9 @@ from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
from danswer.direct_qa.qa_prompts import JsonChatProcessor from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import JsonProcessor from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.qa_utils import process_answer from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
@@ -41,15 +39,9 @@ logger = setup_logger()
F = TypeVar("F", bound=Callable) F = TypeVar("F", bound=Callable)
def get_openai_api_key() -> str:
return OPENAI_API_KEY or cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
)
def _ensure_openai_api_key(api_key: str | None) -> str: def _ensure_openai_api_key(api_key: str | None) -> str:
try: try:
return api_key or get_openai_api_key() return api_key or get_gen_ai_api_key()
except ConfigNotFoundError: except ConfigNotFoundError:
raise OpenAIKeyMissing() raise OpenAIKeyMissing()
@@ -131,7 +123,7 @@ class OpenAICompletionQA(OpenAIQAModel):
self.timeout = timeout self.timeout = timeout
self.include_metadata = include_metadata self.include_metadata = include_metadata
try: try:
self.api_key = api_key or get_openai_api_key() self.api_key = api_key or get_gen_ai_api_key()
except ConfigNotFoundError: except ConfigNotFoundError:
raise OpenAIKeyMissing() raise OpenAIKeyMissing()

View File

@@ -3,6 +3,7 @@ import math
import re import re
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from typing import cast
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@@ -12,14 +13,17 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.qa_prompts import ANSWER_PAT from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup from danswer.utils.text_processing import shared_precompare_cleanup
@@ -27,6 +31,12 @@ from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger() logger = setup_logger()
def get_gen_ai_api_key() -> str:
return GEN_AI_API_KEY or cast(
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
)
def structure_quotes_for_response( def structure_quotes_for_response(
quotes: list[DanswerQuote] | None, quotes: list[DanswerQuote] | None,
) -> dict[str, dict[str, str | None]]: ) -> dict[str, dict[str, str | None]]:
@@ -246,3 +256,9 @@ def process_model_tokens(
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs) quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
yield structure_quotes_for_response(quotes) yield structure_quotes_for_response(quotes)
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character"""
for token in model_out:
yield token

View File

@@ -0,0 +1,201 @@
import abc
import json
from collections.abc import Generator
from typing import Any
import requests
from requests.models import Response
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.direct_qa.qa_utils import simulate_streaming_response
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
class HostSpecificRequestModel(abc.ABC):
"""Provides a more minimal implementation requirement for extending to new models
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
"""
@staticmethod
@abc.abstractmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None,
max_output_tokens: int,
stream: bool,
timeout: int | None,
) -> Response:
"""Given a filled out prompt, how to send it to the model API with the
correct request format with the correct parameters"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def extract_model_output_from_response(
response: Response,
) -> str:
"""Extract the full model output text from a response.
This is for nonstreaming endpoints"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
"""Generate tokens from a streaming response
This is for streaming endpoints"""
raise NotImplementedError
class HuggingFaceRequestModel(HostSpecificRequestModel):
@staticmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None,
max_output_tokens: int,
stream: bool, # Not supported by Inference Endpoints (as of Aug 2023)
timeout: int | None,
) -> Response:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
data = {
"inputs": filled_prompt,
"parameters": {
# HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive
"temperature": 0.01,
# Skip the long tail
"top_p": 0.9,
"max_new_tokens": max_output_tokens,
},
}
try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except TimeoutError as error:
raise TimeoutError(f"Model inference to {endpoint} timed out") from error
@staticmethod
def _hf_extract_model_output(
response: Response,
) -> str:
if response.status_code != 200:
response.raise_for_status()
return json.loads(response.content)[0].get("generated_text", "")
@staticmethod
def extract_model_output_from_response(
response: Response,
) -> str:
return HuggingFaceRequestModel._hf_extract_model_output(response)
@staticmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
"""HF endpoints do not do streaming currently so this function will
simulate streaming for the meantime but will need to be replaced in
the future once streaming is enabled."""
model_out = HuggingFaceRequestModel._hf_extract_model_output(response)
yield from simulate_streaming_response(model_out)
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
if model_host_type == ModelHostType.HUGGINGFACE.value:
return HuggingFaceRequestModel()
else:
# TODO support Azure, GCP, AWS
raise ValueError(
"Invalid model hosting service selected, currently supports only huggingface"
)
class RequestCompletionQA(QAModel):
def __init__(
self,
endpoint: str = GEN_AI_ENDPOINT,
model_host_type: str = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
) -> None:
self.endpoint = endpoint
self.api_key = api_key
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.model_class = get_host_specific_model_class(model_host_type)
self.timeout = timeout
def _get_request_response(
self, query: str, context_docs: list[InferenceChunk], stream: bool
) -> Response:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, include_metadata=False
)
logger.debug(filled_prompt)
return self.model_class.send_model_request(
filled_prompt,
self.endpoint,
self.api_key,
self.max_output_tokens,
stream,
self.timeout,
)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)
model_output = self.model_class.extract_model_output_from_response(
model_api_response
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self,
query: str,
context_docs: list[InferenceChunk],
) -> Generator[dict[str, Any] | None, None, None]:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)
token_generator = self.model_class.generate_model_tokens_from_response(
model_api_response
)
yield from process_model_tokens(
tokens=token_generator,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -18,7 +18,7 @@ from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.connectors.file.utils import write_temp_files from danswer.connectors.file.utils import write_temp_files
from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_KEY from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_KEY
from danswer.connectors.google_drive.connector_auth import get_auth_url from danswer.connectors.google_drive.connector_auth import get_auth_url
@@ -50,7 +50,7 @@ from danswer.db.index_attempt import create_index_attempt
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import check_model_api_key_is_valid from danswer.direct_qa import check_model_api_key_is_valid
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.open_ai import get_openai_api_key from danswer.direct_qa.open_ai import get_gen_ai_api_key
from danswer.direct_qa.open_ai import OpenAIQAModel from danswer.direct_qa.open_ai import OpenAIQAModel
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
@@ -293,19 +293,17 @@ def connector_run_once(
) )
@router.head("/admin/openai-api-key/validate") @router.head("/admin/genai-api-key/validate")
def validate_existing_openai_api_key( def validate_existing_genai_api_key(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> None: ) -> None:
# OpenAI key is only used for generative QA, so no need to validate this # OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off or if a non-OpenAI model is being used # if it's turned off or if a non-OpenAI model is being used
if DISABLE_GENERATIVE_AI or not isinstance( if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key:
get_default_backend_qa_model(), OpenAIQAModel
):
return return
# Only validate every so often # Only validate every so often
check_key_time = "openai_api_key_last_check_time" check_key_time = "genai_api_key_last_check_time"
kv_store = get_dynamic_config_store() kv_store = get_dynamic_config_store()
curr_time = datetime.now() curr_time = datetime.now()
try: try:
@@ -318,7 +316,7 @@ def validate_existing_openai_api_key(
pass pass
try: try:
openai_api_key = get_openai_api_key() genai_api_key = get_gen_ai_api_key()
except ConfigNotFoundError: except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found") raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e: except ValueError as e:
@@ -327,7 +325,7 @@ def validate_existing_openai_api_key(
get_dynamic_config_store().store(check_key_time, curr_time.timestamp()) get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try: try:
is_valid = check_model_api_key_is_valid(openai_api_key) is_valid = check_model_api_key_is_valid(genai_api_key)
except ValueError: except ValueError:
# this is the case where they aren't using an OpenAI-based model # this is the case where they aren't using an OpenAI-based model
is_valid = True is_valid = True
@@ -336,8 +334,8 @@ def validate_existing_openai_api_key(
raise HTTPException(status_code=400, detail="Invalid API key provided") raise HTTPException(status_code=400, detail="Invalid API key provided")
@router.get("/admin/openai-api-key", response_model=ApiKey) @router.get("/admin/genai-api-key", response_model=ApiKey)
def get_openai_api_key_from_dynamic_config_store( def get_gen_ai_api_key_from_dynamic_config_store(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> ApiKey: ) -> ApiKey:
""" """
@@ -347,15 +345,15 @@ def get_openai_api_key_from_dynamic_config_store(
# only get last 4 characters of key to not expose full key # only get last 4 characters of key to not expose full key
return ApiKey( return ApiKey(
api_key=cast( api_key=cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY) str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
)[-4:] )[-4:]
) )
except ConfigNotFoundError: except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found") raise HTTPException(status_code=404, detail="Key not found")
@router.put("/admin/openai-api-key") @router.put("/admin/genai-api-key")
def store_openai_api_key( def store_genai_api_key(
request: ApiKey, request: ApiKey,
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> None: ) -> None:
@@ -363,16 +361,16 @@ def store_openai_api_key(
is_valid = check_model_api_key_is_valid(request.api_key) is_valid = check_model_api_key_is_valid(request.api_key)
if not is_valid: if not is_valid:
raise HTTPException(400, "Invalid API key provided") raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key) get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(400, str(e)) raise HTTPException(400, str(e))
@router.delete("/admin/openai-api-key") @router.delete("/admin/genai-api-key")
def delete_openai_api_key( def delete_genai_api_key(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> None: ) -> None:
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY) get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
"""Endpoints for basic users""" """Endpoints for basic users"""

View File

@@ -6,6 +6,7 @@ reorder-python-imports==3.9.0
types-beautifulsoup4==4.12.0.3 types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13 types-html5lib==1.1.11.13
types-oauthlib==3.2.0.9 types-oauthlib==3.2.0.9
types-setuptools==68.0.0.3
types-psycopg2==2.9.21.10 types-psycopg2==2.9.21.10
types-python-dateutil==2.8.19.13 types-python-dateutil==2.8.19.13
types-regex==2023.3.23.1 types-regex==2023.3.23.1

View File

@@ -19,6 +19,9 @@ services:
environment: environment:
- INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- POSTGRES_HOST=relational_db - POSTGRES_HOST=relational_db
- QDRANT_HOST=vector_db - QDRANT_HOST=vector_db
- TYPESENSE_HOST=search_engine - TYPESENSE_HOST=search_engine
@@ -49,6 +52,9 @@ services:
environment: environment:
- INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- POSTGRES_HOST=relational_db - POSTGRES_HOST=relational_db
- QDRANT_HOST=vector_db - QDRANT_HOST=vector_db
- TYPESENSE_HOST=search_engine - TYPESENSE_HOST=search_engine

View File

@@ -5,7 +5,7 @@
# Insert your OpenAI API key here, currently the only Generative AI endpoint for QA that we support is OpenAI # Insert your OpenAI API key here, currently the only Generative AI endpoint for QA that we support is OpenAI
# If not provided here, UI will prompt on setup # If not provided here, UI will prompt on setup
OPENAI_API_KEY= GEN_AI_API_KEY=
# Choose between "openai-chat-completion" and "openai-completion" # Choose between "openai-chat-completion" and "openai-completion"
INTERNAL_MODEL_VERSION=openai-chat-completion INTERNAL_MODEL_VERSION=openai-chat-completion
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility # Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility

View File

@@ -3,13 +3,13 @@
import { LoadingAnimation } from "@/components/Loading"; import { LoadingAnimation } from "@/components/Loading";
import { KeyIcon, TrashIcon } from "@/components/icons/icons"; import { KeyIcon, TrashIcon } from "@/components/icons/icons";
import { ApiKeyForm } from "@/components/openai/ApiKeyForm"; import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
import { OPENAI_API_KEY_URL } from "@/components/openai/constants"; import { GEN_AI_API_KEY_URL } from "@/components/openai/constants";
import { fetcher } from "@/lib/fetcher"; import { fetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr"; import useSWR, { mutate } from "swr";
const ExistingKeys = () => { const ExistingKeys = () => {
const { data, isLoading, error } = useSWR<{ api_key: string }>( const { data, isLoading, error } = useSWR<{ api_key: string }>(
OPENAI_API_KEY_URL, GEN_AI_API_KEY_URL,
fetcher fetcher
); );
@@ -33,7 +33,7 @@ const ExistingKeys = () => {
<button <button
className="ml-1 my-auto hover:bg-gray-700 rounded-full p-1" className="ml-1 my-auto hover:bg-gray-700 rounded-full p-1"
onClick={async () => { onClick={async () => {
await fetch(OPENAI_API_KEY_URL, { await fetch(GEN_AI_API_KEY_URL, {
method: "DELETE", method: "DELETE",
}); });
window.location.reload(); window.location.reload();
@@ -64,7 +64,7 @@ const Page = () => {
<ApiKeyForm <ApiKeyForm
handleResponse={(response) => { handleResponse={(response) => {
if (response.ok) { if (response.ok) {
mutate(OPENAI_API_KEY_URL); mutate(GEN_AI_API_KEY_URL);
} }
}} }}
/> />

View File

@@ -2,7 +2,7 @@ import { Form, Formik } from "formik";
import { Popup } from "../admin/connectors/Popup"; import { Popup } from "../admin/connectors/Popup";
import { useState } from "react"; import { useState } from "react";
import { TextFormField } from "../admin/connectors/Field"; import { TextFormField } from "../admin/connectors/Field";
import { OPENAI_API_KEY_URL } from "./constants"; import { GEN_AI_API_KEY_URL } from "./constants";
import { LoadingAnimation } from "../Loading"; import { LoadingAnimation } from "../Loading";
interface Props { interface Props {
@@ -21,7 +21,7 @@ export const ApiKeyForm = ({ handleResponse }: Props) => {
<Formik <Formik
initialValues={{ apiKey: "" }} initialValues={{ apiKey: "" }}
onSubmit={async ({ apiKey }, formikHelpers) => { onSubmit={async ({ apiKey }, formikHelpers) => {
const response = await fetch(OPENAI_API_KEY_URL, { const response = await fetch(GEN_AI_API_KEY_URL, {
method: "PUT", method: "PUT",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@@ -7,7 +7,7 @@ export const ApiKeyModal = () => {
const [isOpen, setIsOpen] = useState(false); const [isOpen, setIsOpen] = useState(false);
useEffect(() => { useEffect(() => {
fetch("/api/manage/admin/openai-api-key/validate", { fetch("/api/manage/admin/genai-api-key/validate", {
method: "HEAD", method: "HEAD",
}).then((res) => { }).then((res) => {
// show popup if either the API key is not set or the API key is invalid // show popup if either the API key is not set or the API key is invalid

View File

@@ -1 +1 @@
export const OPENAI_API_KEY_URL = "/api/manage/admin/openai-api-key"; export const GEN_AI_API_KEY_URL = "/api/manage/admin/genai-api-key";