diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index c088825e867a..6b0c3e8c2a37 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -34,7 +34,7 @@ from danswer.direct_qa.qa_utils import get_usable_chunks from danswer.document_index import get_default_document_index from danswer.indexing.models import InferenceChunk from danswer.llm.build import get_default_llm -from danswer.llm.llm import LLM +from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import translate_danswer_msg_to_langchain from danswer.search.access_filters import build_access_filters_for_user diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index dd22608ed975..24886c1e7efb 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -77,21 +77,6 @@ class DocumentIndexType(str, Enum): SPLIT = "split" # Typesense + Qdrant -class DanswerGenAIModel(str, Enum): - """This represents the internal Danswer GenAI model which determines the class that is used - to generate responses to the user query. Different models/services require different internal - handling, this allows for modularity of implementation within Danswer""" - - OPENAI = "openai-completion" - OPENAI_CHAT = "openai-chat-completion" - GPT4ALL = "gpt4all-completion" - GPT4ALL_CHAT = "gpt4all-chat-completion" - HUGGINGFACE = "huggingface-client-completion" - HUGGINGFACE_CHAT = "huggingface-client-chat-completion" - REQUEST = "request-completion" - TRANSFORMERS = "transformers" - - class AuthType(str, Enum): DISABLED = "disabled" BASIC = "basic" @@ -100,17 +85,6 @@ class AuthType(str, Enum): SAML = "saml" -class ModelHostType(str, Enum): - """For GenAI models interfaced via requests, different services have different - expectations for what fields are included in the request""" - - # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task - HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API - # https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183 - COLAB_DEMO = "colab-demo" - # TODO support for Azure, AWS, GCP GenAI model hosting - - class QAFeedbackType(str, Enum): LIKE = "like" # User likes the answer, used for metrics DISLIKE = "dislike" # User dislikes the answer, used for metrics diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 59a1961ce0ec..353853ac2815 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -1,9 +1,5 @@ import os -from danswer.configs.constants import DanswerGenAIModel -from danswer.configs.constants import ModelHostType - - ##### # Embedding/Reranking Model Configs ##### @@ -55,62 +51,38 @@ SEARCH_DISTANCE_CUTOFF = 0 # Intent model max context size QUERY_MAX_CONTEXT_SIZE = 256 +# Danswer custom Deep Learning Models +INTENT_MODEL_VERSION = "danswer/intent-model" + ##### # Generative AI Model Configs ##### -# Other models should work as well, check the library/API compatibility. -# But these are the models that have been verified to work with the existing prompts. -# Using a different model may require some prompt tuning. See qa_prompts.py -VERIFIED_MODELS = { - DanswerGenAIModel.OPENAI: ["text-davinci-003"], - DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"], - DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"], - DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"], - # The "chat" model below is actually "instruction finetuned" and does not support conversational - DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"], - DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"], - # Created by Deepset.ai - # https://huggingface.co/deepset/deberta-v3-large-squad2 - # Model provided with no modifications - DanswerGenAIModel.TRANSFORMERS.value: ["deepset/deberta-v3-large-squad2"], -} -# Sets the internal Danswer model class to use -INTERNAL_MODEL_VERSION = os.environ.get( - "INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value -) +# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default, +# be sure to use one that is LiteLLM compatible: +# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables +# The provider is the prefix before / in the model argument + +# Additionally Danswer supports GPT4All and custom request library based models +# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach +# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally +GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai" +# If using Azure, it's the engine name, for example: Danswer +GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo" # If the Generative AI model requires an API key for access, otherwise can leave blank -GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) - -# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version -GEN_AI_MODEL_VERSION = os.environ.get( - "GEN_AI_MODEL_VERSION", - VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0], +GEN_AI_API_KEY = ( + os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None ) -# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then -# set the two below to specify -# - Where to hit the endpoint -# - How should the request be formed -GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "") -GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value) +# API Base, such as (for Azure): https://danswer.openai.azure.com/ +GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None +# API Version, such as (for Azure): 2023-09-15-preview +GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None # Set this to be enough for an answer + quotes. Also used for Chat GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024) # This next restriction is only used for chat ATM, used to expire old messages as needed GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000) GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0) - -# Danswer custom Deep Learning Models -INTENT_MODEL_VERSION = "danswer/intent-model" - -##### -# OpenAI Azure -##### -API_BASE_OPENAI = os.environ.get("API_BASE_OPENAI", "") -API_TYPE_OPENAI = os.environ.get("API_TYPE_OPENAI", "").lower() -API_VERSION_OPENAI = os.environ.get("API_VERSION_OPENAI", "") -# Deployment ID used interchangeably with "engine" parameter -AZURE_DEPLOYMENT_ID = os.environ.get("AZURE_DEPLOYMENT_ID", "") diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index f99b14c27426..8642e17e9597 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -9,8 +9,6 @@ from danswer.configs.constants import IGNORE_FOR_QA from danswer.db.feedback import create_query_event from danswer.db.feedback import update_query_event_retrieved_documents from danswer.db.models import User -from danswer.direct_qa.exceptions import OpenAIKeyMissing -from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.qa_utils import get_usable_chunks @@ -132,7 +130,7 @@ def answer_qa_query( qa_model = get_default_qa_model( timeout=answer_generation_timeout, real_time_flow=real_time_flow ) - except (UnknownModelError, OpenAIKeyMissing) as e: + except Exception as e: return QAResponse( answer=None, quotes=None, diff --git a/backend/danswer/direct_qa/exceptions.py b/backend/danswer/direct_qa/exceptions.py deleted file mode 100644 index eb0434a7b1aa..000000000000 --- a/backend/danswer/direct_qa/exceptions.py +++ /dev/null @@ -1,13 +0,0 @@ -class OpenAIKeyMissing(Exception): - default_msg = ( - "Unable to find existing OpenAI Key. " - 'A new key can be added from "Keys" section of the Admin Panel' - ) - - def __init__(self, msg: str = default_msg) -> None: - super().__init__(msg) - - -class UnknownModelError(Exception): - def __init__(self, model_name: str) -> None: - super().__init__(f"Unknown Internal QA model name: {model_name}") diff --git a/backend/danswer/direct_qa/gpt_4_all.py b/backend/danswer/direct_qa/gpt_4_all.py deleted file mode 100644 index cb12883f035e..000000000000 --- a/backend/danswer/direct_qa/gpt_4_all.py +++ /dev/null @@ -1,209 +0,0 @@ -from collections.abc import Callable -from typing import Any - -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.direct_qa.interfaces import AnswerQuestionReturn -from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.models import LLMMetricsContainer -from danswer.direct_qa.qa_prompts import ChatPromptProcessor -from danswer.direct_qa.qa_prompts import NonChatPromptProcessor -from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor -from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor -from danswer.direct_qa.qa_utils import process_answer -from danswer.direct_qa.qa_utils import process_model_tokens -from danswer.indexing.models import InferenceChunk -from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time - -logger = setup_logger() - - -class DummyGPT4All: - """In the case of import failure due to M1 Mac incompatibility, - so this module does not raise exceptions during server startup, - as long as this module isn't actually used""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - raise RuntimeError("GPT4All library not installed.") - - -try: - from gpt4all import GPT4All # type:ignore -except ImportError: - logger.warning( - "GPT4All library not installed. " - "If you wish to run GPT4ALL (in memory) to power Danswer's " - "Generative AI features, please install gpt4all==1.0.5. " - "As of Aug 2023, this library is not compatible with M1 Mac." - ) - GPT4All = DummyGPT4All - - -GPT4ALL_MODEL: GPT4All | None = None - - -def get_gpt_4_all_model( - model_version: str = GEN_AI_MODEL_VERSION, -) -> GPT4All: - global GPT4ALL_MODEL - if GPT4ALL_MODEL is None: - GPT4ALL_MODEL = GPT4All(model_version) - return GPT4ALL_MODEL - - -def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]: - """ - Utility to add in some common default values so they don't have to be set every time. - """ - return { - "temp": 0, - **kwargs, - } - - -class GPT4AllCompletionQA(QAModel): - def __init__( - self, - prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - include_metadata: bool = False, # gpt4all models can't handle this atm - ) -> None: - self.prompt_processor = prompt_processor - self.model_version = model_version - self.max_output_tokens = max_output_tokens - self.include_metadata = include_metadata - - @property - def requires_api_key(self) -> bool: - return False - - def warm_up_model(self) -> None: - get_gpt_4_all_model(self.model_version) - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused - ) -> AnswerQuestionReturn: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - gen_ai_model = get_gpt_4_all_model(self.model_version) - - model_output = gen_ai_model.generate( - **_build_gpt4all_settings( - prompt=filled_prompt, max_tokens=self.max_output_tokens - ), - ) - - logger.debug(model_output) - - return process_answer(model_output, context_docs) - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - gen_ai_model = get_gpt_4_all_model(self.model_version) - - model_stream = gen_ai_model.generate( - **_build_gpt4all_settings( - prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True - ), - ) - - yield from process_model_tokens( - tokens=model_stream, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) - - -class GPT4AllChatCompletionQA(QAModel): - def __init__( - self, - prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - include_metadata: bool = False, # gpt4all models can't handle this atm - ) -> None: - self.prompt_processor = prompt_processor - self.model_version = model_version - self.max_output_tokens = max_output_tokens - self.include_metadata = include_metadata - - @property - def requires_api_key(self) -> bool: - return False - - def warm_up_model(self) -> None: - get_gpt_4_all_model(self.model_version) - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionReturn: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - gen_ai_model = get_gpt_4_all_model(self.model_version) - - with gen_ai_model.chat_session(): - context_msgs = filled_prompt[:-1] - user_query = filled_prompt[-1].get("content") - for message in context_msgs: - gen_ai_model.current_chat_session.append(message) - - model_output = gen_ai_model.generate( - **_build_gpt4all_settings( - prompt=user_query, max_tokens=self.max_output_tokens - ), - ) - - logger.debug(model_output) - - return process_answer(model_output, context_docs) - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - gen_ai_model = get_gpt_4_all_model(self.model_version) - - with gen_ai_model.chat_session(): - context_msgs = filled_prompt[:-1] - user_query = filled_prompt[-1].get("content") - for message in context_msgs: - gen_ai_model.current_chat_session.append(message) - - model_stream = gen_ai_model.generate( - **_build_gpt4all_settings( - prompt=user_query, max_tokens=self.max_output_tokens - ), - ) - - yield from process_model_tokens( - tokens=model_stream, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) diff --git a/backend/danswer/direct_qa/huggingface.py b/backend/danswer/direct_qa/huggingface.py deleted file mode 100644 index 3f6168b0a3ec..000000000000 --- a/backend/danswer/direct_qa/huggingface.py +++ /dev/null @@ -1,195 +0,0 @@ -from collections.abc import Callable -from typing import Any - -from huggingface_hub import InferenceClient # type:ignore -from huggingface_hub.utils import HfHubHTTPError # type:ignore - -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.direct_qa.interfaces import AnswerQuestionReturn -from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.models import LLMMetricsContainer -from danswer.direct_qa.qa_prompts import ChatPromptProcessor -from danswer.direct_qa.qa_prompts import FreeformProcessor -from danswer.direct_qa.qa_prompts import JsonChatProcessor -from danswer.direct_qa.qa_prompts import NonChatPromptProcessor -from danswer.direct_qa.qa_utils import process_answer -from danswer.direct_qa.qa_utils import process_model_tokens -from danswer.direct_qa.qa_utils import simulate_streaming_response -from danswer.indexing.models import InferenceChunk -from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time - -logger = setup_logger() - - -def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]: - """ - Utility to add in some common default values so they don't have to be set every time. - """ - return { - "do_sample": False, - "seed": 69, # For reproducibility - **kwargs, - } - - -class HuggingFaceCompletionQA(QAModel): - def __init__( - self, - prompt_processor: NonChatPromptProcessor = FreeformProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - include_metadata: bool = False, - api_key: str | None = None, - ) -> None: - self.prompt_processor = prompt_processor - self.max_output_tokens = max_output_tokens - self.include_metadata = include_metadata - self.client = InferenceClient(model=model_version, token=api_key) - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused - ) -> AnswerQuestionReturn: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - model_output = self.client.text_generation( - filled_prompt, - **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), - ) - logger.debug(model_output) - - return process_answer(model_output, context_docs) - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - model_stream = self.client.text_generation( - filled_prompt, - **_build_hf_inference_settings( - max_new_tokens=self.max_output_tokens, stream=True - ), - ) - - yield from process_model_tokens( - tokens=model_stream, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) - - -class HuggingFaceChatCompletionQA(QAModel): - """Chat in this class refers to the HuggingFace Conversational API. - Not to be confused with Chat/Instruction finetuned models. - Llama2-chat... means it is an Instruction finetuned model, not necessarily that - it supports Conversational API""" - - def __init__( - self, - prompt_processor: ChatPromptProcessor = JsonChatProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - include_metadata: bool = False, - api_key: str | None = None, - ) -> None: - self.prompt_processor = prompt_processor - self.max_output_tokens = max_output_tokens - self.include_metadata = include_metadata - self.client = InferenceClient(model=model_version, token=api_key) - - @staticmethod - def _convert_chat_to_hf_conversational_format( - dialog: list[dict[str, str]] - ) -> tuple[str, list[str], list[str]]: - if dialog[-1]["role"] != "user": - raise Exception( - "Last message in conversational dialog must be User message" - ) - user_message = dialog[-1]["content"] - dialog = dialog[0:-1] - generated_responses = [] - past_user_inputs = [] - for message in dialog: - # HuggingFace inference client doesn't support system messages today - # so lumping them in with user messages - if message["role"] in ["user", "system"]: - past_user_inputs += [message["content"]] - else: - generated_responses += [message["content"]] - return user_message, generated_responses, past_user_inputs - - def _get_hf_model_output( - self, query: str, context_docs: list[InferenceChunk] - ) -> str: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - - ( - query, - past_responses, - past_inputs, - ) = self._convert_chat_to_hf_conversational_format(filled_prompt) - - logger.debug(f"Last Input: {query}") - logger.debug(f"Past Inputs: {past_inputs}") - logger.debug(f"Past Responses: {past_responses}") - try: - model_output = self.client.conversational( - query, - generated_responses=past_responses, - past_user_inputs=past_inputs, - parameters={"max_length": self.max_output_tokens}, - ) - except HfHubHTTPError as model_error: - if model_error.response.status_code == 422: - raise ValueError( - "Selected HuggingFace Model does not support HuggingFace Conversational API," - "try using the huggingface-inference-completion in Danswer instead" - ) - raise - logger.debug(model_output) - - return model_output - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionReturn: - model_output = self._get_hf_model_output(query, context_docs) - - answer, quotes_dict = process_answer(model_output, context_docs) - - return answer, quotes_dict - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - """As of Aug 2023, HF conversational (chat) endpoints do not support streaming - So here it is faked by streaming characters within Danswer from the model output - """ - model_output = self._get_hf_model_output(query, context_docs) - - model_stream = simulate_streaming_response(model_output) - - yield from process_model_tokens( - tokens=model_stream, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py index 688d0f002bd1..60897023fa1d 100644 --- a/backend/danswer/direct_qa/interfaces.py +++ b/backend/danswer/direct_qa/interfaces.py @@ -52,6 +52,7 @@ class QAModel: def requires_api_key(self) -> bool: """Is this model protected by security features Does it need an api key to access the model for inference""" + # TODO, this should be false for custom request model and gpt4all return True def warm_up_model(self) -> None: diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py index 95ca3aafe045..a43b05a215e6 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/llm_utils.py @@ -1,32 +1,11 @@ -from typing import Any - -import pkg_resources from openai.error import AuthenticationError from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.constants import DanswerGenAIModel -from danswer.configs.constants import ModelHostType -from danswer.configs.model_configs import GEN_AI_API_KEY -from danswer.configs.model_configs import GEN_AI_ENDPOINT -from danswer.configs.model_configs import GEN_AI_HOST_TYPE -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION -from danswer.direct_qa.exceptions import UnknownModelError -from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA -from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA -from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA -from danswer.direct_qa.huggingface import HuggingFaceCompletionQA from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.local_transformers import TransformerQA -from danswer.direct_qa.open_ai import OpenAICompletionQA from danswer.direct_qa.qa_block import QABlock from danswer.direct_qa.qa_block import QAHandler -from danswer.direct_qa.qa_block import SimpleChatQAHandler from danswer.direct_qa.qa_block import SingleMessageQAHandler from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler -from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor -from danswer.direct_qa.qa_utils import get_gen_ai_api_key -from danswer.direct_qa.request_model import RequestCompletionQA -from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.llm.build import get_default_llm from danswer.utils.logger import setup_logger @@ -52,93 +31,23 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool: return False -def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler: - if model == DanswerGenAIModel.OPENAI_CHAT.value: - return ( - SingleMessageQAHandler() - if real_time_flow - else SingleMessageScratchpadHandler() - ) - - return SimpleChatQAHandler() +# TODO introduce the prompt choice parameter +def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler: + return ( + SingleMessageQAHandler() if real_time_flow else SingleMessageScratchpadHandler() + ) + # return SimpleChatQAHandler() def get_default_qa_model( - internal_model: str = INTERNAL_MODEL_VERSION, - endpoint: str | None = GEN_AI_ENDPOINT, - model_host_type: str | None = GEN_AI_HOST_TYPE, - api_key: str | None = GEN_AI_API_KEY, + api_key: str | None = None, timeout: int = QA_TIMEOUT, real_time_flow: bool = True, - **kwargs: Any, ) -> QAModel: - if not api_key: - try: - api_key = get_gen_ai_api_key() - except ConfigNotFoundError: - pass + llm = get_default_llm(api_key=api_key, timeout=timeout) + qa_handler = get_default_qa_handler(real_time_flow=real_time_flow) - try: - # un-used arguments will be ignored by the underlying `LLM` class - # if any args are missing, a `TypeError` will be thrown - llm = get_default_llm(timeout=timeout) - qa_handler = get_default_qa_handler( - model=internal_model, real_time_flow=real_time_flow - ) - - return QABlock( - llm=llm, - qa_handler=qa_handler, - ) - except Exception: - logger.exception( - "Unable to build a QABlock with the new approach, going back to the " - "legacy approach" - ) - - if internal_model in [ - DanswerGenAIModel.GPT4ALL.value, - DanswerGenAIModel.GPT4ALL_CHAT.value, - ]: - # gpt4all is not compatible M1 Mac hardware as of Aug 2023 - pkg_resources.get_distribution("gpt4all") - - if internal_model == DanswerGenAIModel.OPENAI.value: - return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.GPT4ALL.value: - return GPT4AllCompletionQA(**kwargs) - elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value: - return GPT4AllChatCompletionQA(**kwargs) - elif internal_model == DanswerGenAIModel.HUGGINGFACE.value: - return HuggingFaceCompletionQA(api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value: - return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.TRANSFORMERS: - return TransformerQA() - elif internal_model == DanswerGenAIModel.REQUEST.value: - if endpoint is None or model_host_type is None: - raise ValueError( - "Request based GenAI model requires an endpoint and host type" - ) - if ( - model_host_type == ModelHostType.HUGGINGFACE.value - or model_host_type == ModelHostType.COLAB_DEMO.value - ): - # Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits - # With the 7B Llama2 Chat model, there is a max limit of 1512 tokens - # This is the sum of input and output tokens, so cannot take in full Danswer context - return RequestCompletionQA( - endpoint=endpoint, - model_host_type=model_host_type, - api_key=api_key, - prompt_processor=WeakModelFreeformProcessor(), - timeout=timeout, - ) - return RequestCompletionQA( - endpoint=endpoint, - model_host_type=model_host_type, - api_key=api_key, - timeout=timeout, - ) - else: - raise UnknownModelError(internal_model) + return QABlock( + llm=llm, + qa_handler=qa_handler, + ) diff --git a/backend/danswer/direct_qa/local_transformers.py b/backend/danswer/direct_qa/local_transformers.py deleted file mode 100644 index f0102f863b6d..000000000000 --- a/backend/danswer/direct_qa/local_transformers.py +++ /dev/null @@ -1,156 +0,0 @@ -import re -from collections.abc import Callable - -from transformers import pipeline # type:ignore -from transformers import QuestionAnsweringPipeline # type:ignore - -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.direct_qa.interfaces import AnswerQuestionReturn -from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn -from danswer.direct_qa.interfaces import DanswerAnswer -from danswer.direct_qa.interfaces import DanswerAnswerPiece -from danswer.direct_qa.interfaces import DanswerQuote -from danswer.direct_qa.interfaces import DanswerQuotes -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.models import LLMMetricsContainer -from danswer.indexing.models import InferenceChunk -from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time - -logger = setup_logger() - -TRANSFORMER_DEFAULT_MAX_CONTEXT = 512 - -_TRANSFORMER_MODEL: QuestionAnsweringPipeline | None = None - - -def get_default_transformer_model( - model_version: str = GEN_AI_MODEL_VERSION, - max_context: int = TRANSFORMER_DEFAULT_MAX_CONTEXT, -) -> QuestionAnsweringPipeline: - global _TRANSFORMER_MODEL - if _TRANSFORMER_MODEL is None: - _TRANSFORMER_MODEL = pipeline( - "question-answering", model=model_version, max_seq_len=max_context - ) - - return _TRANSFORMER_MODEL - - -def find_extended_answer(answer: str, context: str) -> str: - """Try to extend the answer by matching across the context text and extending before - and after the quote to some termination character""" - result = re.search( - r"(^|\n\r?|\.)(?P[^\n]{{0,250}}{}[^\n]{{0,250}})(\.|$|\n\r?)".format( - re.escape(answer) - ), - context, - flags=re.MULTILINE | re.DOTALL, - ) - if result: - return result.group("content") - - return answer - - -class TransformerQA(QAModel): - @staticmethod - def _answer_one_chunk( - query: str, - chunk: InferenceChunk, - max_context_len: int = TRANSFORMER_DEFAULT_MAX_CONTEXT, - max_cutoff: float = 0.9, - min_cutoff: float = 0.5, - ) -> tuple[str | None, DanswerQuote | None]: - """Because this type of QA model only takes 1 chunk of context with a fairly small token limit - We have to iterate the checks and check if the answer is found in any of the chunks. - This type of approach does not allow for interpolating answers across chunks - """ - model = get_default_transformer_model() - model_out = model(question=query, context=chunk.content, max_answer_len=128) - - answer = model_out.get("answer") - confidence = model_out.get("score") - - if answer is None: - return None, None - - logger.info(f"Transformer Answer: {answer}") - logger.debug(f"Transformer Confidence: {confidence}") - - # Model tends to be overconfident on short chunks - # so min score required increases as chunk size decreases - # If it's at least 0.9, then it's good enough regardless - # Default minimum of 0.5 required - score_cutoff = max( - min(max_cutoff, 1 - len(chunk.content) / max_context_len), min_cutoff - ) - if confidence < score_cutoff: - return None, None - - extended_answer = find_extended_answer(answer, chunk.content) - - danswer_quote = DanswerQuote( - quote=answer, - document_id=chunk.document_id, - link=chunk.source_links[0] if chunk.source_links else None, - source_type=chunk.source_type, - semantic_identifier=chunk.semantic_identifier, - blurb=chunk.blurb, - ) - - return extended_answer, danswer_quote - - def warm_up_model(self) -> None: - get_default_transformer_model() - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused - ) -> AnswerQuestionReturn: - danswer_quotes: list[DanswerQuote] = [] - d_answers: list[str] = [] - for chunk in context_docs: - answer, quote = self._answer_one_chunk(query, chunk) - if answer is not None and quote is not None: - d_answers.append(answer) - danswer_quotes.append(quote) - - answers_list = [ - f"Answer {ind}: {answer.strip()}" - for ind, answer in enumerate(d_answers, start=1) - ] - combined_answer = "\n".join(answers_list) - return DanswerAnswer(answer=combined_answer), DanswerQuotes( - quotes=danswer_quotes - ) - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - quotes: list[DanswerQuote] = [] - answers: list[str] = [] - for chunk in context_docs: - answer, quote = self._answer_one_chunk(query, chunk) - if answer is not None and quote is not None: - answers.append(answer) - quotes.append(quote) - - # Delay the output of the answers so there isn't long gap between first answer and quotes - answer_count = 1 - for answer in answers: - if answer_count == 1: - yield DanswerAnswerPiece(answer_piece="Source 1: ") - else: - yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ") - answer_count += 1 - for char in answer.strip(): - yield DanswerAnswerPiece(answer_piece=char) - - # signal end of answer - yield DanswerAnswerPiece(answer_piece=None) - - yield DanswerQuotes(quotes=quotes) diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py deleted file mode 100644 index f8e8bbca87f3..000000000000 --- a/backend/danswer/direct_qa/open_ai.py +++ /dev/null @@ -1,209 +0,0 @@ -from abc import ABC -from collections.abc import Callable -from collections.abc import Generator -from copy import copy -from functools import wraps -from typing import Any -from typing import cast -from typing import TypeVar - -import openai -import tiktoken -from openai.error import AuthenticationError -from openai.error import Timeout - -from danswer.configs.app_configs import INCLUDE_METADATA -from danswer.configs.model_configs import API_BASE_OPENAI -from danswer.configs.model_configs import API_TYPE_OPENAI -from danswer.configs.model_configs import API_VERSION_OPENAI -from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.direct_qa.exceptions import OpenAIKeyMissing -from danswer.direct_qa.interfaces import AnswerQuestionReturn -from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.models import LLMMetricsContainer -from danswer.direct_qa.qa_prompts import JsonProcessor -from danswer.direct_qa.qa_prompts import NonChatPromptProcessor -from danswer.direct_qa.qa_utils import get_gen_ai_api_key -from danswer.direct_qa.qa_utils import process_answer -from danswer.direct_qa.qa_utils import process_model_tokens -from danswer.dynamic_configs.interface import ConfigNotFoundError -from danswer.indexing.models import InferenceChunk -from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time - - -logger = setup_logger() - -F = TypeVar("F", bound=Callable) - - -if API_BASE_OPENAI: - openai.api_base = API_BASE_OPENAI -if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread"] - openai.api_type = API_TYPE_OPENAI - openai.api_version = API_VERSION_OPENAI - - -def _ensure_openai_api_key(api_key: str | None) -> str: - final_api_key = api_key or get_gen_ai_api_key() - if final_api_key is None: - raise OpenAIKeyMissing() - - return final_api_key - - -def _build_openai_settings(**kwargs: Any) -> dict[str, Any]: - """ - Utility to add in some common default values so they don't have to be set every time. - """ - return { - "temperature": 0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - **({"deployment_id": AZURE_DEPLOYMENT_ID} if AZURE_DEPLOYMENT_ID else {}), - **kwargs, - } - - -def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F: - @wraps(openai_call) - def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any: - try: - # if streamed, the call returns a generator - if kwargs.get("stream"): - - def _generator() -> Generator[Any, None, None]: - yield from openai_call(*args, **kwargs) - - return _generator() - return openai_call(*args, **kwargs) - except AuthenticationError: - logger.exception("Failed to authenticate with OpenAI API") - raise - except Timeout: - logger.exception("OpenAI API timed out for query: %s", query) - raise - except Exception: - logger.exception("Unexpected error with OpenAI API for query: %s", query) - raise - - return cast(F, wrapped_call) - - -def _tiktoken_trim_chunks( - chunks: list[InferenceChunk], model_version: str, max_chunk_toks: int = 512 -) -> list[InferenceChunk]: - """Edit chunks that have too high token count. Generally due to parsing issues or - characters from another language that are 1 char = 1 token - Trimming by tokens leads to information loss but currently no better way of handling - """ - encoder = tiktoken.encoding_for_model(model_version) - new_chunks = copy(chunks) - for ind, chunk in enumerate(new_chunks): - tokens = encoder.encode(chunk.content) - if len(tokens) > max_chunk_toks: - new_chunk = copy(chunk) - new_chunk.content = encoder.decode(tokens[:max_chunk_toks]) - new_chunks[ind] = new_chunk - return new_chunks - - -# used to check if the QAModel is an OpenAI model -class OpenAIQAModel(QAModel, ABC): - pass - - -class OpenAICompletionQA(OpenAIQAModel): - def __init__( - self, - prompt_processor: NonChatPromptProcessor = JsonProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - api_key: str | None = None, - timeout: int | None = None, - include_metadata: bool = INCLUDE_METADATA, - ) -> None: - self.prompt_processor = prompt_processor - self.model_version = model_version - self.max_output_tokens = max_output_tokens - self.timeout = timeout - self.include_metadata = include_metadata - try: - self.api_key = api_key or get_gen_ai_api_key() - except ConfigNotFoundError: - raise OpenAIKeyMissing() - - @staticmethod - def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]: - for event in response: - yield event["choices"][0]["text"] - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused - ) -> AnswerQuestionReturn: - context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) - - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.Completion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=_ensure_openai_api_key(self.api_key), - prompt=filled_prompt, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - ), - ) - model_output = cast(str, response["choices"][0]["text"]).strip() - logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")) - logger.debug(model_output) - - return process_answer(model_output, context_docs) - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) - - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.Completion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=_ensure_openai_api_key(self.api_key), - prompt=filled_prompt, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - stream=True, - ), - ) - - tokens = self._generate_tokens_from_response(response) - - yield from process_model_tokens( - tokens=tokens, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index 4d9f5b3b00ff..b96736fba5d5 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -28,7 +28,7 @@ from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor from danswer.direct_qa.qa_utils import process_answer from danswer.direct_qa.qa_utils import process_model_tokens from danswer.indexing.models import InferenceChunk -from danswer.llm.llm import LLM +from danswer.llm.interfaces import LLM from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import get_default_llm_tokenizer diff --git a/backend/danswer/direct_qa/request_model.py b/backend/danswer/direct_qa/request_model.py deleted file mode 100644 index 32f98123492a..000000000000 --- a/backend/danswer/direct_qa/request_model.py +++ /dev/null @@ -1,272 +0,0 @@ -import abc -import json -from collections.abc import Callable -from collections.abc import Generator - -import requests -from requests.exceptions import Timeout -from requests.models import Response - -from danswer.configs.constants import ModelHostType -from danswer.configs.model_configs import GEN_AI_API_KEY -from danswer.configs.model_configs import GEN_AI_ENDPOINT -from danswer.configs.model_configs import GEN_AI_HOST_TYPE -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.direct_qa.interfaces import AnswerQuestionReturn -from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.models import LLMMetricsContainer -from danswer.direct_qa.qa_prompts import JsonProcessor -from danswer.direct_qa.qa_prompts import NonChatPromptProcessor -from danswer.direct_qa.qa_utils import process_answer -from danswer.direct_qa.qa_utils import process_model_tokens -from danswer.direct_qa.qa_utils import simulate_streaming_response -from danswer.indexing.models import InferenceChunk -from danswer.utils.logger import setup_logger -from danswer.utils.timing import log_function_time - - -logger = setup_logger() - - -class HostSpecificRequestModel(abc.ABC): - """Provides a more minimal implementation requirement for extending to new models - hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics - """ - - @property - def requires_api_key(self) -> bool: - """Is this model protected by security features - Does it need an api key to access the model for inference""" - return True - - @staticmethod - @abc.abstractmethod - def send_model_request( - filled_prompt: str, - endpoint: str, - api_key: str | None, - max_output_tokens: int, - stream: bool, - timeout: int | None, - ) -> Response: - """Given a filled out prompt, how to send it to the model API with the - correct request format with the correct parameters""" - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def extract_model_output_from_response( - response: Response, - ) -> str: - """Extract the full model output text from a response. - This is for nonstreaming endpoints""" - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def generate_model_tokens_from_response( - response: Response, - ) -> Generator[str, None, None]: - """Generate tokens from a streaming response - This is for streaming endpoints""" - raise NotImplementedError - - -class HuggingFaceRequestModel(HostSpecificRequestModel): - @staticmethod - def send_model_request( - filled_prompt: str, - endpoint: str, - api_key: str | None, - max_output_tokens: int, - stream: bool, # Not supported by Inference Endpoints (as of Aug 2023) - timeout: int | None, - ) -> Response: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } - - data = { - "inputs": filled_prompt, - "parameters": { - # HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive - "temperature": 0.01, - # Skip the long tail - "top_p": 0.9, - "max_new_tokens": max_output_tokens, - }, - } - try: - return requests.post(endpoint, headers=headers, json=data, timeout=timeout) - except Timeout as error: - raise Timeout(f"Model inference to {endpoint} timed out") from error - - @staticmethod - def _hf_extract_model_output( - response: Response, - ) -> str: - if response.status_code != 200: - response.raise_for_status() - - return json.loads(response.content)[0].get("generated_text", "") - - @staticmethod - def extract_model_output_from_response( - response: Response, - ) -> str: - return HuggingFaceRequestModel._hf_extract_model_output(response) - - @staticmethod - def generate_model_tokens_from_response( - response: Response, - ) -> Generator[str, None, None]: - """HF endpoints do not do streaming currently so this function will - simulate streaming for the meantime but will need to be replaced in - the future once streaming is enabled.""" - model_out = HuggingFaceRequestModel._hf_extract_model_output(response) - yield from simulate_streaming_response(model_out) - - -class ColabDemoRequestModel(HostSpecificRequestModel): - """Guide found at: - https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183 - """ - - @property - def requires_api_key(self) -> bool: - return False - - @staticmethod - def send_model_request( - filled_prompt: str, - endpoint: str, - api_key: str | None, # ngrok basic setup doesn't require this - max_output_tokens: int, - stream: bool, - timeout: int | None, - ) -> Response: - headers = { - "Content-Type": "application/json", - } - - data = { - "inputs": filled_prompt, - "parameters": { - "temperature": 0.0, - "max_tokens": max_output_tokens, - }, - } - try: - return requests.post(endpoint, headers=headers, json=data, timeout=timeout) - except Timeout as error: - raise Timeout(f"Model inference to {endpoint} timed out") from error - - @staticmethod - def _colab_demo_extract_model_output( - response: Response, - ) -> str: - if response.status_code != 200: - response.raise_for_status() - - return json.loads(response.content).get("generated_text", "") - - @staticmethod - def extract_model_output_from_response( - response: Response, - ) -> str: - return ColabDemoRequestModel._colab_demo_extract_model_output(response) - - @staticmethod - def generate_model_tokens_from_response( - response: Response, - ) -> Generator[str, None, None]: - model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response) - yield from simulate_streaming_response(model_out) - - -def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel: - if model_host_type == ModelHostType.HUGGINGFACE.value: - return HuggingFaceRequestModel() - if model_host_type == ModelHostType.COLAB_DEMO.value: - return ColabDemoRequestModel() - else: - # TODO support Azure, GCP, AWS - raise ValueError("Invalid model hosting service selected") - - -class RequestCompletionQA(QAModel): - def __init__( - self, - endpoint: str = GEN_AI_ENDPOINT, - model_host_type: str = GEN_AI_HOST_TYPE, - api_key: str | None = GEN_AI_API_KEY, - prompt_processor: NonChatPromptProcessor = JsonProcessor(), - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - timeout: int | None = None, - ) -> None: - self.endpoint = endpoint - self.api_key = api_key - self.prompt_processor = prompt_processor - self.max_output_tokens = max_output_tokens - self.model_class = get_host_specific_model_class(model_host_type) - self.timeout = timeout - - @property - def requires_api_key(self) -> bool: - return self.model_class.requires_api_key - - def _get_request_response( - self, query: str, context_docs: list[InferenceChunk], stream: bool - ) -> Response: - filled_prompt = self.prompt_processor.fill_prompt( - query, context_docs, include_metadata=False - ) - logger.debug(filled_prompt) - - return self.model_class.send_model_request( - filled_prompt, - self.endpoint, - self.api_key, - self.max_output_tokens, - stream, - self.timeout, - ) - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused - ) -> AnswerQuestionReturn: - model_api_response = self._get_request_response( - query, context_docs, stream=False - ) - - model_output = self.model_class.extract_model_output_from_response( - model_api_response - ) - logger.debug(model_output) - - return process_answer(model_output, context_docs) - - def answer_question_stream( - self, - query: str, - context_docs: list[InferenceChunk], - ) -> AnswerQuestionStreamReturn: - model_api_response = self._get_request_response( - query, context_docs, stream=False - ) - - token_generator = self.model_class.generate_model_tokens_from_response( - model_api_response - ) - - yield from process_model_tokens( - tokens=token_generator, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) diff --git a/backend/danswer/llm/__init__.py b/backend/danswer/llm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/backend/danswer/llm/azure.py b/backend/danswer/llm/azure.py deleted file mode 100644 index e2ec1fa35c10..000000000000 --- a/backend/danswer/llm/azure.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any - -from langchain.chat_models.azure_openai import AzureChatOpenAI - -from danswer.configs.model_configs import API_BASE_OPENAI -from danswer.configs.model_configs import API_VERSION_OPENAI -from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID -from danswer.llm.llm import LangChainChatLLM -from danswer.llm.utils import should_be_verbose - - -class AzureGPT(LangChainChatLLM): - def __init__( - self, - api_key: str, - max_output_tokens: int, - timeout: int, - model_version: str, - api_base: str = API_BASE_OPENAI, - api_version: str = API_VERSION_OPENAI, - deployment_name: str = AZURE_DEPLOYMENT_ID, - *args: list[Any], - **kwargs: dict[str, Any] - ): - self._llm = AzureChatOpenAI( - model=model_version, - openai_api_type="azure", - openai_api_base=api_base, - openai_api_version=api_version, - deployment_name=deployment_name, - openai_api_key=api_key, - max_tokens=max_output_tokens, - temperature=0, - request_timeout=timeout, - model_kwargs={ - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - }, - verbose=should_be_verbose(), - max_retries=0, # retries are handled outside of langchain - ) - - @property - def llm(self) -> AzureChatOpenAI: - return self._llm diff --git a/backend/danswer/llm/build.py b/backend/danswer/llm/build.py index 5ec9fae9e980..e916323786f2 100644 --- a/backend/danswer/llm/build.py +++ b/backend/danswer/llm/build.py @@ -1,48 +1,25 @@ from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.constants import DanswerGenAIModel -from danswer.configs.constants import ModelHostType -from danswer.configs.model_configs import API_TYPE_OPENAI -from danswer.configs.model_configs import GEN_AI_ENDPOINT -from danswer.configs.model_configs import GEN_AI_HOST_TYPE -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.direct_qa.qa_utils import get_gen_ai_api_key -from danswer.llm.azure import AzureGPT -from danswer.llm.google_colab_demo import GoogleColabDemo -from danswer.llm.llm import LLM -from danswer.llm.openai import OpenAIGPT +from danswer.llm.custom_llm import CustomModelServer +from danswer.llm.gpt_4_all import DanswerGPT4All +from danswer.llm.interfaces import LLM +from danswer.llm.multi_llm import DefaultMultiLLM def get_default_llm( api_key: str | None = None, timeout: int = QA_TIMEOUT, ) -> LLM: - """NOTE: api_key/timeout must be a special args since we may want to check - if an API key is valid for the default model setup OR we may want to use the - default model with a different timeout specified.""" + """A single place to fetch the configured LLM for Danswer + Also allows overriding certain LLM defaults""" if api_key is None: api_key = get_gen_ai_api_key() - model_args = { - # provide a dummy key since LangChain will throw an exception if not - # given, which would prevent server startup - "api_key": api_key or "dummy_api_key", - "timeout": timeout, - "model_version": GEN_AI_MODEL_VERSION, - "endpoint": GEN_AI_ENDPOINT, - "max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS, - "temperature": GEN_AI_TEMPERATURE, - } - if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value: - if API_TYPE_OPENAI == "azure": - return AzureGPT(**model_args) # type: ignore - return OpenAIGPT(**model_args) # type: ignore - if ( - INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value - and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO - ): - return GoogleColabDemo(**model_args) # type: ignore + if GEN_AI_MODEL_PROVIDER.lower() == "custom": + return CustomModelServer(api_key=api_key, timeout=timeout) - raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}") + if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all": + DanswerGPT4All(timeout=timeout) + + return DefaultMultiLLM(api_key=api_key, timeout=timeout) diff --git a/backend/danswer/llm/google_colab_demo.py b/backend/danswer/llm/custom_llm.py similarity index 54% rename from backend/danswer/llm/google_colab_demo.py rename to backend/danswer/llm/custom_llm.py index d1bdcf390dcf..9d997e2cbc11 100644 --- a/backend/danswer/llm/google_colab_demo.py +++ b/backend/danswer/llm/custom_llm.py @@ -1,24 +1,36 @@ import json from collections.abc import Iterator -from typing import Any import requests from langchain.schema.language_model import LanguageModelInput from requests import Timeout -from danswer.llm.llm import LLM -from danswer.llm.utils import convert_input +from danswer.configs.model_configs import GEN_AI_API_ENDPOINT +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.llm.interfaces import LLM +from danswer.llm.utils import convert_lm_input_to_basic_string -class GoogleColabDemo(LLM): +class CustomModelServer(LLM): + """This class is to provide an example for how to use Danswer + with any LLM, even servers with custom API definitions. + To use with your own model server, simply implement the functions + below to fit your model server expectation""" + def __init__( self, - endpoint: str, - max_output_tokens: int, + # Not used here but you probably want a model server that isn't completely open + api_key: str | None, timeout: int, - *args: list[Any], - **kwargs: dict[str, Any], + endpoint: str | None = GEN_AI_API_ENDPOINT, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, ): + if not endpoint: + raise ValueError( + "Cannot point Danswer to a custom LLM server without providing the " + "endpoint for the model server." + ) + self._endpoint = endpoint self._max_output_tokens = max_output_tokens self._timeout = timeout @@ -29,7 +41,7 @@ class GoogleColabDemo(LLM): } data = { - "inputs": convert_input(input), + "inputs": convert_lm_input_to_basic_string(input), "parameters": { "temperature": 0.0, "max_tokens": self._max_output_tokens, diff --git a/backend/danswer/llm/gpt_4_all.py b/backend/danswer/llm/gpt_4_all.py new file mode 100644 index 000000000000..57aeecc32609 --- /dev/null +++ b/backend/danswer/llm/gpt_4_all.py @@ -0,0 +1,60 @@ +from collections.abc import Iterator +from typing import Any + +from langchain.schema.language_model import LanguageModelInput + +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_TEMPERATURE +from danswer.llm.interfaces import LLM +from danswer.llm.utils import convert_lm_input_to_basic_string +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +class DummyGPT4All: + """In the case of import failure due to architectural incompatibilities, + this module does not raise exceptions during server startup, + as long as the module isn't actually used""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise RuntimeError("GPT4All library not installed.") + + +try: + from gpt4all import GPT4All # type:ignore +except ImportError: + # Setting a low log level because users get scared when they see this + logger.debug( + "GPT4All library not installed. " + "If you wish to run GPT4ALL (in memory) to power Danswer's " + "Generative AI features, please install gpt4all==2.0.2." + ) + GPT4All = DummyGPT4All + + +class DanswerGPT4All(LLM): + """Option to run an LLM locally, however this is significantly slower and + answers tend to be much worse""" + + def __init__( + self, + timeout: int, + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + temperature: float = GEN_AI_TEMPERATURE, + ): + self.timeout = timeout + self.max_output_tokens = max_output_tokens + self.temperature = temperature + self.gpt4all_model = GPT4All(model_version) + + def invoke(self, prompt: LanguageModelInput) -> str: + prompt_basic = convert_lm_input_to_basic_string(prompt) + return self.gpt4all_model.generate(prompt_basic) + + def stream(self, prompt: LanguageModelInput) -> Iterator[str]: + prompt_basic = convert_lm_input_to_basic_string(prompt) + return self.gpt4all_model.generate(prompt_basic, streaming=True) diff --git a/backend/danswer/llm/llm.py b/backend/danswer/llm/interfaces.py similarity index 100% rename from backend/danswer/llm/llm.py rename to backend/danswer/llm/interfaces.py diff --git a/backend/danswer/llm/multi_llm.py b/backend/danswer/llm/multi_llm.py new file mode 100644 index 000000000000..ae5ba2a19c1e --- /dev/null +++ b/backend/danswer/llm/multi_llm.py @@ -0,0 +1,74 @@ +import litellm # type:ignore +from langchain.chat_models import ChatLiteLLM + +from danswer.configs.model_configs import GEN_AI_API_ENDPOINT +from danswer.configs.model_configs import GEN_AI_API_VERSION +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_TEMPERATURE +from danswer.llm.interfaces import LangChainChatLLM +from danswer.llm.utils import should_be_verbose + + +# If a user configures a different model and it doesn't support all the same +# parameters like frequency and presence, just ignore them +litellm.drop_params = True +litellm.telemetry = False + + +def _get_model_str( + model_provider: str | None, + model_version: str | None, +) -> str: + if model_provider and model_version: + return model_provider + "/" + model_version + + if model_version: + # Litellm defaults to openai if no provider specified + # It's implicit so no need to specify here either + return model_version + + # User specified something wrong, just use Danswer default + return GEN_AI_MODEL_VERSION + + +class DefaultMultiLLM(LangChainChatLLM): + """Uses Litellm library to allow easy configuration to use a multitude of LLMs + See https://python.langchain.com/docs/integrations/chat/litellm""" + + DEFAULT_MODEL_PARAMS = { + "frequency_penalty": 0, + "presence_penalty": 0, + } + + def __init__( + self, + api_key: str | None, + timeout: int, + model_provider: str | None = GEN_AI_MODEL_PROVIDER, + model_version: str | None = GEN_AI_MODEL_VERSION, + api_base: str | None = GEN_AI_API_ENDPOINT, + api_version: str | None = GEN_AI_API_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + temperature: float = GEN_AI_TEMPERATURE, + ): + # Litellm Langchain integration currently doesn't take in the api key param + # Can place this in the call below once integration is in + litellm.api_key = api_key + litellm.api_version = api_version + + self._llm = ChatLiteLLM( # type: ignore + model=_get_model_str(model_provider, model_version), + api_base=api_base, + max_tokens=max_output_tokens, + temperature=temperature, + request_timeout=timeout, + model_kwargs=DefaultMultiLLM.DEFAULT_MODEL_PARAMS, + verbose=should_be_verbose(), + max_retries=0, # retries are handled outside of langchain + ) + + @property + def llm(self) -> ChatLiteLLM: + return self._llm diff --git a/backend/danswer/llm/openai.py b/backend/danswer/llm/openai.py deleted file mode 100644 index 48673aa22c44..000000000000 --- a/backend/danswer/llm/openai.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Any -from typing import cast - -from langchain.chat_models import ChatLiteLLM -import litellm # type:ignore - -from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.llm.llm import LangChainChatLLM -from danswer.llm.utils import should_be_verbose - - -# If a user configures a different model and it doesn't support all the same -# parameters like frequency and presence, just ignore them -litellm.drop_params=True - - -class OpenAIGPT(LangChainChatLLM): - - DEFAULT_MODEL_PARAMS = { - "frequency_penalty": 0, - "presence_penalty": 0, - } - - def __init__( - self, - api_key: str, - max_output_tokens: int, - timeout: int, - model_version: str, - temperature: float = GEN_AI_TEMPERATURE, - *args: list[Any], - **kwargs: dict[str, Any] - ): - litellm.api_key = api_key - - self._llm = ChatLiteLLM( # type: ignore - model=model_version, - # Prefer using None which is the default value, endpoint could be empty string - api_base=cast(str, kwargs.get("endpoint")) or None, - max_tokens=max_output_tokens, - temperature=temperature, - request_timeout=timeout, - model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS, - verbose=should_be_verbose(), - max_retries=0, # retries are handled outside of langchain - ) - - @property - def llm(self) -> ChatLiteLLM: - return self._llm diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index b3a48ae67739..93db51f25c9a 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -22,6 +22,7 @@ _LLM_TOKENIZER: Callable[[str], Any] | None = None def get_default_llm_tokenizer() -> Callable: + """Currently only supports the OpenAI default tokenizer: tiktoken""" global _LLM_TOKENIZER if _LLM_TOKENIZER is None: _LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode @@ -71,14 +72,7 @@ def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]: return [HumanMessage(content=message)] -def message_generator_to_string_generator( - messages: Iterator[BaseMessageChunk], -) -> Iterator[str]: - for message in messages: - yield message.content - - -def convert_input(lm_input: LanguageModelInput) -> str: +def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str: """Heavily inspired by: https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86 """ @@ -99,6 +93,13 @@ def convert_input(lm_input: LanguageModelInput) -> str: return prompt_value.to_string() +def message_generator_to_string_generator( + messages: Iterator[BaseMessageChunk], +) -> Iterator[str]: + for message in messages: + yield message.content + + def should_be_verbose() -> bool: return LOG_LEVEL == "debug" diff --git a/backend/danswer/main.py b/backend/danswer/main.py index cdb621709a09..82f0ac26e098 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -22,13 +22,12 @@ from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import AuthType -from danswer.configs.model_configs import API_BASE_OPENAI -from danswer.configs.model_configs import API_TYPE_OPENAI from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL +from danswer.configs.model_configs import GEN_AI_API_ENDPOINT +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.configs.model_configs import SKIP_RERANKING from danswer.db.credentials import create_initial_public_credential from danswer.direct_qa.llm_utils import get_default_qa_model @@ -152,14 +151,6 @@ def get_application() -> FastAPI: warm_up_models, ) - if DISABLE_GENERATIVE_AI: - logger.info("Generative AI Q&A disabled") - else: - logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}") - logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}") - if API_TYPE_OPENAI == "azure": - logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}") - verify_auth = fetch_versioned_implementation( "danswer.auth.users", "verify_auth_setting" ) @@ -169,6 +160,14 @@ def get_application() -> FastAPI: if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET: logger.info("Both OAuth Client ID and Secret are configured.") + if DISABLE_GENERATIVE_AI: + logger.info("Generative AI Q&A disabled") + else: + logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}") + logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}") + if GEN_AI_API_ENDPOINT: + logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") + if SKIP_RERANKING: logger.info("Reranking step of search flow is disabled") diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 4108d5aed944..c496da1c6e96 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -25,7 +25,7 @@ from danswer.db.feedback import update_document_hidden from danswer.db.models import User from danswer.direct_qa.llm_utils import check_model_api_key_is_valid from danswer.direct_qa.llm_utils import get_default_qa_model -from danswer.direct_qa.open_ai import get_gen_ai_api_key +from danswer.direct_qa.qa_utils import get_gen_ai_api_key from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.server.models import ApiKey diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index fa85a2e326d9..4682a5ba37ce 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -19,8 +19,6 @@ from danswer.db.feedback import update_query_event_feedback from danswer.db.feedback import update_query_event_retrieved_documents from danswer.db.models import User from danswer.direct_qa.answer_question import answer_qa_query -from danswer.direct_qa.exceptions import OpenAIKeyMissing -from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.llm_utils import get_default_qa_model @@ -302,7 +300,7 @@ def stream_direct_qa( try: qa_model = get_default_qa_model() - except (UnknownModelError, OpenAIKeyMissing) as e: + except Exception as e: logger.exception("Unable to get QA model") error = StreamingError(error=str(e)) yield get_json_line(error.dict()) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index dbe2619a6aa1..5d276647ab69 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -13,9 +13,9 @@ filelock==3.12.0 google-api-python-client==2.86.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 -# GPT4All library does not support M1 Mac architecture +# GPT4All library has issues running on Macs and python:3.11.4-slim-bookworm # will reintroduce this when library version catches up -# gpt4all==1.0.5 +# gpt4all==2.0.2 httpcore==0.16.3 httpx==0.23.3 httpx-oauth==0.11.2 diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index ac2fd220fd65..b2ae2aa1430a 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -16,11 +16,11 @@ services: ports: - "8080:8080" environment: - - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} + - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} - - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} + - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} + - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-} - POSTGRES_HOST=relational_db - VESPA_HOST=index @@ -30,10 +30,6 @@ services: - GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-} - GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} - - API_BASE_OPENAI=${API_BASE_OPENAI:-} - - API_TYPE_OPENAI=${API_TYPE_OPENAI:-} - - API_VERSION_OPENAI=${API_VERSION_OPENAI:-} - - AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - DISABLE_TIME_FILTER_EXTRACTION=${DISABLE_TIME_FILTER_EXTRACTION:-} # Don't change the NLP model configs unless you know what you're doing @@ -63,17 +59,13 @@ services: - index restart: always environment: - - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} + - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} - - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} + - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} + - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - POSTGRES_HOST=relational_db - VESPA_HOST=index - - API_BASE_OPENAI=${API_BASE_OPENAI:-} - - API_TYPE_OPENAI=${API_TYPE_OPENAI:-} - - API_VERSION_OPENAI=${API_VERSION_OPENAI:-} - - AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-} - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} # Connector Configs - CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-} diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index a652d7e370c3..1f600b63b381 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -3,57 +3,57 @@ # This is only necessary when using the docker-compose.prod.yml compose file. -# Insert your OpenAI API key here If not provided here, UI will prompt on setup. -# This env variable takes precedence over UI settings. -GEN_AI_API_KEY= -# Choose between "openai-chat-completion" and "openai-completion" -INTERNAL_MODEL_VERSION=openai-chat-completion -# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility -GEN_AI_MODEL_VERSION=gpt-4 - -# Neccessary environment variables for Azure OpenAI: -API_BASE_OPENAI= -API_TYPE_OPENAI= -API_VERSION_OPENAI= -AZURE_DEPLOYMENT_ID= - # Could be something like danswer.companyname.com WEB_DOMAIN=http://localhost:3000 -# Default values here are what Postgres uses by default, feel free to change. -POSTGRES_USER=postgres -POSTGRES_PASSWORD=password + +# Generative AI settings, uncomment as needed, will work with defaults +GEN_AI_MODEL_PROVIDER=openai +GEN_AI_MODEL_VERSION=gpt-4 +# Provide this as a global default/backup, this can also be set via the UI +#GEN_AI_API_KEY= +# Set to use Azure OpenAI or other services, such as https://danswer.openai.azure.com/ +#GEN_AI_API_ENDPOINT= +# Set up to use a specific API version, such as 2023-09-15-preview (example taken from Azure) +#GEN_AI_API_VERSION= + # If you want to setup a slack bot to answer questions automatically in Slack -# channels it is added to, you must specify the below. +# channels it is added to, you must specify the two below. # More information in the guide here: https://docs.danswer.dev/slack_bot_setup -DANSWER_BOT_SLACK_APP_TOKEN= -DANSWER_BOT_SLACK_BOT_TOKEN= +#DANSWER_BOT_SLACK_APP_TOKEN= +#DANSWER_BOT_SLACK_BOT_TOKEN= -# Used to generate values for security verification, use a random string -SECRET= - -# How long before user needs to reauthenticate, default to 1 day. (cookie expiration time) -SESSION_EXPIRE_TIME_SECONDS=86400 # The following are for configuring User Authentication, supported flows are: # disabled # google_oauth (login with google/gmail account) # oidc (only in Danswer enterprise edition) # saml (only in Danswer enterprise edition) -AUTH_TYPE= +AUTH_TYPE=google_oauth -# Set the two below to use with Google OAuth +# Set the values below to use with Google OAuth GOOGLE_OAUTH_CLIENT_ID= GOOGLE_OAUTH_CLIENT_SECRET= +SECRET= # OpenID Connect (OIDC) -OPENID_CONFIG_URL= +#OPENID_CONFIG_URL= # SAML config directory for OneLogin compatible setups -SAML_CONF_DIR= +#SAML_CONF_DIR= -# used to specify a list of allowed user domains, only checked if user Auth is turned on + +# How long before user needs to reauthenticate, default to 1 day. (cookie expiration time) +SESSION_EXPIRE_TIME_SECONDS=86400 + + +# Use the below to specify a list of allowed user domains, only checked if user Auth is turned on # e.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will only allow users # with an @example.com or an @example.org email -VALID_EMAIL_DOMAINS= +#VALID_EMAIL_DOMAINS= + + +# Default values here are what Postgres uses by default, feel free to change. +POSTGRES_USER=postgres +POSTGRES_PASSWORD=password