From 3bfc72484d91b7314576e0a6d673d9a2ab7f9378 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 6 Aug 2023 18:31:47 -0700 Subject: [PATCH] Support for Request accessed GenAI Models (#270) --- backend/danswer/configs/app_configs.py | 6 - backend/danswer/configs/constants.py | 25 +- backend/danswer/configs/model_configs.py | 65 +++-- backend/danswer/direct_qa/__init__.py | 97 +++++-- backend/danswer/direct_qa/gpt_4_all.py | 36 ++- backend/danswer/direct_qa/huggingface.py | 189 +++++++++++++ .../direct_qa/huggingface_inference.py | 248 ------------------ backend/danswer/direct_qa/interfaces.py | 6 + backend/danswer/direct_qa/open_ai.py | 14 +- backend/danswer/direct_qa/qa_utils.py | 16 ++ backend/danswer/direct_qa/request_model.py | 201 ++++++++++++++ backend/danswer/server/manage.py | 36 ++- backend/requirements/dev.txt | 1 + .../docker_compose/docker-compose.dev.yml | 6 + deployment/docker_compose/env.prod.template | 2 +- web/src/app/admin/keys/openai/page.tsx | 8 +- web/src/components/openai/ApiKeyForm.tsx | 4 +- web/src/components/openai/ApiKeyModal.tsx | 2 +- web/src/components/openai/constants.ts | 2 +- 19 files changed, 613 insertions(+), 351 deletions(-) create mode 100644 backend/danswer/direct_qa/huggingface.py delete mode 100644 backend/danswer/direct_qa/huggingface_inference.py create mode 100644 backend/danswer/direct_qa/request_model.py diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 8d8ae16980ae..2c54e34447d5 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -138,12 +138,6 @@ CHUNK_WORD_OVERLAP = 5 CHUNK_MAX_CHAR_OVERLAP = 50 -##### -# Other API Keys -##### -OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") - - ##### # Encoder Model Endpoint Configs (Currently unused, running the models in memory) ##### diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 5074efe81ef6..d18c2fefc3e8 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -12,7 +12,7 @@ SECTION_CONTINUATION = "section_continuation" ALLOWED_USERS = "allowed_users" ALLOWED_GROUPS = "allowed_groups" METADATA = "metadata" -OPENAI_API_KEY_STORAGE_KEY = "openai_api_key" +GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key" HTML_SEPARATOR = "\n" PUBLIC_DOC_PAT = "PUBLIC" @@ -30,3 +30,26 @@ class DocumentSource(str, Enum): PRODUCTBOARD = "productboard" FILE = "file" NOTION = "notion" + + +class DanswerGenAIModel(str, Enum): + """This represents the internal Danswer GenAI model which determines the class that is used + to generate responses to the user query. Different models/services require different internal + handling, this allows for modularity of implementation within Danswer""" + + OPENAI = "openai-completion" + OPENAI_CHAT = "openai-chat-completion" + GPT4ALL = "gpt4all-completion" + GPT4ALL_CHAT = "gpt4all-chat-completion" + HUGGINGFACE = "huggingface-inference-completion" + HUGGINGFACE_CHAT = "huggingface-inference-chat-completion" + REQUEST = "request-completion" + + +class ModelHostType(str, Enum): + """For GenAI models interfaced via requests, different services have different + expectations for what fields are included in the request""" + + # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task + HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API + # TODO support for Azure, AWS, GCP GenAI model hosting diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index e5b929111396..320fdadee738 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -1,4 +1,8 @@ import os +from enum import Enum + +from danswer.configs.constants import DanswerGenAIModel +from danswer.configs.constants import ModelHostType # Important considerations when choosing models # Max tokens count needs to be high considering use case (at least 512) @@ -30,35 +34,46 @@ CROSS_EMBED_CONTEXT_SIZE = 512 # Purely an optimization, memory limitation consideration BATCH_SIZE_ENCODE_CHUNKS = 8 -# QA Model API Configs -# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models -# Valid list: -# - openai-completion -# - openai-chat-completion -# - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself -# - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself -# To use gpt4all, run: pip install --upgrade gpt4all==1.0.5 -# These support HuggingFace Inference API, Inference Endpoints and servers running the text-generation-inference backend -# - huggingface-inference-completion -# - huggingface-inference-chat-completion +##### +# Generative AI Model Configs +##### +# Other models should work as well, check the library/API compatibility. +# But these are the models that have been verified to work with the existing prompts. +# Using a different model may require some prompt tuning. See qa_prompts.py +VERIFIED_MODELS = { + DanswerGenAIModel.OPENAI: ["text-davinci-003"], + DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"], + DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"], + DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"], + # The "chat" model below is actually "instruction finetuned" and does not support conversational + DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"], + DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"], +} + +# Sets the internal Danswer model class to use INTERNAL_MODEL_VERSION = os.environ.get( - "INTERNAL_MODEL_VERSION", "openai-chat-completion" + "INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value ) -# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model -GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo") + +# If the Generative AI model requires an API key for access, otherwise can leave blank +GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", "") + +# If using GPT4All or OpenAI, specify the model version +GEN_AI_MODEL_VERSION = os.environ.get( + "GEN_AI_MODEL_VERSION", + VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0], +) + +# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then +# set the two below to specify +# - Where to hit the endpoint +# - How should the request be formed +GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "") +GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value) + +# Set this to be enough for an answer + quotes GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512")) -# Use HuggingFace API Token for Huggingface inference client -GEN_AI_HUGGINGFACE_API_TOKEN = os.environ.get("GEN_AI_HUGGINGFACE_API_TOKEN", None) -# Use the conversational API with the huggingface-inference-chat-completion internal model -# Note - this only works with models that support conversational interfaces -GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL = ( - os.environ.get("GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL", "").lower() == "true" -) -# Disable streaming responses. Set this to true to "polyfill" streaming for models that don't support streaming -GEN_AI_HUGGINGFACE_DISABLE_STREAM = ( - os.environ.get("GEN_AI_HUGGINGFACE_DISABLE_STREAM", "").lower() == "true" -) # Danswer custom Deep Learning Models INTENT_MODEL_VERSION = "danswer/intent-model" diff --git a/backend/danswer/direct_qa/__init__.py b/backend/danswer/direct_qa/__init__.py index 0f1c9ddefdbe..f3918bc333fc 100644 --- a/backend/danswer/direct_qa/__init__.py +++ b/backend/danswer/direct_qa/__init__.py @@ -1,25 +1,30 @@ from typing import Any +import pkg_resources from openai.error import AuthenticationError -from openai.error import Timeout from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.model_configs import ( - GEN_AI_HUGGINGFACE_API_TOKEN, - INTERNAL_MODEL_VERSION, -) +from danswer.configs.constants import DanswerGenAIModel +from danswer.configs.constants import ModelHostType +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_ENDPOINT +from danswer.configs.model_configs import GEN_AI_HOST_TYPE +from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.exceptions import UnknownModelError +from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA +from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA +from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA +from danswer.direct_qa.huggingface import HuggingFaceCompletionQA from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.huggingface_inference import ( - HuggingFaceInferenceChatCompletionQA, - HuggingFaceInferenceCompletionQA, -) from danswer.direct_qa.open_ai import OpenAIChatCompletionQA from danswer.direct_qa.open_ai import OpenAICompletionQA +from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor +from danswer.direct_qa.qa_utils import get_gen_ai_api_key +from danswer.direct_qa.request_model import RequestCompletionQA +from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.utils.logger import setup_logger -# Imports commented out temporarily due to incompatibility of gpt4all with M1 Mac hardware currently -# from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA -# from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA +logger = setup_logger() def check_model_api_key_is_valid(model_api_key: str) -> bool: @@ -35,32 +40,66 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool: return True except AuthenticationError: return False - except Timeout: - pass + except Exception as e: + logger.warning(f"GenAI API key failed for the following reason: {e}") return False def get_default_backend_qa_model( internal_model: str = INTERNAL_MODEL_VERSION, - api_key: str | None = None, + endpoint: str | None = GEN_AI_ENDPOINT, + model_host_type: str | None = GEN_AI_HOST_TYPE, + api_key: str | None = GEN_AI_API_KEY, timeout: int = QA_TIMEOUT, - **kwargs: Any + **kwargs: Any, ) -> QAModel: - if internal_model == "openai-completion": + if not api_key: + try: + api_key = get_gen_ai_api_key() + except ConfigNotFoundError: + pass + + if internal_model in [ + DanswerGenAIModel.GPT4ALL.value, + DanswerGenAIModel.GPT4ALL_CHAT.value, + ]: + # gpt4all is not compatible M1 Mac hardware as of Aug 2023 + pkg_resources.get_distribution("gpt4all") + + if internal_model == DanswerGenAIModel.OPENAI.value: return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) - elif internal_model == "openai-chat-completion": + elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value: return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs) - elif internal_model == "huggingface-inference-completion": - api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN - return HuggingFaceInferenceCompletionQA(api_key=api_key, **kwargs) - elif internal_model == "huggingface-inference-chat-completion": - api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN - return HuggingFaceInferenceChatCompletionQA(api_key=api_key, **kwargs) - # Note GPT4All is not supported for M1 Mac machines currently, removing until support is added - # elif internal_model == "gpt4all-completion": - # return GPT4AllCompletionQA(**kwargs) - # elif internal_model == "gpt4all-chat-completion": - # return GPT4AllChatCompletionQA(**kwargs) + elif internal_model == DanswerGenAIModel.GPT4ALL.value: + return GPT4AllCompletionQA(**kwargs) + elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value: + return GPT4AllChatCompletionQA(**kwargs) + elif internal_model == DanswerGenAIModel.HUGGINGFACE.value: + return HuggingFaceCompletionQA(api_key=api_key, **kwargs) + elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value: + return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs) + elif internal_model == DanswerGenAIModel.REQUEST.value: + if endpoint is None or model_host_type is None: + raise ValueError( + "Request based GenAI model requires an endpoint and host type" + ) + if model_host_type == ModelHostType.HUGGINGFACE.value: + # Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits + # With the 7B Llama2 Chat model, there is a max limit of 1512 tokens + # This is the sum of input and output tokens, so cannot take in full Danswer context + return RequestCompletionQA( + endpoint=endpoint, + model_host_type=model_host_type, + api_key=api_key, + prompt_processor=WeakModelFreeformProcessor(), + timeout=timeout, + ) + return RequestCompletionQA( + endpoint=endpoint, + model_host_type=model_host_type, + api_key=api_key, + timeout=timeout, + ) else: raise UnknownModelError(internal_model) diff --git a/backend/danswer/direct_qa/gpt_4_all.py b/backend/danswer/direct_qa/gpt_4_all.py index 673cf94e7921..9402b234acf1 100644 --- a/backend/danswer/direct_qa/gpt_4_all.py +++ b/backend/danswer/direct_qa/gpt_4_all.py @@ -1,8 +1,6 @@ from collections.abc import Generator from typing import Any -from gpt4all import GPT4All # type:ignore - from danswer.chunking.models import InferenceChunk from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_VERSION @@ -18,9 +16,30 @@ from danswer.direct_qa.qa_utils import process_model_tokens from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time - logger = setup_logger() + +class DummyGPT4All: + """In the case of import failure due to M1 Mac incompatibility, + so this module does not raise exceptions during server startup, + as long as this module isn't actually used""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError("GPT4All library not installed.") + + +try: + from gpt4all import GPT4All # type:ignore +except ImportError: + logger.warning( + "GPT4All library not installed. " + "If you wish to run GPT4ALL (in memory) to power Danswer's " + "Generative AI features, please install gpt4all==1.0.5. " + "As of Aug 2023, this library is not compatible with M1 Mac." + ) + GPT4All = DummyGPT4All + + GPT4ALL_MODEL: GPT4All | None = None @@ -56,6 +75,10 @@ class GPT4AllCompletionQA(QAModel): self.max_output_tokens = max_output_tokens self.include_metadata = include_metadata + @property + def requires_api_key(self) -> bool: + return False + def warm_up_model(self) -> None: get_gpt_4_all_model(self.model_version) @@ -117,6 +140,13 @@ class GPT4AllChatCompletionQA(QAModel): self.max_output_tokens = max_output_tokens self.include_metadata = include_metadata + @property + def requires_api_key(self) -> bool: + return False + + def warm_up_model(self) -> None: + get_gpt_4_all_model(self.model_version) + @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] diff --git a/backend/danswer/direct_qa/huggingface.py b/backend/danswer/direct_qa/huggingface.py new file mode 100644 index 000000000000..ea8310f20947 --- /dev/null +++ b/backend/danswer/direct_qa/huggingface.py @@ -0,0 +1,189 @@ +from collections.abc import Generator +from typing import Any + +from huggingface_hub import InferenceClient # type:ignore +from huggingface_hub.utils import HfHubHTTPError # type:ignore + +from danswer.chunking.models import InferenceChunk +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.qa_prompts import ChatPromptProcessor +from danswer.direct_qa.qa_prompts import FreeformProcessor +from danswer.direct_qa.qa_prompts import JsonChatProcessor +from danswer.direct_qa.qa_prompts import NonChatPromptProcessor +from danswer.direct_qa.qa_utils import process_answer +from danswer.direct_qa.qa_utils import process_model_tokens +from danswer.direct_qa.qa_utils import simulate_streaming_response +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time + +logger = setup_logger() + + +def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]: + """ + Utility to add in some common default values so they don't have to be set every time. + """ + return { + "do_sample": False, + "seed": 69, # For reproducibility + **kwargs, + } + + +class HuggingFaceCompletionQA(QAModel): + def __init__( + self, + prompt_processor: NonChatPromptProcessor = FreeformProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + include_metadata: bool = False, + api_key: str | None = None, + ) -> None: + self.prompt_processor = prompt_processor + self.max_output_tokens = max_output_tokens + self.include_metadata = include_metadata + self.client = InferenceClient(model=model_version, token=api_key) + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + model_output = self.client.text_generation( + filled_prompt, + **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), + ) + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + model_stream = self.client.text_generation( + filled_prompt, + **_build_hf_inference_settings( + max_new_tokens=self.max_output_tokens, stream=True + ), + ) + + yield from process_model_tokens( + tokens=model_stream, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) + + +class HuggingFaceChatCompletionQA(QAModel): + """Chat in this class refers to the HuggingFace Conversational API. + Not to be confused with Chat/Instruction finetuned models. + Llama2-chat... means it is an Instruction finetuned model, not necessarily that + it supports Conversational API""" + + def __init__( + self, + prompt_processor: ChatPromptProcessor = JsonChatProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + include_metadata: bool = False, + api_key: str | None = None, + ) -> None: + self.prompt_processor = prompt_processor + self.max_output_tokens = max_output_tokens + self.include_metadata = include_metadata + self.client = InferenceClient(model=model_version, token=api_key) + + @staticmethod + def _convert_chat_to_hf_conversational_format( + dialog: list[dict[str, str]] + ) -> tuple[str, list[str], list[str]]: + if dialog[-1]["role"] != "user": + raise Exception( + "Last message in conversational dialog must be User message" + ) + user_message = dialog[-1]["content"] + dialog = dialog[0:-1] + generated_responses = [] + past_user_inputs = [] + for message in dialog: + # HuggingFace inference client doesn't support system messages today + # so lumping them in with user messages + if message["role"] in ["user", "system"]: + past_user_inputs += [message["content"]] + else: + generated_responses += [message["content"]] + return user_message, generated_responses, past_user_inputs + + def _get_hf_model_output( + self, query: str, context_docs: list[InferenceChunk] + ) -> str: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + + ( + query, + past_responses, + past_inputs, + ) = self._convert_chat_to_hf_conversational_format(filled_prompt) + + logger.debug(f"Last Input: {query}") + logger.debug(f"Past Inputs: {past_inputs}") + logger.debug(f"Past Responses: {past_responses}") + try: + model_output = self.client.conversational( + query, + generated_responses=past_responses, + past_user_inputs=past_inputs, + parameters={"max_length": self.max_output_tokens}, + ) + except HfHubHTTPError as model_error: + if model_error.response.status_code == 422: + raise ValueError( + "Selected HuggingFace Model does not support HuggingFace Conversational API," + "try using the huggingface-inference-completion in Danswer instead" + ) + raise + logger.debug(model_output) + + return model_output + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + model_output = self._get_hf_model_output(query, context_docs) + + answer, quotes_dict = process_answer(model_output, context_docs) + + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + """As of Aug 2023, HF conversational (chat) endpoints do not support streaming + So here it is faked by streaming characters within Danswer from the model output + """ + model_output = self._get_hf_model_output(query, context_docs) + + model_stream = simulate_streaming_response(model_output) + + yield from process_model_tokens( + tokens=model_stream, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) diff --git a/backend/danswer/direct_qa/huggingface_inference.py b/backend/danswer/direct_qa/huggingface_inference.py deleted file mode 100644 index 7b9993586e99..000000000000 --- a/backend/danswer/direct_qa/huggingface_inference.py +++ /dev/null @@ -1,248 +0,0 @@ -from collections.abc import Generator -from typing import Any, Iterator - -from danswer.chunking.models import InferenceChunk -from danswer.configs.model_configs import ( - GEN_AI_HUGGINGFACE_DISABLE_STREAM, - GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL, - GEN_AI_MAX_OUTPUT_TOKENS, -) -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.direct_qa.interfaces import DanswerAnswer -from danswer.direct_qa.interfaces import DanswerQuote -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.qa_prompts import ( - ChatPromptProcessor, - JsonChatProcessor, - JsonProcessor, -) -from danswer.direct_qa.qa_prompts import NonChatPromptProcessor -from danswer.direct_qa.qa_utils import process_answer -from danswer.direct_qa.qa_utils import process_model_tokens -from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time -from huggingface_hub import InferenceClient - -logger = setup_logger() - - -def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]: - """ - Utility to add in some common default values so they don't have to be set every time. - """ - return { - "do_sample": False, - "seed": 69, # For reproducibility - **kwargs, - } - - -def _generic_chat_dialog_to_prompt_formatter(dialog: list[dict[str, str]]) -> str: - """ - Utility to convert chat dialog to a text-generation prompt for models tuned for chat. - Note - This is a "best guess" attempt at a generic completions prompt for chat - completion models. It isn't optimized for all chat trained models, but tries - to serialize to a format that most models understand. - Models like Llama2-chat have been optimized for certain formatting of chat - completions, and this function doesn't take that into account, so you won't - always get the best possible outcome. - TODO - Add the ability to pass custom formatters for chat dialogue - """ - DEFAULT_SYSTEM_PROMPT = """\ - You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense or is not factually coherent, explain why instead of answering incorrectly. - If you don't know the answer to a question, don't share false information.""" - prompt = "" - if dialog[0]["role"] != "system": - dialog = [ - { - "role": "system", - "content": DEFAULT_SYSTEM_PROMPT, - } - ] + dialog - for message in dialog: - prompt += f"{message['role'].upper()}: {message['content']}\n" - prompt += "ASSISTANT:" - return prompt - - -def _mock_streaming_response(tokens: str) -> Generator[str, None, None]: - """Utility to mock a streaming response""" - for token in tokens: - yield token - - -class HuggingFaceInferenceCompletionQA(QAModel): - def __init__( - self, - prompt_processor: NonChatPromptProcessor = JsonProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - include_metadata: bool = False, - api_key: str | None = None, - ) -> None: - self.prompt_processor = prompt_processor - self.max_output_tokens = max_output_tokens - self.include_metadata = include_metadata - self.client = InferenceClient(model=model_version, token=api_key) - - @log_function_time() - def answer_question( - self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - model_output = self.client.text_generation( - filled_prompt, - **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), - ) - logger.debug(model_output) - answer, quotes_dict = process_answer(model_output, context_docs) - return answer, quotes_dict - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - if not GEN_AI_HUGGINGFACE_DISABLE_STREAM: - model_stream = self.client.text_generation( - filled_prompt, - **_build_hf_inference_settings( - max_new_tokens=self.max_output_tokens, stream=True - ), - ) - else: - model_output = self.client.text_generation( - filled_prompt, - **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), - ) - logger.debug(model_output) - model_stream = _mock_streaming_response(model_output) - yield from process_model_tokens( - tokens=model_stream, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) - - -class HuggingFaceInferenceChatCompletionQA(QAModel): - def __init__( - self, - prompt_processor: ChatPromptProcessor = JsonChatProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - include_metadata: bool = False, - api_key: str | None = None, - ) -> None: - self.prompt_processor = prompt_processor - self.max_output_tokens = max_output_tokens - self.include_metadata = include_metadata - self.client = InferenceClient(model=model_version, token=api_key) - - @staticmethod - def convert_dialog_to_conversational_format( - dialog: list[dict[str, str]] - ) -> tuple[str, list[str], list[str]]: - if dialog[-1]["role"] != "user": - raise Exception( - "Last message in conversational dialog must be User message" - ) - user_message = dialog[-1]["content"] - dialog = dialog[0:-1] - generated_responses = [] - past_user_inputs = [] - for message in dialog: - # HuggingFace inference client doesn't support system messages today - # so lumping them in with user messages - if message["role"] in ["user", "system"]: - past_user_inputs += [message["content"]] - else: - generated_responses += [message["content"]] - return user_message, generated_responses, past_user_inputs - - @log_function_time() - def answer_question( - self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL: - ( - message, - generated_responses, - past_user_inputs, - ) = self.convert_dialog_to_conversational_format(filled_prompt) - model_output = self.client.conversational( - message, - generated_responses=generated_responses, - past_user_inputs=past_user_inputs, - parameters={"max_length": self.max_output_tokens}, - ) - else: - chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt) - logger.debug(chat_prompt) - model_output = self.client.text_generation( - chat_prompt, - **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), - ) - logger.debug(model_output) - answer, quotes_dict = process_answer(model_output, context_docs) - return answer, quotes_dict - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - if not GEN_AI_HUGGINGFACE_DISABLE_STREAM: - if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL: - raise Exception( - "Conversational API is not available with streaming enabled. Please either " - + "disable streaming, or disable using conversational API." - ) - chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt) - logger.debug(chat_prompt) - model_stream = self.client.text_generation( - chat_prompt, - **_build_hf_inference_settings( - max_new_tokens=self.max_output_tokens, stream=True - ), - ) - else: - if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL: - ( - message, - generated_responses, - past_user_inputs, - ) = self.convert_dialog_to_conversational_format(filled_prompt) - model_output = self.client.conversational( - message, - generated_responses=generated_responses, - past_user_inputs=past_user_inputs, - parameters={"max_length": self.max_output_tokens}, - ) - else: - chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt) - logger.debug(chat_prompt) - model_output = self.client.text_generation( - chat_prompt, - **_build_hf_inference_settings( - max_new_tokens=self.max_output_tokens - ), - ) - logger.debug(model_output) - model_stream = _mock_streaming_response(model_output) - yield from process_model_tokens( - tokens=model_stream, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py index a40a7af4247c..9a4fa8229cd5 100644 --- a/backend/danswer/direct_qa/interfaces.py +++ b/backend/danswer/direct_qa/interfaces.py @@ -23,6 +23,12 @@ class DanswerQuote: class QAModel: + @property + def requires_api_key(self) -> bool: + """Is this model protected by security features + Does it need an api key to access the model for inference""" + return True + def warm_up_model(self) -> None: """This is called during server start up to load the models into memory pass if model is accessed via API""" diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py index 022ea8cd272d..a766044275af 100644 --- a/backend/danswer/direct_qa/open_ai.py +++ b/backend/danswer/direct_qa/open_ai.py @@ -15,8 +15,6 @@ from openai.error import Timeout from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import INCLUDE_METADATA -from danswer.configs.app_configs import OPENAI_API_KEY -from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.direct_qa.exceptions import OpenAIKeyMissing @@ -28,9 +26,9 @@ from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg from danswer.direct_qa.qa_prompts import JsonChatProcessor from danswer.direct_qa.qa_prompts import JsonProcessor from danswer.direct_qa.qa_prompts import NonChatPromptProcessor +from danswer.direct_qa.qa_utils import get_gen_ai_api_key from danswer.direct_qa.qa_utils import process_answer from danswer.direct_qa.qa_utils import process_model_tokens -from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -41,15 +39,9 @@ logger = setup_logger() F = TypeVar("F", bound=Callable) -def get_openai_api_key() -> str: - return OPENAI_API_KEY or cast( - str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY) - ) - - def _ensure_openai_api_key(api_key: str | None) -> str: try: - return api_key or get_openai_api_key() + return api_key or get_gen_ai_api_key() except ConfigNotFoundError: raise OpenAIKeyMissing() @@ -131,7 +123,7 @@ class OpenAICompletionQA(OpenAIQAModel): self.timeout = timeout self.include_metadata = include_metadata try: - self.api_key = api_key or get_openai_api_key() + self.api_key = api_key or get_gen_ai_api_key() except ConfigNotFoundError: raise OpenAIKeyMissing() diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index b06766159164..a3bb3d993752 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -3,6 +3,7 @@ import math import re from collections.abc import Generator from typing import Any +from typing import cast from typing import Optional from typing import Tuple @@ -12,14 +13,17 @@ from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.constants import BLURB from danswer.configs.constants import DOCUMENT_ID +from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINK from danswer.configs.constants import SOURCE_TYPE +from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.qa_prompts import ANSWER_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT +from danswer.dynamic_configs import get_dynamic_config_store from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import shared_precompare_cleanup @@ -27,6 +31,12 @@ from danswer.utils.text_processing import shared_precompare_cleanup logger = setup_logger() +def get_gen_ai_api_key() -> str: + return GEN_AI_API_KEY or cast( + str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY) + ) + + def structure_quotes_for_response( quotes: list[DanswerQuote] | None, ) -> dict[str, dict[str, str | None]]: @@ -246,3 +256,9 @@ def process_model_tokens( quotes = extract_quotes_from_completed_token_stream(model_output, context_docs) yield structure_quotes_for_response(quotes) + + +def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: + """Mock streaming by generating the passed in model output, character by character""" + for token in model_out: + yield token diff --git a/backend/danswer/direct_qa/request_model.py b/backend/danswer/direct_qa/request_model.py new file mode 100644 index 000000000000..7e6fff12ed0a --- /dev/null +++ b/backend/danswer/direct_qa/request_model.py @@ -0,0 +1,201 @@ +import abc +import json +from collections.abc import Generator +from typing import Any + +import requests +from requests.models import Response + +from danswer.chunking.models import InferenceChunk +from danswer.configs.constants import ModelHostType +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_ENDPOINT +from danswer.configs.model_configs import GEN_AI_HOST_TYPE +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.qa_prompts import JsonProcessor +from danswer.direct_qa.qa_prompts import NonChatPromptProcessor +from danswer.direct_qa.qa_utils import process_answer +from danswer.direct_qa.qa_utils import process_model_tokens +from danswer.direct_qa.qa_utils import simulate_streaming_response +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time + + +logger = setup_logger() + + +class HostSpecificRequestModel(abc.ABC): + """Provides a more minimal implementation requirement for extending to new models + hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics + """ + + @staticmethod + @abc.abstractmethod + def send_model_request( + filled_prompt: str, + endpoint: str, + api_key: str | None, + max_output_tokens: int, + stream: bool, + timeout: int | None, + ) -> Response: + """Given a filled out prompt, how to send it to the model API with the + correct request format with the correct parameters""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def extract_model_output_from_response( + response: Response, + ) -> str: + """Extract the full model output text from a response. + This is for nonstreaming endpoints""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def generate_model_tokens_from_response( + response: Response, + ) -> Generator[str, None, None]: + """Generate tokens from a streaming response + This is for streaming endpoints""" + raise NotImplementedError + + +class HuggingFaceRequestModel(HostSpecificRequestModel): + @staticmethod + def send_model_request( + filled_prompt: str, + endpoint: str, + api_key: str | None, + max_output_tokens: int, + stream: bool, # Not supported by Inference Endpoints (as of Aug 2023) + timeout: int | None, + ) -> Response: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + data = { + "inputs": filled_prompt, + "parameters": { + # HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive + "temperature": 0.01, + # Skip the long tail + "top_p": 0.9, + "max_new_tokens": max_output_tokens, + }, + } + try: + return requests.post(endpoint, headers=headers, json=data, timeout=timeout) + except TimeoutError as error: + raise TimeoutError(f"Model inference to {endpoint} timed out") from error + + @staticmethod + def _hf_extract_model_output( + response: Response, + ) -> str: + if response.status_code != 200: + response.raise_for_status() + + return json.loads(response.content)[0].get("generated_text", "") + + @staticmethod + def extract_model_output_from_response( + response: Response, + ) -> str: + return HuggingFaceRequestModel._hf_extract_model_output(response) + + @staticmethod + def generate_model_tokens_from_response( + response: Response, + ) -> Generator[str, None, None]: + """HF endpoints do not do streaming currently so this function will + simulate streaming for the meantime but will need to be replaced in + the future once streaming is enabled.""" + model_out = HuggingFaceRequestModel._hf_extract_model_output(response) + yield from simulate_streaming_response(model_out) + + +def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel: + if model_host_type == ModelHostType.HUGGINGFACE.value: + return HuggingFaceRequestModel() + else: + # TODO support Azure, GCP, AWS + raise ValueError( + "Invalid model hosting service selected, currently supports only huggingface" + ) + + +class RequestCompletionQA(QAModel): + def __init__( + self, + endpoint: str = GEN_AI_ENDPOINT, + model_host_type: str = GEN_AI_HOST_TYPE, + api_key: str | None = GEN_AI_API_KEY, + prompt_processor: NonChatPromptProcessor = JsonProcessor(), + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + timeout: int | None = None, + ) -> None: + self.endpoint = endpoint + self.api_key = api_key + self.prompt_processor = prompt_processor + self.max_output_tokens = max_output_tokens + self.model_class = get_host_specific_model_class(model_host_type) + self.timeout = timeout + + def _get_request_response( + self, query: str, context_docs: list[InferenceChunk], stream: bool + ) -> Response: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, include_metadata=False + ) + logger.debug(filled_prompt) + + return self.model_class.send_model_request( + filled_prompt, + self.endpoint, + self.api_key, + self.max_output_tokens, + stream, + self.timeout, + ) + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + model_api_response = self._get_request_response( + query, context_docs, stream=False + ) + + model_output = self.model_class.extract_model_output_from_response( + model_api_response + ) + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, + query: str, + context_docs: list[InferenceChunk], + ) -> Generator[dict[str, Any] | None, None, None]: + model_api_response = self._get_request_response( + query, context_docs, stream=False + ) + + token_generator = self.model_class.generate_model_tokens_from_response( + model_api_response + ) + + yield from process_model_tokens( + tokens=token_generator, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 0a708c38c50d..0e1a40d62c84 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -18,7 +18,7 @@ from danswer.auth.users import current_user from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX -from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY +from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.connectors.file.utils import write_temp_files from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_KEY from danswer.connectors.google_drive.connector_auth import get_auth_url @@ -50,7 +50,7 @@ from danswer.db.index_attempt import create_index_attempt from danswer.db.models import User from danswer.direct_qa import check_model_api_key_is_valid from danswer.direct_qa import get_default_backend_qa_model -from danswer.direct_qa.open_ai import get_openai_api_key +from danswer.direct_qa.open_ai import get_gen_ai_api_key from danswer.direct_qa.open_ai import OpenAIQAModel from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError @@ -293,19 +293,17 @@ def connector_run_once( ) -@router.head("/admin/openai-api-key/validate") -def validate_existing_openai_api_key( +@router.head("/admin/genai-api-key/validate") +def validate_existing_genai_api_key( _: User = Depends(current_admin_user), ) -> None: # OpenAI key is only used for generative QA, so no need to validate this # if it's turned off or if a non-OpenAI model is being used - if DISABLE_GENERATIVE_AI or not isinstance( - get_default_backend_qa_model(), OpenAIQAModel - ): + if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key: return # Only validate every so often - check_key_time = "openai_api_key_last_check_time" + check_key_time = "genai_api_key_last_check_time" kv_store = get_dynamic_config_store() curr_time = datetime.now() try: @@ -318,7 +316,7 @@ def validate_existing_openai_api_key( pass try: - openai_api_key = get_openai_api_key() + genai_api_key = get_gen_ai_api_key() except ConfigNotFoundError: raise HTTPException(status_code=404, detail="Key not found") except ValueError as e: @@ -327,7 +325,7 @@ def validate_existing_openai_api_key( get_dynamic_config_store().store(check_key_time, curr_time.timestamp()) try: - is_valid = check_model_api_key_is_valid(openai_api_key) + is_valid = check_model_api_key_is_valid(genai_api_key) except ValueError: # this is the case where they aren't using an OpenAI-based model is_valid = True @@ -336,8 +334,8 @@ def validate_existing_openai_api_key( raise HTTPException(status_code=400, detail="Invalid API key provided") -@router.get("/admin/openai-api-key", response_model=ApiKey) -def get_openai_api_key_from_dynamic_config_store( +@router.get("/admin/genai-api-key", response_model=ApiKey) +def get_gen_ai_api_key_from_dynamic_config_store( _: User = Depends(current_admin_user), ) -> ApiKey: """ @@ -347,15 +345,15 @@ def get_openai_api_key_from_dynamic_config_store( # only get last 4 characters of key to not expose full key return ApiKey( api_key=cast( - str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY) + str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY) )[-4:] ) except ConfigNotFoundError: raise HTTPException(status_code=404, detail="Key not found") -@router.put("/admin/openai-api-key") -def store_openai_api_key( +@router.put("/admin/genai-api-key") +def store_genai_api_key( request: ApiKey, _: User = Depends(current_admin_user), ) -> None: @@ -363,16 +361,16 @@ def store_openai_api_key( is_valid = check_model_api_key_is_valid(request.api_key) if not is_valid: raise HTTPException(400, "Invalid API key provided") - get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key) + get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key) except RuntimeError as e: raise HTTPException(400, str(e)) -@router.delete("/admin/openai-api-key") -def delete_openai_api_key( +@router.delete("/admin/genai-api-key") +def delete_genai_api_key( _: User = Depends(current_admin_user), ) -> None: - get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY) + get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY) """Endpoints for basic users""" diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index f2a87ef85b31..b88bf65d3700 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -6,6 +6,7 @@ reorder-python-imports==3.9.0 types-beautifulsoup4==4.12.0.3 types-html5lib==1.1.11.13 types-oauthlib==3.2.0.9 +types-setuptools==68.0.0.3 types-psycopg2==2.9.21.10 types-python-dateutil==2.8.19.13 types-regex==2023.3.23.1 diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index f06f881273a1..3b6663ddc920 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -19,6 +19,9 @@ services: environment: - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} + - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} + - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} + - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} - POSTGRES_HOST=relational_db - QDRANT_HOST=vector_db - TYPESENSE_HOST=search_engine @@ -49,6 +52,9 @@ services: environment: - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} + - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} + - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} + - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} - POSTGRES_HOST=relational_db - QDRANT_HOST=vector_db - TYPESENSE_HOST=search_engine diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index 21b6cab43f7a..f3eeda318830 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -5,7 +5,7 @@ # Insert your OpenAI API key here, currently the only Generative AI endpoint for QA that we support is OpenAI # If not provided here, UI will prompt on setup -OPENAI_API_KEY= +GEN_AI_API_KEY= # Choose between "openai-chat-completion" and "openai-completion" INTERNAL_MODEL_VERSION=openai-chat-completion # Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility diff --git a/web/src/app/admin/keys/openai/page.tsx b/web/src/app/admin/keys/openai/page.tsx index f581ccaaebb4..52f96b43a7ec 100644 --- a/web/src/app/admin/keys/openai/page.tsx +++ b/web/src/app/admin/keys/openai/page.tsx @@ -3,13 +3,13 @@ import { LoadingAnimation } from "@/components/Loading"; import { KeyIcon, TrashIcon } from "@/components/icons/icons"; import { ApiKeyForm } from "@/components/openai/ApiKeyForm"; -import { OPENAI_API_KEY_URL } from "@/components/openai/constants"; +import { GEN_AI_API_KEY_URL } from "@/components/openai/constants"; import { fetcher } from "@/lib/fetcher"; import useSWR, { mutate } from "swr"; const ExistingKeys = () => { const { data, isLoading, error } = useSWR<{ api_key: string }>( - OPENAI_API_KEY_URL, + GEN_AI_API_KEY_URL, fetcher ); @@ -33,7 +33,7 @@ const ExistingKeys = () => {