mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-28 04:49:21 +02:00
Reworking the LLM layer (#666)
This commit is contained in:
@@ -34,7 +34,7 @@ from danswer.direct_qa.qa_utils import get_usable_chunks
|
|||||||
from danswer.document_index import get_default_document_index
|
from danswer.document_index import get_default_document_index
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.llm.build import get_default_llm
|
from danswer.llm.build import get_default_llm
|
||||||
from danswer.llm.llm import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import get_default_llm_tokenizer
|
from danswer.llm.utils import get_default_llm_tokenizer
|
||||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
from danswer.search.access_filters import build_access_filters_for_user
|
||||||
|
@@ -77,21 +77,6 @@ class DocumentIndexType(str, Enum):
|
|||||||
SPLIT = "split" # Typesense + Qdrant
|
SPLIT = "split" # Typesense + Qdrant
|
||||||
|
|
||||||
|
|
||||||
class DanswerGenAIModel(str, Enum):
|
|
||||||
"""This represents the internal Danswer GenAI model which determines the class that is used
|
|
||||||
to generate responses to the user query. Different models/services require different internal
|
|
||||||
handling, this allows for modularity of implementation within Danswer"""
|
|
||||||
|
|
||||||
OPENAI = "openai-completion"
|
|
||||||
OPENAI_CHAT = "openai-chat-completion"
|
|
||||||
GPT4ALL = "gpt4all-completion"
|
|
||||||
GPT4ALL_CHAT = "gpt4all-chat-completion"
|
|
||||||
HUGGINGFACE = "huggingface-client-completion"
|
|
||||||
HUGGINGFACE_CHAT = "huggingface-client-chat-completion"
|
|
||||||
REQUEST = "request-completion"
|
|
||||||
TRANSFORMERS = "transformers"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthType(str, Enum):
|
class AuthType(str, Enum):
|
||||||
DISABLED = "disabled"
|
DISABLED = "disabled"
|
||||||
BASIC = "basic"
|
BASIC = "basic"
|
||||||
@@ -100,17 +85,6 @@ class AuthType(str, Enum):
|
|||||||
SAML = "saml"
|
SAML = "saml"
|
||||||
|
|
||||||
|
|
||||||
class ModelHostType(str, Enum):
|
|
||||||
"""For GenAI models interfaced via requests, different services have different
|
|
||||||
expectations for what fields are included in the request"""
|
|
||||||
|
|
||||||
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
|
||||||
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
|
|
||||||
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
|
|
||||||
COLAB_DEMO = "colab-demo"
|
|
||||||
# TODO support for Azure, AWS, GCP GenAI model hosting
|
|
||||||
|
|
||||||
|
|
||||||
class QAFeedbackType(str, Enum):
|
class QAFeedbackType(str, Enum):
|
||||||
LIKE = "like" # User likes the answer, used for metrics
|
LIKE = "like" # User likes the answer, used for metrics
|
||||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||||
|
@@ -1,9 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from danswer.configs.constants import DanswerGenAIModel
|
|
||||||
from danswer.configs.constants import ModelHostType
|
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Embedding/Reranking Model Configs
|
# Embedding/Reranking Model Configs
|
||||||
#####
|
#####
|
||||||
@@ -55,62 +51,38 @@ SEARCH_DISTANCE_CUTOFF = 0
|
|||||||
# Intent model max context size
|
# Intent model max context size
|
||||||
QUERY_MAX_CONTEXT_SIZE = 256
|
QUERY_MAX_CONTEXT_SIZE = 256
|
||||||
|
|
||||||
|
# Danswer custom Deep Learning Models
|
||||||
|
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Generative AI Model Configs
|
# Generative AI Model Configs
|
||||||
#####
|
#####
|
||||||
# Other models should work as well, check the library/API compatibility.
|
|
||||||
# But these are the models that have been verified to work with the existing prompts.
|
|
||||||
# Using a different model may require some prompt tuning. See qa_prompts.py
|
|
||||||
VERIFIED_MODELS = {
|
|
||||||
DanswerGenAIModel.OPENAI: ["text-davinci-003"],
|
|
||||||
DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"],
|
|
||||||
DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
|
||||||
DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
|
||||||
# The "chat" model below is actually "instruction finetuned" and does not support conversational
|
|
||||||
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
|
|
||||||
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
|
|
||||||
# Created by Deepset.ai
|
|
||||||
# https://huggingface.co/deepset/deberta-v3-large-squad2
|
|
||||||
# Model provided with no modifications
|
|
||||||
DanswerGenAIModel.TRANSFORMERS.value: ["deepset/deberta-v3-large-squad2"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Sets the internal Danswer model class to use
|
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
|
||||||
INTERNAL_MODEL_VERSION = os.environ.get(
|
# be sure to use one that is LiteLLM compatible:
|
||||||
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
|
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
|
||||||
)
|
# The provider is the prefix before / in the model argument
|
||||||
|
|
||||||
|
# Additionally Danswer supports GPT4All and custom request library based models
|
||||||
|
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
|
||||||
|
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
|
||||||
|
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||||
|
# If using Azure, it's the engine name, for example: Danswer
|
||||||
|
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
|
||||||
|
|
||||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
GEN_AI_API_KEY = (
|
||||||
|
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
|
||||||
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
|
|
||||||
GEN_AI_MODEL_VERSION = os.environ.get(
|
|
||||||
"GEN_AI_MODEL_VERSION",
|
|
||||||
VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
|
# API Base, such as (for Azure): https://danswer.openai.azure.com/
|
||||||
# set the two below to specify
|
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||||
# - Where to hit the endpoint
|
# API Version, such as (for Azure): 2023-09-15-preview
|
||||||
# - How should the request be formed
|
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||||
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
|
|
||||||
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
|
|
||||||
|
|
||||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||||
# This next restriction is only used for chat ATM, used to expire old messages as needed
|
# This next restriction is only used for chat ATM, used to expire old messages as needed
|
||||||
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
|
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
|
||||||
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
||||||
|
|
||||||
# Danswer custom Deep Learning Models
|
|
||||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
|
||||||
|
|
||||||
#####
|
|
||||||
# OpenAI Azure
|
|
||||||
#####
|
|
||||||
API_BASE_OPENAI = os.environ.get("API_BASE_OPENAI", "")
|
|
||||||
API_TYPE_OPENAI = os.environ.get("API_TYPE_OPENAI", "").lower()
|
|
||||||
API_VERSION_OPENAI = os.environ.get("API_VERSION_OPENAI", "")
|
|
||||||
# Deployment ID used interchangeably with "engine" parameter
|
|
||||||
AZURE_DEPLOYMENT_ID = os.environ.get("AZURE_DEPLOYMENT_ID", "")
|
|
||||||
|
@@ -9,8 +9,6 @@ from danswer.configs.constants import IGNORE_FOR_QA
|
|||||||
from danswer.db.feedback import create_query_event
|
from danswer.db.feedback import create_query_event
|
||||||
from danswer.db.feedback import update_query_event_retrieved_documents
|
from danswer.db.feedback import update_query_event_retrieved_documents
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
from danswer.direct_qa.models import LLMMetricsContainer
|
||||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||||
@@ -132,7 +130,7 @@ def answer_qa_query(
|
|||||||
qa_model = get_default_qa_model(
|
qa_model = get_default_qa_model(
|
||||||
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
||||||
)
|
)
|
||||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
except Exception as e:
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=None,
|
answer=None,
|
||||||
quotes=None,
|
quotes=None,
|
||||||
|
@@ -1,13 +0,0 @@
|
|||||||
class OpenAIKeyMissing(Exception):
|
|
||||||
default_msg = (
|
|
||||||
"Unable to find existing OpenAI Key. "
|
|
||||||
'A new key can be added from "Keys" section of the Admin Panel'
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, msg: str = default_msg) -> None:
|
|
||||||
super().__init__(msg)
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownModelError(Exception):
|
|
||||||
def __init__(self, model_name: str) -> None:
|
|
||||||
super().__init__(f"Unknown Internal QA model name: {model_name}")
|
|
@@ -1,209 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
|
||||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
|
||||||
from danswer.direct_qa.qa_utils import process_answer
|
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class DummyGPT4All:
|
|
||||||
"""In the case of import failure due to M1 Mac incompatibility,
|
|
||||||
so this module does not raise exceptions during server startup,
|
|
||||||
as long as this module isn't actually used"""
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
raise RuntimeError("GPT4All library not installed.")
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from gpt4all import GPT4All # type:ignore
|
|
||||||
except ImportError:
|
|
||||||
logger.warning(
|
|
||||||
"GPT4All library not installed. "
|
|
||||||
"If you wish to run GPT4ALL (in memory) to power Danswer's "
|
|
||||||
"Generative AI features, please install gpt4all==1.0.5. "
|
|
||||||
"As of Aug 2023, this library is not compatible with M1 Mac."
|
|
||||||
)
|
|
||||||
GPT4All = DummyGPT4All
|
|
||||||
|
|
||||||
|
|
||||||
GPT4ALL_MODEL: GPT4All | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpt_4_all_model(
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
) -> GPT4All:
|
|
||||||
global GPT4ALL_MODEL
|
|
||||||
if GPT4ALL_MODEL is None:
|
|
||||||
GPT4ALL_MODEL = GPT4All(model_version)
|
|
||||||
return GPT4ALL_MODEL
|
|
||||||
|
|
||||||
|
|
||||||
def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Utility to add in some common default values so they don't have to be set every time.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"temp": 0,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GPT4AllCompletionQA(QAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(),
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
include_metadata: bool = False, # gpt4all models can't handle this atm
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.model_version = model_version
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_api_key(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def warm_up_model(self) -> None:
|
|
||||||
get_gpt_4_all_model(self.model_version)
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
|
||||||
|
|
||||||
model_output = gen_ai_model.generate(
|
|
||||||
**_build_gpt4all_settings(
|
|
||||||
prompt=filled_prompt, max_tokens=self.max_output_tokens
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
return process_answer(model_output, context_docs)
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
|
||||||
|
|
||||||
model_stream = gen_ai_model.generate(
|
|
||||||
**_build_gpt4all_settings(
|
|
||||||
prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=model_stream,
|
|
||||||
context_docs=context_docs,
|
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPT4AllChatCompletionQA(QAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(),
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
include_metadata: bool = False, # gpt4all models can't handle this atm
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.model_version = model_version
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_api_key(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def warm_up_model(self) -> None:
|
|
||||||
get_gpt_4_all_model(self.model_version)
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
|
||||||
|
|
||||||
with gen_ai_model.chat_session():
|
|
||||||
context_msgs = filled_prompt[:-1]
|
|
||||||
user_query = filled_prompt[-1].get("content")
|
|
||||||
for message in context_msgs:
|
|
||||||
gen_ai_model.current_chat_session.append(message)
|
|
||||||
|
|
||||||
model_output = gen_ai_model.generate(
|
|
||||||
**_build_gpt4all_settings(
|
|
||||||
prompt=user_query, max_tokens=self.max_output_tokens
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
return process_answer(model_output, context_docs)
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
|
||||||
|
|
||||||
with gen_ai_model.chat_session():
|
|
||||||
context_msgs = filled_prompt[:-1]
|
|
||||||
user_query = filled_prompt[-1].get("content")
|
|
||||||
for message in context_msgs:
|
|
||||||
gen_ai_model.current_chat_session.append(message)
|
|
||||||
|
|
||||||
model_stream = gen_ai_model.generate(
|
|
||||||
**_build_gpt4all_settings(
|
|
||||||
prompt=user_query, max_tokens=self.max_output_tokens
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=model_stream,
|
|
||||||
context_docs=context_docs,
|
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
|
||||||
)
|
|
@@ -1,195 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient # type:ignore
|
|
||||||
from huggingface_hub.utils import HfHubHTTPError # type:ignore
|
|
||||||
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
|
||||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import FreeformProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
|
||||||
from danswer.direct_qa.qa_utils import process_answer
|
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
|
||||||
from danswer.direct_qa.qa_utils import simulate_streaming_response
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Utility to add in some common default values so they don't have to be set every time.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"do_sample": False,
|
|
||||||
"seed": 69, # For reproducibility
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceCompletionQA(QAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: NonChatPromptProcessor = FreeformProcessor(),
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
include_metadata: bool = False,
|
|
||||||
api_key: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
self.client = InferenceClient(model=model_version, token=api_key)
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
model_output = self.client.text_generation(
|
|
||||||
filled_prompt,
|
|
||||||
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
|
||||||
)
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
return process_answer(model_output, context_docs)
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
model_stream = self.client.text_generation(
|
|
||||||
filled_prompt,
|
|
||||||
**_build_hf_inference_settings(
|
|
||||||
max_new_tokens=self.max_output_tokens, stream=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=model_stream,
|
|
||||||
context_docs=context_docs,
|
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceChatCompletionQA(QAModel):
|
|
||||||
"""Chat in this class refers to the HuggingFace Conversational API.
|
|
||||||
Not to be confused with Chat/Instruction finetuned models.
|
|
||||||
Llama2-chat... means it is an Instruction finetuned model, not necessarily that
|
|
||||||
it supports Conversational API"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
include_metadata: bool = False,
|
|
||||||
api_key: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
self.client = InferenceClient(model=model_version, token=api_key)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _convert_chat_to_hf_conversational_format(
|
|
||||||
dialog: list[dict[str, str]]
|
|
||||||
) -> tuple[str, list[str], list[str]]:
|
|
||||||
if dialog[-1]["role"] != "user":
|
|
||||||
raise Exception(
|
|
||||||
"Last message in conversational dialog must be User message"
|
|
||||||
)
|
|
||||||
user_message = dialog[-1]["content"]
|
|
||||||
dialog = dialog[0:-1]
|
|
||||||
generated_responses = []
|
|
||||||
past_user_inputs = []
|
|
||||||
for message in dialog:
|
|
||||||
# HuggingFace inference client doesn't support system messages today
|
|
||||||
# so lumping them in with user messages
|
|
||||||
if message["role"] in ["user", "system"]:
|
|
||||||
past_user_inputs += [message["content"]]
|
|
||||||
else:
|
|
||||||
generated_responses += [message["content"]]
|
|
||||||
return user_message, generated_responses, past_user_inputs
|
|
||||||
|
|
||||||
def _get_hf_model_output(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> str:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
query,
|
|
||||||
past_responses,
|
|
||||||
past_inputs,
|
|
||||||
) = self._convert_chat_to_hf_conversational_format(filled_prompt)
|
|
||||||
|
|
||||||
logger.debug(f"Last Input: {query}")
|
|
||||||
logger.debug(f"Past Inputs: {past_inputs}")
|
|
||||||
logger.debug(f"Past Responses: {past_responses}")
|
|
||||||
try:
|
|
||||||
model_output = self.client.conversational(
|
|
||||||
query,
|
|
||||||
generated_responses=past_responses,
|
|
||||||
past_user_inputs=past_inputs,
|
|
||||||
parameters={"max_length": self.max_output_tokens},
|
|
||||||
)
|
|
||||||
except HfHubHTTPError as model_error:
|
|
||||||
if model_error.response.status_code == 422:
|
|
||||||
raise ValueError(
|
|
||||||
"Selected HuggingFace Model does not support HuggingFace Conversational API,"
|
|
||||||
"try using the huggingface-inference-completion in Danswer instead"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
model_output = self._get_hf_model_output(query, context_docs)
|
|
||||||
|
|
||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
||||||
|
|
||||||
return answer, quotes_dict
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
|
|
||||||
So here it is faked by streaming characters within Danswer from the model output
|
|
||||||
"""
|
|
||||||
model_output = self._get_hf_model_output(query, context_docs)
|
|
||||||
|
|
||||||
model_stream = simulate_streaming_response(model_output)
|
|
||||||
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=model_stream,
|
|
||||||
context_docs=context_docs,
|
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
|
||||||
)
|
|
@@ -52,6 +52,7 @@ class QAModel:
|
|||||||
def requires_api_key(self) -> bool:
|
def requires_api_key(self) -> bool:
|
||||||
"""Is this model protected by security features
|
"""Is this model protected by security features
|
||||||
Does it need an api key to access the model for inference"""
|
Does it need an api key to access the model for inference"""
|
||||||
|
# TODO, this should be false for custom request model and gpt4all
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def warm_up_model(self) -> None:
|
def warm_up_model(self) -> None:
|
||||||
|
@@ -1,32 +1,11 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
from openai.error import AuthenticationError
|
from openai.error import AuthenticationError
|
||||||
|
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerGenAIModel
|
|
||||||
from danswer.configs.constants import ModelHostType
|
|
||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
|
||||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
|
||||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
|
||||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
|
||||||
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
|
||||||
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
|
||||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.local_transformers import TransformerQA
|
|
||||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
|
||||||
from danswer.direct_qa.qa_block import QABlock
|
from danswer.direct_qa.qa_block import QABlock
|
||||||
from danswer.direct_qa.qa_block import QAHandler
|
from danswer.direct_qa.qa_block import QAHandler
|
||||||
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
|
||||||
from danswer.direct_qa.qa_block import SingleMessageQAHandler
|
from danswer.direct_qa.qa_block import SingleMessageQAHandler
|
||||||
from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler
|
from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
|
||||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
|
||||||
from danswer.llm.build import get_default_llm
|
from danswer.llm.build import get_default_llm
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
@@ -52,93 +31,23 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler:
|
# TODO introduce the prompt choice parameter
|
||||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
|
||||||
return (
|
return (
|
||||||
SingleMessageQAHandler()
|
SingleMessageQAHandler() if real_time_flow else SingleMessageScratchpadHandler()
|
||||||
if real_time_flow
|
)
|
||||||
else SingleMessageScratchpadHandler()
|
# return SimpleChatQAHandler()
|
||||||
)
|
|
||||||
|
|
||||||
return SimpleChatQAHandler()
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_qa_model(
|
def get_default_qa_model(
|
||||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
api_key: str | None = None,
|
||||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
|
||||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
|
||||||
api_key: str | None = GEN_AI_API_KEY,
|
|
||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
real_time_flow: bool = True,
|
real_time_flow: bool = True,
|
||||||
**kwargs: Any,
|
|
||||||
) -> QAModel:
|
) -> QAModel:
|
||||||
if not api_key:
|
llm = get_default_llm(api_key=api_key, timeout=timeout)
|
||||||
try:
|
qa_handler = get_default_qa_handler(real_time_flow=real_time_flow)
|
||||||
api_key = get_gen_ai_api_key()
|
|
||||||
except ConfigNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
return QABlock(
|
||||||
# un-used arguments will be ignored by the underlying `LLM` class
|
llm=llm,
|
||||||
# if any args are missing, a `TypeError` will be thrown
|
qa_handler=qa_handler,
|
||||||
llm = get_default_llm(timeout=timeout)
|
)
|
||||||
qa_handler = get_default_qa_handler(
|
|
||||||
model=internal_model, real_time_flow=real_time_flow
|
|
||||||
)
|
|
||||||
|
|
||||||
return QABlock(
|
|
||||||
llm=llm,
|
|
||||||
qa_handler=qa_handler,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Unable to build a QABlock with the new approach, going back to the "
|
|
||||||
"legacy approach"
|
|
||||||
)
|
|
||||||
|
|
||||||
if internal_model in [
|
|
||||||
DanswerGenAIModel.GPT4ALL.value,
|
|
||||||
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
|
||||||
]:
|
|
||||||
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
|
|
||||||
pkg_resources.get_distribution("gpt4all")
|
|
||||||
|
|
||||||
if internal_model == DanswerGenAIModel.OPENAI.value:
|
|
||||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
|
||||||
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
|
||||||
return GPT4AllCompletionQA(**kwargs)
|
|
||||||
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
|
||||||
return GPT4AllChatCompletionQA(**kwargs)
|
|
||||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
|
|
||||||
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
|
|
||||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
|
|
||||||
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
|
|
||||||
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
|
|
||||||
return TransformerQA()
|
|
||||||
elif internal_model == DanswerGenAIModel.REQUEST.value:
|
|
||||||
if endpoint is None or model_host_type is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Request based GenAI model requires an endpoint and host type"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
model_host_type == ModelHostType.HUGGINGFACE.value
|
|
||||||
or model_host_type == ModelHostType.COLAB_DEMO.value
|
|
||||||
):
|
|
||||||
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
|
||||||
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
|
||||||
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
|
||||||
return RequestCompletionQA(
|
|
||||||
endpoint=endpoint,
|
|
||||||
model_host_type=model_host_type,
|
|
||||||
api_key=api_key,
|
|
||||||
prompt_processor=WeakModelFreeformProcessor(),
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
return RequestCompletionQA(
|
|
||||||
endpoint=endpoint,
|
|
||||||
model_host_type=model_host_type,
|
|
||||||
api_key=api_key,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise UnknownModelError(internal_model)
|
|
||||||
|
@@ -1,156 +0,0 @@
|
|||||||
import re
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
from transformers import pipeline # type:ignore
|
|
||||||
from transformers import QuestionAnsweringPipeline # type:ignore
|
|
||||||
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
|
||||||
from danswer.direct_qa.interfaces import DanswerQuote
|
|
||||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
TRANSFORMER_DEFAULT_MAX_CONTEXT = 512
|
|
||||||
|
|
||||||
_TRANSFORMER_MODEL: QuestionAnsweringPipeline | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_transformer_model(
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
max_context: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
|
|
||||||
) -> QuestionAnsweringPipeline:
|
|
||||||
global _TRANSFORMER_MODEL
|
|
||||||
if _TRANSFORMER_MODEL is None:
|
|
||||||
_TRANSFORMER_MODEL = pipeline(
|
|
||||||
"question-answering", model=model_version, max_seq_len=max_context
|
|
||||||
)
|
|
||||||
|
|
||||||
return _TRANSFORMER_MODEL
|
|
||||||
|
|
||||||
|
|
||||||
def find_extended_answer(answer: str, context: str) -> str:
|
|
||||||
"""Try to extend the answer by matching across the context text and extending before
|
|
||||||
and after the quote to some termination character"""
|
|
||||||
result = re.search(
|
|
||||||
r"(^|\n\r?|\.)(?P<content>[^\n]{{0,250}}{}[^\n]{{0,250}})(\.|$|\n\r?)".format(
|
|
||||||
re.escape(answer)
|
|
||||||
),
|
|
||||||
context,
|
|
||||||
flags=re.MULTILINE | re.DOTALL,
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
return result.group("content")
|
|
||||||
|
|
||||||
return answer
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerQA(QAModel):
|
|
||||||
@staticmethod
|
|
||||||
def _answer_one_chunk(
|
|
||||||
query: str,
|
|
||||||
chunk: InferenceChunk,
|
|
||||||
max_context_len: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
|
|
||||||
max_cutoff: float = 0.9,
|
|
||||||
min_cutoff: float = 0.5,
|
|
||||||
) -> tuple[str | None, DanswerQuote | None]:
|
|
||||||
"""Because this type of QA model only takes 1 chunk of context with a fairly small token limit
|
|
||||||
We have to iterate the checks and check if the answer is found in any of the chunks.
|
|
||||||
This type of approach does not allow for interpolating answers across chunks
|
|
||||||
"""
|
|
||||||
model = get_default_transformer_model()
|
|
||||||
model_out = model(question=query, context=chunk.content, max_answer_len=128)
|
|
||||||
|
|
||||||
answer = model_out.get("answer")
|
|
||||||
confidence = model_out.get("score")
|
|
||||||
|
|
||||||
if answer is None:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
logger.info(f"Transformer Answer: {answer}")
|
|
||||||
logger.debug(f"Transformer Confidence: {confidence}")
|
|
||||||
|
|
||||||
# Model tends to be overconfident on short chunks
|
|
||||||
# so min score required increases as chunk size decreases
|
|
||||||
# If it's at least 0.9, then it's good enough regardless
|
|
||||||
# Default minimum of 0.5 required
|
|
||||||
score_cutoff = max(
|
|
||||||
min(max_cutoff, 1 - len(chunk.content) / max_context_len), min_cutoff
|
|
||||||
)
|
|
||||||
if confidence < score_cutoff:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
extended_answer = find_extended_answer(answer, chunk.content)
|
|
||||||
|
|
||||||
danswer_quote = DanswerQuote(
|
|
||||||
quote=answer,
|
|
||||||
document_id=chunk.document_id,
|
|
||||||
link=chunk.source_links[0] if chunk.source_links else None,
|
|
||||||
source_type=chunk.source_type,
|
|
||||||
semantic_identifier=chunk.semantic_identifier,
|
|
||||||
blurb=chunk.blurb,
|
|
||||||
)
|
|
||||||
|
|
||||||
return extended_answer, danswer_quote
|
|
||||||
|
|
||||||
def warm_up_model(self) -> None:
|
|
||||||
get_default_transformer_model()
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
danswer_quotes: list[DanswerQuote] = []
|
|
||||||
d_answers: list[str] = []
|
|
||||||
for chunk in context_docs:
|
|
||||||
answer, quote = self._answer_one_chunk(query, chunk)
|
|
||||||
if answer is not None and quote is not None:
|
|
||||||
d_answers.append(answer)
|
|
||||||
danswer_quotes.append(quote)
|
|
||||||
|
|
||||||
answers_list = [
|
|
||||||
f"Answer {ind}: {answer.strip()}"
|
|
||||||
for ind, answer in enumerate(d_answers, start=1)
|
|
||||||
]
|
|
||||||
combined_answer = "\n".join(answers_list)
|
|
||||||
return DanswerAnswer(answer=combined_answer), DanswerQuotes(
|
|
||||||
quotes=danswer_quotes
|
|
||||||
)
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
quotes: list[DanswerQuote] = []
|
|
||||||
answers: list[str] = []
|
|
||||||
for chunk in context_docs:
|
|
||||||
answer, quote = self._answer_one_chunk(query, chunk)
|
|
||||||
if answer is not None and quote is not None:
|
|
||||||
answers.append(answer)
|
|
||||||
quotes.append(quote)
|
|
||||||
|
|
||||||
# Delay the output of the answers so there isn't long gap between first answer and quotes
|
|
||||||
answer_count = 1
|
|
||||||
for answer in answers:
|
|
||||||
if answer_count == 1:
|
|
||||||
yield DanswerAnswerPiece(answer_piece="Source 1: ")
|
|
||||||
else:
|
|
||||||
yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ")
|
|
||||||
answer_count += 1
|
|
||||||
for char in answer.strip():
|
|
||||||
yield DanswerAnswerPiece(answer_piece=char)
|
|
||||||
|
|
||||||
# signal end of answer
|
|
||||||
yield DanswerAnswerPiece(answer_piece=None)
|
|
||||||
|
|
||||||
yield DanswerQuotes(quotes=quotes)
|
|
@@ -1,209 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Generator
|
|
||||||
from copy import copy
|
|
||||||
from functools import wraps
|
|
||||||
from typing import Any
|
|
||||||
from typing import cast
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import tiktoken
|
|
||||||
from openai.error import AuthenticationError
|
|
||||||
from openai.error import Timeout
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import INCLUDE_METADATA
|
|
||||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
|
||||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
|
||||||
from danswer.configs.model_configs import API_VERSION_OPENAI
|
|
||||||
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
|
||||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
|
||||||
from danswer.direct_qa.qa_utils import process_answer
|
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable)
|
|
||||||
|
|
||||||
|
|
||||||
if API_BASE_OPENAI:
|
|
||||||
openai.api_base = API_BASE_OPENAI
|
|
||||||
if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread"]
|
|
||||||
openai.api_type = API_TYPE_OPENAI
|
|
||||||
openai.api_version = API_VERSION_OPENAI
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_openai_api_key(api_key: str | None) -> str:
|
|
||||||
final_api_key = api_key or get_gen_ai_api_key()
|
|
||||||
if final_api_key is None:
|
|
||||||
raise OpenAIKeyMissing()
|
|
||||||
|
|
||||||
return final_api_key
|
|
||||||
|
|
||||||
|
|
||||||
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Utility to add in some common default values so they don't have to be set every time.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"temperature": 0,
|
|
||||||
"top_p": 1,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
**({"deployment_id": AZURE_DEPLOYMENT_ID} if AZURE_DEPLOYMENT_ID else {}),
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
|
||||||
@wraps(openai_call)
|
|
||||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
||||||
try:
|
|
||||||
# if streamed, the call returns a generator
|
|
||||||
if kwargs.get("stream"):
|
|
||||||
|
|
||||||
def _generator() -> Generator[Any, None, None]:
|
|
||||||
yield from openai_call(*args, **kwargs)
|
|
||||||
|
|
||||||
return _generator()
|
|
||||||
return openai_call(*args, **kwargs)
|
|
||||||
except AuthenticationError:
|
|
||||||
logger.exception("Failed to authenticate with OpenAI API")
|
|
||||||
raise
|
|
||||||
except Timeout:
|
|
||||||
logger.exception("OpenAI API timed out for query: %s", query)
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return cast(F, wrapped_call)
|
|
||||||
|
|
||||||
|
|
||||||
def _tiktoken_trim_chunks(
|
|
||||||
chunks: list[InferenceChunk], model_version: str, max_chunk_toks: int = 512
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
"""Edit chunks that have too high token count. Generally due to parsing issues or
|
|
||||||
characters from another language that are 1 char = 1 token
|
|
||||||
Trimming by tokens leads to information loss but currently no better way of handling
|
|
||||||
"""
|
|
||||||
encoder = tiktoken.encoding_for_model(model_version)
|
|
||||||
new_chunks = copy(chunks)
|
|
||||||
for ind, chunk in enumerate(new_chunks):
|
|
||||||
tokens = encoder.encode(chunk.content)
|
|
||||||
if len(tokens) > max_chunk_toks:
|
|
||||||
new_chunk = copy(chunk)
|
|
||||||
new_chunk.content = encoder.decode(tokens[:max_chunk_toks])
|
|
||||||
new_chunks[ind] = new_chunk
|
|
||||||
return new_chunks
|
|
||||||
|
|
||||||
|
|
||||||
# used to check if the QAModel is an OpenAI model
|
|
||||||
class OpenAIQAModel(QAModel, ABC):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompletionQA(OpenAIQAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
api_key: str | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
include_metadata: bool = INCLUDE_METADATA,
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.model_version = model_version
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.timeout = timeout
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
try:
|
|
||||||
self.api_key = api_key or get_gen_ai_api_key()
|
|
||||||
except ConfigNotFoundError:
|
|
||||||
raise OpenAIKeyMissing()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
|
||||||
for event in response:
|
|
||||||
yield event["choices"][0]["text"]
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
|
||||||
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.Completion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=_ensure_openai_api_key(self.api_key),
|
|
||||||
prompt=filled_prompt,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model_output = cast(str, response["choices"][0]["text"]).strip()
|
|
||||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
return process_answer(model_output, context_docs)
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
|
||||||
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.Completion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=_ensure_openai_api_key(self.api_key),
|
|
||||||
prompt=filled_prompt,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
stream=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = self._generate_tokens_from_response(response)
|
|
||||||
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=tokens,
|
|
||||||
context_docs=context_docs,
|
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
|
||||||
)
|
|
@@ -28,7 +28,7 @@ from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
|||||||
from danswer.direct_qa.qa_utils import process_answer
|
from danswer.direct_qa.qa_utils import process_answer
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.llm.llm import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import check_number_of_tokens
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.llm.utils import get_default_llm_tokenizer
|
from danswer.llm.utils import get_default_llm_tokenizer
|
||||||
|
@@ -1,272 +0,0 @@
|
|||||||
import abc
|
|
||||||
import json
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from requests.exceptions import Timeout
|
|
||||||
from requests.models import Response
|
|
||||||
|
|
||||||
from danswer.configs.constants import ModelHostType
|
|
||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
|
||||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
|
||||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
|
||||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
|
||||||
from danswer.direct_qa.qa_utils import process_answer
|
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
|
||||||
from danswer.direct_qa.qa_utils import simulate_streaming_response
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class HostSpecificRequestModel(abc.ABC):
|
|
||||||
"""Provides a more minimal implementation requirement for extending to new models
|
|
||||||
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_api_key(self) -> bool:
|
|
||||||
"""Is this model protected by security features
|
|
||||||
Does it need an api key to access the model for inference"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abc.abstractmethod
|
|
||||||
def send_model_request(
|
|
||||||
filled_prompt: str,
|
|
||||||
endpoint: str,
|
|
||||||
api_key: str | None,
|
|
||||||
max_output_tokens: int,
|
|
||||||
stream: bool,
|
|
||||||
timeout: int | None,
|
|
||||||
) -> Response:
|
|
||||||
"""Given a filled out prompt, how to send it to the model API with the
|
|
||||||
correct request format with the correct parameters"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abc.abstractmethod
|
|
||||||
def extract_model_output_from_response(
|
|
||||||
response: Response,
|
|
||||||
) -> str:
|
|
||||||
"""Extract the full model output text from a response.
|
|
||||||
This is for nonstreaming endpoints"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abc.abstractmethod
|
|
||||||
def generate_model_tokens_from_response(
|
|
||||||
response: Response,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
"""Generate tokens from a streaming response
|
|
||||||
This is for streaming endpoints"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceRequestModel(HostSpecificRequestModel):
|
|
||||||
@staticmethod
|
|
||||||
def send_model_request(
|
|
||||||
filled_prompt: str,
|
|
||||||
endpoint: str,
|
|
||||||
api_key: str | None,
|
|
||||||
max_output_tokens: int,
|
|
||||||
stream: bool, # Not supported by Inference Endpoints (as of Aug 2023)
|
|
||||||
timeout: int | None,
|
|
||||||
) -> Response:
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"inputs": filled_prompt,
|
|
||||||
"parameters": {
|
|
||||||
# HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive
|
|
||||||
"temperature": 0.01,
|
|
||||||
# Skip the long tail
|
|
||||||
"top_p": 0.9,
|
|
||||||
"max_new_tokens": max_output_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
|
||||||
except Timeout as error:
|
|
||||||
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _hf_extract_model_output(
|
|
||||||
response: Response,
|
|
||||||
) -> str:
|
|
||||||
if response.status_code != 200:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return json.loads(response.content)[0].get("generated_text", "")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_model_output_from_response(
|
|
||||||
response: Response,
|
|
||||||
) -> str:
|
|
||||||
return HuggingFaceRequestModel._hf_extract_model_output(response)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_model_tokens_from_response(
|
|
||||||
response: Response,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
"""HF endpoints do not do streaming currently so this function will
|
|
||||||
simulate streaming for the meantime but will need to be replaced in
|
|
||||||
the future once streaming is enabled."""
|
|
||||||
model_out = HuggingFaceRequestModel._hf_extract_model_output(response)
|
|
||||||
yield from simulate_streaming_response(model_out)
|
|
||||||
|
|
||||||
|
|
||||||
class ColabDemoRequestModel(HostSpecificRequestModel):
|
|
||||||
"""Guide found at:
|
|
||||||
https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_api_key(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def send_model_request(
|
|
||||||
filled_prompt: str,
|
|
||||||
endpoint: str,
|
|
||||||
api_key: str | None, # ngrok basic setup doesn't require this
|
|
||||||
max_output_tokens: int,
|
|
||||||
stream: bool,
|
|
||||||
timeout: int | None,
|
|
||||||
) -> Response:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"inputs": filled_prompt,
|
|
||||||
"parameters": {
|
|
||||||
"temperature": 0.0,
|
|
||||||
"max_tokens": max_output_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
|
||||||
except Timeout as error:
|
|
||||||
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _colab_demo_extract_model_output(
|
|
||||||
response: Response,
|
|
||||||
) -> str:
|
|
||||||
if response.status_code != 200:
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return json.loads(response.content).get("generated_text", "")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def extract_model_output_from_response(
|
|
||||||
response: Response,
|
|
||||||
) -> str:
|
|
||||||
return ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_model_tokens_from_response(
|
|
||||||
response: Response,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
|
||||||
yield from simulate_streaming_response(model_out)
|
|
||||||
|
|
||||||
|
|
||||||
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
|
|
||||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
|
||||||
return HuggingFaceRequestModel()
|
|
||||||
if model_host_type == ModelHostType.COLAB_DEMO.value:
|
|
||||||
return ColabDemoRequestModel()
|
|
||||||
else:
|
|
||||||
# TODO support Azure, GCP, AWS
|
|
||||||
raise ValueError("Invalid model hosting service selected")
|
|
||||||
|
|
||||||
|
|
||||||
class RequestCompletionQA(QAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
endpoint: str = GEN_AI_ENDPOINT,
|
|
||||||
model_host_type: str = GEN_AI_HOST_TYPE,
|
|
||||||
api_key: str | None = GEN_AI_API_KEY,
|
|
||||||
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
timeout: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.api_key = api_key
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.model_class = get_host_specific_model_class(model_host_type)
|
|
||||||
self.timeout = timeout
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_api_key(self) -> bool:
|
|
||||||
return self.model_class.requires_api_key
|
|
||||||
|
|
||||||
def _get_request_response(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk], stream: bool
|
|
||||||
) -> Response:
|
|
||||||
filled_prompt = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, include_metadata=False
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
return self.model_class.send_model_request(
|
|
||||||
filled_prompt,
|
|
||||||
self.endpoint,
|
|
||||||
self.api_key,
|
|
||||||
self.max_output_tokens,
|
|
||||||
stream,
|
|
||||||
self.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
model_api_response = self._get_request_response(
|
|
||||||
query, context_docs, stream=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model_output = self.model_class.extract_model_output_from_response(
|
|
||||||
model_api_response
|
|
||||||
)
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
return process_answer(model_output, context_docs)
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
model_api_response = self._get_request_response(
|
|
||||||
query, context_docs, stream=False
|
|
||||||
)
|
|
||||||
|
|
||||||
token_generator = self.model_class.generate_model_tokens_from_response(
|
|
||||||
model_api_response
|
|
||||||
)
|
|
||||||
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=token_generator,
|
|
||||||
context_docs=context_docs,
|
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
|
||||||
)
|
|
0
backend/danswer/llm/__init__.py
Normal file
0
backend/danswer/llm/__init__.py
Normal file
@@ -1,46 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
|
||||||
|
|
||||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
|
||||||
from danswer.configs.model_configs import API_VERSION_OPENAI
|
|
||||||
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
|
||||||
from danswer.llm.llm import LangChainChatLLM
|
|
||||||
from danswer.llm.utils import should_be_verbose
|
|
||||||
|
|
||||||
|
|
||||||
class AzureGPT(LangChainChatLLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str,
|
|
||||||
max_output_tokens: int,
|
|
||||||
timeout: int,
|
|
||||||
model_version: str,
|
|
||||||
api_base: str = API_BASE_OPENAI,
|
|
||||||
api_version: str = API_VERSION_OPENAI,
|
|
||||||
deployment_name: str = AZURE_DEPLOYMENT_ID,
|
|
||||||
*args: list[Any],
|
|
||||||
**kwargs: dict[str, Any]
|
|
||||||
):
|
|
||||||
self._llm = AzureChatOpenAI(
|
|
||||||
model=model_version,
|
|
||||||
openai_api_type="azure",
|
|
||||||
openai_api_base=api_base,
|
|
||||||
openai_api_version=api_version,
|
|
||||||
deployment_name=deployment_name,
|
|
||||||
openai_api_key=api_key,
|
|
||||||
max_tokens=max_output_tokens,
|
|
||||||
temperature=0,
|
|
||||||
request_timeout=timeout,
|
|
||||||
model_kwargs={
|
|
||||||
"top_p": 1,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
},
|
|
||||||
verbose=should_be_verbose(),
|
|
||||||
max_retries=0, # retries are handled outside of langchain
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def llm(self) -> AzureChatOpenAI:
|
|
||||||
return self._llm
|
|
@@ -1,48 +1,25 @@
|
|||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerGenAIModel
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
from danswer.configs.constants import ModelHostType
|
|
||||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
|
||||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
|
||||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||||
from danswer.llm.azure import AzureGPT
|
from danswer.llm.custom_llm import CustomModelServer
|
||||||
from danswer.llm.google_colab_demo import GoogleColabDemo
|
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||||
from danswer.llm.llm import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.openai import OpenAIGPT
|
from danswer.llm.multi_llm import DefaultMultiLLM
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm(
|
def get_default_llm(
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
"""NOTE: api_key/timeout must be a special args since we may want to check
|
"""A single place to fetch the configured LLM for Danswer
|
||||||
if an API key is valid for the default model setup OR we may want to use the
|
Also allows overriding certain LLM defaults"""
|
||||||
default model with a different timeout specified."""
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = get_gen_ai_api_key()
|
api_key = get_gen_ai_api_key()
|
||||||
|
|
||||||
model_args = {
|
if GEN_AI_MODEL_PROVIDER.lower() == "custom":
|
||||||
# provide a dummy key since LangChain will throw an exception if not
|
return CustomModelServer(api_key=api_key, timeout=timeout)
|
||||||
# given, which would prevent server startup
|
|
||||||
"api_key": api_key or "dummy_api_key",
|
|
||||||
"timeout": timeout,
|
|
||||||
"model_version": GEN_AI_MODEL_VERSION,
|
|
||||||
"endpoint": GEN_AI_ENDPOINT,
|
|
||||||
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
"temperature": GEN_AI_TEMPERATURE,
|
|
||||||
}
|
|
||||||
if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
|
|
||||||
if API_TYPE_OPENAI == "azure":
|
|
||||||
return AzureGPT(**model_args) # type: ignore
|
|
||||||
return OpenAIGPT(**model_args) # type: ignore
|
|
||||||
if (
|
|
||||||
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
|
|
||||||
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
|
|
||||||
):
|
|
||||||
return GoogleColabDemo(**model_args) # type: ignore
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")
|
if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all":
|
||||||
|
DanswerGPT4All(timeout=timeout)
|
||||||
|
|
||||||
|
return DefaultMultiLLM(api_key=api_key, timeout=timeout)
|
||||||
|
@@ -1,24 +1,36 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain.schema.language_model import LanguageModelInput
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
from requests import Timeout
|
from requests import Timeout
|
||||||
|
|
||||||
from danswer.llm.llm import LLM
|
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||||
from danswer.llm.utils import convert_input
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.llm.interfaces import LLM
|
||||||
|
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||||
|
|
||||||
|
|
||||||
class GoogleColabDemo(LLM):
|
class CustomModelServer(LLM):
|
||||||
|
"""This class is to provide an example for how to use Danswer
|
||||||
|
with any LLM, even servers with custom API definitions.
|
||||||
|
To use with your own model server, simply implement the functions
|
||||||
|
below to fit your model server expectation"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
endpoint: str,
|
# Not used here but you probably want a model server that isn't completely open
|
||||||
max_output_tokens: int,
|
api_key: str | None,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
*args: list[Any],
|
endpoint: str | None = GEN_AI_API_ENDPOINT,
|
||||||
**kwargs: dict[str, Any],
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
):
|
):
|
||||||
|
if not endpoint:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot point Danswer to a custom LLM server without providing the "
|
||||||
|
"endpoint for the model server."
|
||||||
|
)
|
||||||
|
|
||||||
self._endpoint = endpoint
|
self._endpoint = endpoint
|
||||||
self._max_output_tokens = max_output_tokens
|
self._max_output_tokens = max_output_tokens
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
@@ -29,7 +41,7 @@ class GoogleColabDemo(LLM):
|
|||||||
}
|
}
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"inputs": convert_input(input),
|
"inputs": convert_lm_input_to_basic_string(input),
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"max_tokens": self._max_output_tokens,
|
"max_tokens": self._max_output_tokens,
|
60
backend/danswer/llm/gpt_4_all.py
Normal file
60
backend/danswer/llm/gpt_4_all.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
|
from danswer.llm.interfaces import LLM
|
||||||
|
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyGPT4All:
|
||||||
|
"""In the case of import failure due to architectural incompatibilities,
|
||||||
|
this module does not raise exceptions during server startup,
|
||||||
|
as long as the module isn't actually used"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise RuntimeError("GPT4All library not installed.")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gpt4all import GPT4All # type:ignore
|
||||||
|
except ImportError:
|
||||||
|
# Setting a low log level because users get scared when they see this
|
||||||
|
logger.debug(
|
||||||
|
"GPT4All library not installed. "
|
||||||
|
"If you wish to run GPT4ALL (in memory) to power Danswer's "
|
||||||
|
"Generative AI features, please install gpt4all==2.0.2."
|
||||||
|
)
|
||||||
|
GPT4All = DummyGPT4All
|
||||||
|
|
||||||
|
|
||||||
|
class DanswerGPT4All(LLM):
|
||||||
|
"""Option to run an LLM locally, however this is significantly slower and
|
||||||
|
answers tend to be much worse"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
timeout: int,
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
temperature: float = GEN_AI_TEMPERATURE,
|
||||||
|
):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.gpt4all_model = GPT4All(model_version)
|
||||||
|
|
||||||
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
|
prompt_basic = convert_lm_input_to_basic_string(prompt)
|
||||||
|
return self.gpt4all_model.generate(prompt_basic)
|
||||||
|
|
||||||
|
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||||
|
prompt_basic = convert_lm_input_to_basic_string(prompt)
|
||||||
|
return self.gpt4all_model.generate(prompt_basic, streaming=True)
|
74
backend/danswer/llm/multi_llm.py
Normal file
74
backend/danswer/llm/multi_llm.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
import litellm # type:ignore
|
||||||
|
from langchain.chat_models import ChatLiteLLM
|
||||||
|
|
||||||
|
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||||
|
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
|
from danswer.llm.interfaces import LangChainChatLLM
|
||||||
|
from danswer.llm.utils import should_be_verbose
|
||||||
|
|
||||||
|
|
||||||
|
# If a user configures a different model and it doesn't support all the same
|
||||||
|
# parameters like frequency and presence, just ignore them
|
||||||
|
litellm.drop_params = True
|
||||||
|
litellm.telemetry = False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_str(
|
||||||
|
model_provider: str | None,
|
||||||
|
model_version: str | None,
|
||||||
|
) -> str:
|
||||||
|
if model_provider and model_version:
|
||||||
|
return model_provider + "/" + model_version
|
||||||
|
|
||||||
|
if model_version:
|
||||||
|
# Litellm defaults to openai if no provider specified
|
||||||
|
# It's implicit so no need to specify here either
|
||||||
|
return model_version
|
||||||
|
|
||||||
|
# User specified something wrong, just use Danswer default
|
||||||
|
return GEN_AI_MODEL_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMultiLLM(LangChainChatLLM):
|
||||||
|
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||||
|
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PARAMS = {
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None,
|
||||||
|
timeout: int,
|
||||||
|
model_provider: str | None = GEN_AI_MODEL_PROVIDER,
|
||||||
|
model_version: str | None = GEN_AI_MODEL_VERSION,
|
||||||
|
api_base: str | None = GEN_AI_API_ENDPOINT,
|
||||||
|
api_version: str | None = GEN_AI_API_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
temperature: float = GEN_AI_TEMPERATURE,
|
||||||
|
):
|
||||||
|
# Litellm Langchain integration currently doesn't take in the api key param
|
||||||
|
# Can place this in the call below once integration is in
|
||||||
|
litellm.api_key = api_key
|
||||||
|
litellm.api_version = api_version
|
||||||
|
|
||||||
|
self._llm = ChatLiteLLM( # type: ignore
|
||||||
|
model=_get_model_str(model_provider, model_version),
|
||||||
|
api_base=api_base,
|
||||||
|
max_tokens=max_output_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
request_timeout=timeout,
|
||||||
|
model_kwargs=DefaultMultiLLM.DEFAULT_MODEL_PARAMS,
|
||||||
|
verbose=should_be_verbose(),
|
||||||
|
max_retries=0, # retries are handled outside of langchain
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm(self) -> ChatLiteLLM:
|
||||||
|
return self._llm
|
@@ -1,50 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from langchain.chat_models import ChatLiteLLM
|
|
||||||
import litellm # type:ignore
|
|
||||||
|
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
|
||||||
from danswer.llm.llm import LangChainChatLLM
|
|
||||||
from danswer.llm.utils import should_be_verbose
|
|
||||||
|
|
||||||
|
|
||||||
# If a user configures a different model and it doesn't support all the same
|
|
||||||
# parameters like frequency and presence, just ignore them
|
|
||||||
litellm.drop_params=True
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPT(LangChainChatLLM):
|
|
||||||
|
|
||||||
DEFAULT_MODEL_PARAMS = {
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str,
|
|
||||||
max_output_tokens: int,
|
|
||||||
timeout: int,
|
|
||||||
model_version: str,
|
|
||||||
temperature: float = GEN_AI_TEMPERATURE,
|
|
||||||
*args: list[Any],
|
|
||||||
**kwargs: dict[str, Any]
|
|
||||||
):
|
|
||||||
litellm.api_key = api_key
|
|
||||||
|
|
||||||
self._llm = ChatLiteLLM( # type: ignore
|
|
||||||
model=model_version,
|
|
||||||
# Prefer using None which is the default value, endpoint could be empty string
|
|
||||||
api_base=cast(str, kwargs.get("endpoint")) or None,
|
|
||||||
max_tokens=max_output_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
request_timeout=timeout,
|
|
||||||
model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS,
|
|
||||||
verbose=should_be_verbose(),
|
|
||||||
max_retries=0, # retries are handled outside of langchain
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def llm(self) -> ChatLiteLLM:
|
|
||||||
return self._llm
|
|
@@ -22,6 +22,7 @@ _LLM_TOKENIZER: Callable[[str], Any] | None = None
|
|||||||
|
|
||||||
|
|
||||||
def get_default_llm_tokenizer() -> Callable:
|
def get_default_llm_tokenizer() -> Callable:
|
||||||
|
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||||
global _LLM_TOKENIZER
|
global _LLM_TOKENIZER
|
||||||
if _LLM_TOKENIZER is None:
|
if _LLM_TOKENIZER is None:
|
||||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode
|
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode
|
||||||
@@ -71,14 +72,7 @@ def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
|||||||
return [HumanMessage(content=message)]
|
return [HumanMessage(content=message)]
|
||||||
|
|
||||||
|
|
||||||
def message_generator_to_string_generator(
|
def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str:
|
||||||
messages: Iterator[BaseMessageChunk],
|
|
||||||
) -> Iterator[str]:
|
|
||||||
for message in messages:
|
|
||||||
yield message.content
|
|
||||||
|
|
||||||
|
|
||||||
def convert_input(lm_input: LanguageModelInput) -> str:
|
|
||||||
"""Heavily inspired by:
|
"""Heavily inspired by:
|
||||||
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
||||||
"""
|
"""
|
||||||
@@ -99,6 +93,13 @@ def convert_input(lm_input: LanguageModelInput) -> str:
|
|||||||
return prompt_value.to_string()
|
return prompt_value.to_string()
|
||||||
|
|
||||||
|
|
||||||
|
def message_generator_to_string_generator(
|
||||||
|
messages: Iterator[BaseMessageChunk],
|
||||||
|
) -> Iterator[str]:
|
||||||
|
for message in messages:
|
||||||
|
yield message.content
|
||||||
|
|
||||||
|
|
||||||
def should_be_verbose() -> bool:
|
def should_be_verbose() -> bool:
|
||||||
return LOG_LEVEL == "debug"
|
return LOG_LEVEL == "debug"
|
||||||
|
|
||||||
|
@@ -22,13 +22,12 @@ from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
|||||||
from danswer.configs.app_configs import SECRET
|
from danswer.configs.app_configs import SECRET
|
||||||
from danswer.configs.app_configs import WEB_DOMAIN
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
from danswer.configs.constants import AuthType
|
from danswer.configs.constants import AuthType
|
||||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
|
||||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
|
||||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
|
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
|
||||||
from danswer.configs.model_configs import SKIP_RERANKING
|
from danswer.configs.model_configs import SKIP_RERANKING
|
||||||
from danswer.db.credentials import create_initial_public_credential
|
from danswer.db.credentials import create_initial_public_credential
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
@@ -152,14 +151,6 @@ def get_application() -> FastAPI:
|
|||||||
warm_up_models,
|
warm_up_models,
|
||||||
)
|
)
|
||||||
|
|
||||||
if DISABLE_GENERATIVE_AI:
|
|
||||||
logger.info("Generative AI Q&A disabled")
|
|
||||||
else:
|
|
||||||
logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
|
|
||||||
logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
|
|
||||||
if API_TYPE_OPENAI == "azure":
|
|
||||||
logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}")
|
|
||||||
|
|
||||||
verify_auth = fetch_versioned_implementation(
|
verify_auth = fetch_versioned_implementation(
|
||||||
"danswer.auth.users", "verify_auth_setting"
|
"danswer.auth.users", "verify_auth_setting"
|
||||||
)
|
)
|
||||||
@@ -169,6 +160,14 @@ def get_application() -> FastAPI:
|
|||||||
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
||||||
logger.info("Both OAuth Client ID and Secret are configured.")
|
logger.info("Both OAuth Client ID and Secret are configured.")
|
||||||
|
|
||||||
|
if DISABLE_GENERATIVE_AI:
|
||||||
|
logger.info("Generative AI Q&A disabled")
|
||||||
|
else:
|
||||||
|
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
||||||
|
logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}")
|
||||||
|
if GEN_AI_API_ENDPOINT:
|
||||||
|
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||||
|
|
||||||
if SKIP_RERANKING:
|
if SKIP_RERANKING:
|
||||||
logger.info("Reranking step of search flow is disabled")
|
logger.info("Reranking step of search flow is disabled")
|
||||||
|
|
||||||
|
@@ -25,7 +25,7 @@ from danswer.db.feedback import update_document_hidden
|
|||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.direct_qa.open_ai import get_gen_ai_api_key
|
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
from danswer.server.models import ApiKey
|
from danswer.server.models import ApiKey
|
||||||
|
@@ -19,8 +19,6 @@ from danswer.db.feedback import update_query_event_feedback
|
|||||||
from danswer.db.feedback import update_query_event_retrieved_documents
|
from danswer.db.feedback import update_query_event_retrieved_documents
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.answer_question import answer_qa_query
|
from danswer.direct_qa.answer_question import answer_qa_query
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.interfaces import StreamingError
|
from danswer.direct_qa.interfaces import StreamingError
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
@@ -302,7 +300,7 @@ def stream_direct_qa(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
qa_model = get_default_qa_model()
|
qa_model = get_default_qa_model()
|
||||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
except Exception as e:
|
||||||
logger.exception("Unable to get QA model")
|
logger.exception("Unable to get QA model")
|
||||||
error = StreamingError(error=str(e))
|
error = StreamingError(error=str(e))
|
||||||
yield get_json_line(error.dict())
|
yield get_json_line(error.dict())
|
||||||
|
@@ -13,9 +13,9 @@ filelock==3.12.0
|
|||||||
google-api-python-client==2.86.0
|
google-api-python-client==2.86.0
|
||||||
google-auth-httplib2==0.1.0
|
google-auth-httplib2==0.1.0
|
||||||
google-auth-oauthlib==1.0.0
|
google-auth-oauthlib==1.0.0
|
||||||
# GPT4All library does not support M1 Mac architecture
|
# GPT4All library has issues running on Macs and python:3.11.4-slim-bookworm
|
||||||
# will reintroduce this when library version catches up
|
# will reintroduce this when library version catches up
|
||||||
# gpt4all==1.0.5
|
# gpt4all==2.0.2
|
||||||
httpcore==0.16.3
|
httpcore==0.16.3
|
||||||
httpx==0.23.3
|
httpx==0.23.3
|
||||||
httpx-oauth==0.11.2
|
httpx-oauth==0.11.2
|
||||||
|
@@ -16,11 +16,11 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8080:8080"
|
||||||
environment:
|
environment:
|
||||||
- INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion}
|
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
|
||||||
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
|
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
|
||||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||||
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
|
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
|
||||||
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
|
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
|
||||||
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
|
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
|
||||||
- POSTGRES_HOST=relational_db
|
- POSTGRES_HOST=relational_db
|
||||||
- VESPA_HOST=index
|
- VESPA_HOST=index
|
||||||
@@ -30,10 +30,6 @@ services:
|
|||||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||||
- API_BASE_OPENAI=${API_BASE_OPENAI:-}
|
|
||||||
- API_TYPE_OPENAI=${API_TYPE_OPENAI:-}
|
|
||||||
- API_VERSION_OPENAI=${API_VERSION_OPENAI:-}
|
|
||||||
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
|
|
||||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||||
- DISABLE_TIME_FILTER_EXTRACTION=${DISABLE_TIME_FILTER_EXTRACTION:-}
|
- DISABLE_TIME_FILTER_EXTRACTION=${DISABLE_TIME_FILTER_EXTRACTION:-}
|
||||||
# Don't change the NLP model configs unless you know what you're doing
|
# Don't change the NLP model configs unless you know what you're doing
|
||||||
@@ -63,17 +59,13 @@ services:
|
|||||||
- index
|
- index
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
- INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion}
|
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
|
||||||
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
|
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
|
||||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||||
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
|
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
|
||||||
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
|
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
|
||||||
- POSTGRES_HOST=relational_db
|
- POSTGRES_HOST=relational_db
|
||||||
- VESPA_HOST=index
|
- VESPA_HOST=index
|
||||||
- API_BASE_OPENAI=${API_BASE_OPENAI:-}
|
|
||||||
- API_TYPE_OPENAI=${API_TYPE_OPENAI:-}
|
|
||||||
- API_VERSION_OPENAI=${API_VERSION_OPENAI:-}
|
|
||||||
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
|
|
||||||
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
||||||
# Connector Configs
|
# Connector Configs
|
||||||
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
|
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
|
||||||
|
@@ -3,57 +3,57 @@
|
|||||||
# This is only necessary when using the docker-compose.prod.yml compose file.
|
# This is only necessary when using the docker-compose.prod.yml compose file.
|
||||||
|
|
||||||
|
|
||||||
# Insert your OpenAI API key here If not provided here, UI will prompt on setup.
|
|
||||||
# This env variable takes precedence over UI settings.
|
|
||||||
GEN_AI_API_KEY=
|
|
||||||
# Choose between "openai-chat-completion" and "openai-completion"
|
|
||||||
INTERNAL_MODEL_VERSION=openai-chat-completion
|
|
||||||
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility
|
|
||||||
GEN_AI_MODEL_VERSION=gpt-4
|
|
||||||
|
|
||||||
# Neccessary environment variables for Azure OpenAI:
|
|
||||||
API_BASE_OPENAI=
|
|
||||||
API_TYPE_OPENAI=
|
|
||||||
API_VERSION_OPENAI=
|
|
||||||
AZURE_DEPLOYMENT_ID=
|
|
||||||
|
|
||||||
# Could be something like danswer.companyname.com
|
# Could be something like danswer.companyname.com
|
||||||
WEB_DOMAIN=http://localhost:3000
|
WEB_DOMAIN=http://localhost:3000
|
||||||
|
|
||||||
# Default values here are what Postgres uses by default, feel free to change.
|
|
||||||
POSTGRES_USER=postgres
|
# Generative AI settings, uncomment as needed, will work with defaults
|
||||||
POSTGRES_PASSWORD=password
|
GEN_AI_MODEL_PROVIDER=openai
|
||||||
|
GEN_AI_MODEL_VERSION=gpt-4
|
||||||
|
# Provide this as a global default/backup, this can also be set via the UI
|
||||||
|
#GEN_AI_API_KEY=
|
||||||
|
# Set to use Azure OpenAI or other services, such as https://danswer.openai.azure.com/
|
||||||
|
#GEN_AI_API_ENDPOINT=
|
||||||
|
# Set up to use a specific API version, such as 2023-09-15-preview (example taken from Azure)
|
||||||
|
#GEN_AI_API_VERSION=
|
||||||
|
|
||||||
|
|
||||||
# If you want to setup a slack bot to answer questions automatically in Slack
|
# If you want to setup a slack bot to answer questions automatically in Slack
|
||||||
# channels it is added to, you must specify the below.
|
# channels it is added to, you must specify the two below.
|
||||||
# More information in the guide here: https://docs.danswer.dev/slack_bot_setup
|
# More information in the guide here: https://docs.danswer.dev/slack_bot_setup
|
||||||
DANSWER_BOT_SLACK_APP_TOKEN=
|
#DANSWER_BOT_SLACK_APP_TOKEN=
|
||||||
DANSWER_BOT_SLACK_BOT_TOKEN=
|
#DANSWER_BOT_SLACK_BOT_TOKEN=
|
||||||
|
|
||||||
# Used to generate values for security verification, use a random string
|
|
||||||
SECRET=
|
|
||||||
|
|
||||||
# How long before user needs to reauthenticate, default to 1 day. (cookie expiration time)
|
|
||||||
SESSION_EXPIRE_TIME_SECONDS=86400
|
|
||||||
|
|
||||||
# The following are for configuring User Authentication, supported flows are:
|
# The following are for configuring User Authentication, supported flows are:
|
||||||
# disabled
|
# disabled
|
||||||
# google_oauth (login with google/gmail account)
|
# google_oauth (login with google/gmail account)
|
||||||
# oidc (only in Danswer enterprise edition)
|
# oidc (only in Danswer enterprise edition)
|
||||||
# saml (only in Danswer enterprise edition)
|
# saml (only in Danswer enterprise edition)
|
||||||
AUTH_TYPE=
|
AUTH_TYPE=google_oauth
|
||||||
|
|
||||||
# Set the two below to use with Google OAuth
|
# Set the values below to use with Google OAuth
|
||||||
GOOGLE_OAUTH_CLIENT_ID=
|
GOOGLE_OAUTH_CLIENT_ID=
|
||||||
GOOGLE_OAUTH_CLIENT_SECRET=
|
GOOGLE_OAUTH_CLIENT_SECRET=
|
||||||
|
SECRET=
|
||||||
|
|
||||||
# OpenID Connect (OIDC)
|
# OpenID Connect (OIDC)
|
||||||
OPENID_CONFIG_URL=
|
#OPENID_CONFIG_URL=
|
||||||
|
|
||||||
# SAML config directory for OneLogin compatible setups
|
# SAML config directory for OneLogin compatible setups
|
||||||
SAML_CONF_DIR=
|
#SAML_CONF_DIR=
|
||||||
|
|
||||||
# used to specify a list of allowed user domains, only checked if user Auth is turned on
|
|
||||||
|
# How long before user needs to reauthenticate, default to 1 day. (cookie expiration time)
|
||||||
|
SESSION_EXPIRE_TIME_SECONDS=86400
|
||||||
|
|
||||||
|
|
||||||
|
# Use the below to specify a list of allowed user domains, only checked if user Auth is turned on
|
||||||
# e.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will only allow users
|
# e.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will only allow users
|
||||||
# with an @example.com or an @example.org email
|
# with an @example.com or an @example.org email
|
||||||
VALID_EMAIL_DOMAINS=
|
#VALID_EMAIL_DOMAINS=
|
||||||
|
|
||||||
|
|
||||||
|
# Default values here are what Postgres uses by default, feel free to change.
|
||||||
|
POSTGRES_USER=postgres
|
||||||
|
POSTGRES_PASSWORD=password
|
||||||
|
Reference in New Issue
Block a user