Reworking the LLM layer (#666)

This commit is contained in:
Yuhong Sun
2023-10-31 18:22:42 -07:00
committed by GitHub
parent d9e5795b36
commit fbf7c642a3
28 changed files with 265 additions and 1448 deletions

View File

@@ -34,7 +34,7 @@ from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.document_index import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.llm.build import get_default_llm
from danswer.llm.llm import LLM
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.access_filters import build_access_filters_for_user

View File

@@ -77,21 +77,6 @@ class DocumentIndexType(str, Enum):
SPLIT = "split" # Typesense + Qdrant
class DanswerGenAIModel(str, Enum):
"""This represents the internal Danswer GenAI model which determines the class that is used
to generate responses to the user query. Different models/services require different internal
handling, this allows for modularity of implementation within Danswer"""
OPENAI = "openai-completion"
OPENAI_CHAT = "openai-chat-completion"
GPT4ALL = "gpt4all-completion"
GPT4ALL_CHAT = "gpt4all-chat-completion"
HUGGINGFACE = "huggingface-client-completion"
HUGGINGFACE_CHAT = "huggingface-client-chat-completion"
REQUEST = "request-completion"
TRANSFORMERS = "transformers"
class AuthType(str, Enum):
DISABLED = "disabled"
BASIC = "basic"
@@ -100,17 +85,6 @@ class AuthType(str, Enum):
SAML = "saml"
class ModelHostType(str, Enum):
"""For GenAI models interfaced via requests, different services have different
expectations for what fields are included in the request"""
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
COLAB_DEMO = "colab-demo"
# TODO support for Azure, AWS, GCP GenAI model hosting
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics

View File

@@ -1,9 +1,5 @@
import os
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
#####
# Embedding/Reranking Model Configs
#####
@@ -55,62 +51,38 @@ SEARCH_DISTANCE_CUTOFF = 0
# Intent model max context size
QUERY_MAX_CONTEXT_SIZE = 256
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
#####
# Generative AI Model Configs
#####
# Other models should work as well, check the library/API compatibility.
# But these are the models that have been verified to work with the existing prompts.
# Using a different model may require some prompt tuning. See qa_prompts.py
VERIFIED_MODELS = {
DanswerGenAIModel.OPENAI: ["text-davinci-003"],
DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"],
DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"],
DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"],
# The "chat" model below is actually "instruction finetuned" and does not support conversational
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
# Created by Deepset.ai
# https://huggingface.co/deepset/deberta-v3-large-squad2
# Model provided with no modifications
DanswerGenAIModel.TRANSFORMERS.value: ["deepset/deberta-v3-large-squad2"],
}
# Sets the internal Danswer model class to use
INTERNAL_MODEL_VERSION = os.environ.get(
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
)
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
# be sure to use one that is LiteLLM compatible:
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
# The provider is the prefix before / in the model argument
# Additionally Danswer supports GPT4All and custom request library based models
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
GEN_AI_MODEL_VERSION = os.environ.get(
"GEN_AI_MODEL_VERSION",
VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0],
GEN_AI_API_KEY = (
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
)
# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
# set the two below to specify
# - Where to hit the endpoint
# - How should the request be formed
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
# API Base, such as (for Azure): https://danswer.openai.azure.com/
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
# API Version, such as (for Azure): 2023-09-15-preview
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# This next restriction is only used for chat ATM, used to expire old messages as needed
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
#####
# OpenAI Azure
#####
API_BASE_OPENAI = os.environ.get("API_BASE_OPENAI", "")
API_TYPE_OPENAI = os.environ.get("API_TYPE_OPENAI", "").lower()
API_VERSION_OPENAI = os.environ.get("API_VERSION_OPENAI", "")
# Deployment ID used interchangeably with "engine" parameter
AZURE_DEPLOYMENT_ID = os.environ.get("AZURE_DEPLOYMENT_ID", "")

View File

@@ -9,8 +9,6 @@ from danswer.configs.constants import IGNORE_FOR_QA
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks
@@ -132,7 +130,7 @@ def answer_qa_query(
qa_model = get_default_qa_model(
timeout=answer_generation_timeout, real_time_flow=real_time_flow
)
except (UnknownModelError, OpenAIKeyMissing) as e:
except Exception as e:
return QAResponse(
answer=None,
quotes=None,

View File

@@ -1,13 +0,0 @@
class OpenAIKeyMissing(Exception):
default_msg = (
"Unable to find existing OpenAI Key. "
'A new key can be added from "Keys" section of the Admin Panel'
)
def __init__(self, msg: str = default_msg) -> None:
super().__init__(msg)
class UnknownModelError(Exception):
def __init__(self, model_name: str) -> None:
super().__init__(f"Unknown Internal QA model name: {model_name}")

View File

@@ -1,209 +0,0 @@
from collections.abc import Callable
from typing import Any
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.indexing.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
class DummyGPT4All:
"""In the case of import failure due to M1 Mac incompatibility,
so this module does not raise exceptions during server startup,
as long as this module isn't actually used"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("GPT4All library not installed.")
try:
from gpt4all import GPT4All # type:ignore
except ImportError:
logger.warning(
"GPT4All library not installed. "
"If you wish to run GPT4ALL (in memory) to power Danswer's "
"Generative AI features, please install gpt4all==1.0.5. "
"As of Aug 2023, this library is not compatible with M1 Mac."
)
GPT4All = DummyGPT4All
GPT4ALL_MODEL: GPT4All | None = None
def get_gpt_4_all_model(
model_version: str = GEN_AI_MODEL_VERSION,
) -> GPT4All:
global GPT4ALL_MODEL
if GPT4ALL_MODEL is None:
GPT4ALL_MODEL = GPT4All(model_version)
return GPT4ALL_MODEL
def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"temp": 0,
**kwargs,
}
class GPT4AllCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False, # gpt4all models can't handle this atm
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
@property
def requires_api_key(self) -> bool:
return False
def warm_up_model(self) -> None:
get_gpt_4_all_model(self.model_version)
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
model_output = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=filled_prompt, max_tokens=self.max_output_tokens
),
)
logger.debug(model_output)
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
model_stream = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True
),
)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class GPT4AllChatCompletionQA(QAModel):
def __init__(
self,
prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False, # gpt4all models can't handle this atm
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
@property
def requires_api_key(self) -> bool:
return False
def warm_up_model(self) -> None:
get_gpt_4_all_model(self.model_version)
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
with gen_ai_model.chat_session():
context_msgs = filled_prompt[:-1]
user_query = filled_prompt[-1].get("content")
for message in context_msgs:
gen_ai_model.current_chat_session.append(message)
model_output = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=user_query, max_tokens=self.max_output_tokens
),
)
logger.debug(model_output)
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
with gen_ai_model.chat_session():
context_msgs = filled_prompt[:-1]
user_query = filled_prompt[-1].get("content")
for message in context_msgs:
gen_ai_model.current_chat_session.append(message)
model_stream = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=user_query, max_tokens=self.max_output_tokens
),
)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -1,195 +0,0 @@
from collections.abc import Callable
from typing import Any
from huggingface_hub import InferenceClient # type:ignore
from huggingface_hub.utils import HfHubHTTPError # type:ignore
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import FreeformProcessor
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.direct_qa.qa_utils import simulate_streaming_response
from danswer.indexing.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"do_sample": False,
"seed": 69, # For reproducibility
**kwargs,
}
class HuggingFaceCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = FreeformProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_stream = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class HuggingFaceChatCompletionQA(QAModel):
"""Chat in this class refers to the HuggingFace Conversational API.
Not to be confused with Chat/Instruction finetuned models.
Llama2-chat... means it is an Instruction finetuned model, not necessarily that
it supports Conversational API"""
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@staticmethod
def _convert_chat_to_hf_conversational_format(
dialog: list[dict[str, str]]
) -> tuple[str, list[str], list[str]]:
if dialog[-1]["role"] != "user":
raise Exception(
"Last message in conversational dialog must be User message"
)
user_message = dialog[-1]["content"]
dialog = dialog[0:-1]
generated_responses = []
past_user_inputs = []
for message in dialog:
# HuggingFace inference client doesn't support system messages today
# so lumping them in with user messages
if message["role"] in ["user", "system"]:
past_user_inputs += [message["content"]]
else:
generated_responses += [message["content"]]
return user_message, generated_responses, past_user_inputs
def _get_hf_model_output(
self, query: str, context_docs: list[InferenceChunk]
) -> str:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
(
query,
past_responses,
past_inputs,
) = self._convert_chat_to_hf_conversational_format(filled_prompt)
logger.debug(f"Last Input: {query}")
logger.debug(f"Past Inputs: {past_inputs}")
logger.debug(f"Past Responses: {past_responses}")
try:
model_output = self.client.conversational(
query,
generated_responses=past_responses,
past_user_inputs=past_inputs,
parameters={"max_length": self.max_output_tokens},
)
except HfHubHTTPError as model_error:
if model_error.response.status_code == 422:
raise ValueError(
"Selected HuggingFace Model does not support HuggingFace Conversational API,"
"try using the huggingface-inference-completion in Danswer instead"
)
raise
logger.debug(model_output)
return model_output
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
model_output = self._get_hf_model_output(query, context_docs)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
So here it is faked by streaming characters within Danswer from the model output
"""
model_output = self._get_hf_model_output(query, context_docs)
model_stream = simulate_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -52,6 +52,7 @@ class QAModel:
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
# TODO, this should be false for custom request model and gpt4all
return True
def warm_up_model(self) -> None:

View File

@@ -1,32 +1,11 @@
from typing import Any
import pkg_resources
from openai.error import AuthenticationError
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.local_transformers import TransformerQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_block import QABlock
from danswer.direct_qa.qa_block import QAHandler
from danswer.direct_qa.qa_block import SimpleChatQAHandler
from danswer.direct_qa.qa_block import SingleMessageQAHandler
from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.build import get_default_llm
from danswer.utils.logger import setup_logger
@@ -52,93 +31,23 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
return False
def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
return (
SingleMessageQAHandler()
if real_time_flow
else SingleMessageScratchpadHandler()
)
return SimpleChatQAHandler()
# TODO introduce the prompt choice parameter
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
return (
SingleMessageQAHandler() if real_time_flow else SingleMessageScratchpadHandler()
)
# return SimpleChatQAHandler()
def get_default_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
**kwargs: Any,
) -> QAModel:
if not api_key:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
pass
llm = get_default_llm(api_key=api_key, timeout=timeout)
qa_handler = get_default_qa_handler(real_time_flow=real_time_flow)
try:
# un-used arguments will be ignored by the underlying `LLM` class
# if any args are missing, a `TypeError` will be thrown
llm = get_default_llm(timeout=timeout)
qa_handler = get_default_qa_handler(
model=internal_model, real_time_flow=real_time_flow
)
return QABlock(
llm=llm,
qa_handler=qa_handler,
)
except Exception:
logger.exception(
"Unable to build a QABlock with the new approach, going back to the "
"legacy approach"
)
if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
]:
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
pkg_resources.get_distribution("gpt4all")
if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
return GPT4AllCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
return GPT4AllChatCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
return TransformerQA()
elif internal_model == DanswerGenAIModel.REQUEST.value:
if endpoint is None or model_host_type is None:
raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if (
model_host_type == ModelHostType.HUGGINGFACE.value
or model_host_type == ModelHostType.COLAB_DEMO.value
):
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
prompt_processor=WeakModelFreeformProcessor(),
timeout=timeout,
)
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
timeout=timeout,
)
else:
raise UnknownModelError(internal_model)
return QABlock(
llm=llm,
qa_handler=qa_handler,
)

View File

@@ -1,156 +0,0 @@
import re
from collections.abc import Callable
from transformers import pipeline # type:ignore
from transformers import QuestionAnsweringPipeline # type:ignore
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.indexing.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
TRANSFORMER_DEFAULT_MAX_CONTEXT = 512
_TRANSFORMER_MODEL: QuestionAnsweringPipeline | None = None
def get_default_transformer_model(
model_version: str = GEN_AI_MODEL_VERSION,
max_context: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
) -> QuestionAnsweringPipeline:
global _TRANSFORMER_MODEL
if _TRANSFORMER_MODEL is None:
_TRANSFORMER_MODEL = pipeline(
"question-answering", model=model_version, max_seq_len=max_context
)
return _TRANSFORMER_MODEL
def find_extended_answer(answer: str, context: str) -> str:
"""Try to extend the answer by matching across the context text and extending before
and after the quote to some termination character"""
result = re.search(
r"(^|\n\r?|\.)(?P<content>[^\n]{{0,250}}{}[^\n]{{0,250}})(\.|$|\n\r?)".format(
re.escape(answer)
),
context,
flags=re.MULTILINE | re.DOTALL,
)
if result:
return result.group("content")
return answer
class TransformerQA(QAModel):
@staticmethod
def _answer_one_chunk(
query: str,
chunk: InferenceChunk,
max_context_len: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
max_cutoff: float = 0.9,
min_cutoff: float = 0.5,
) -> tuple[str | None, DanswerQuote | None]:
"""Because this type of QA model only takes 1 chunk of context with a fairly small token limit
We have to iterate the checks and check if the answer is found in any of the chunks.
This type of approach does not allow for interpolating answers across chunks
"""
model = get_default_transformer_model()
model_out = model(question=query, context=chunk.content, max_answer_len=128)
answer = model_out.get("answer")
confidence = model_out.get("score")
if answer is None:
return None, None
logger.info(f"Transformer Answer: {answer}")
logger.debug(f"Transformer Confidence: {confidence}")
# Model tends to be overconfident on short chunks
# so min score required increases as chunk size decreases
# If it's at least 0.9, then it's good enough regardless
# Default minimum of 0.5 required
score_cutoff = max(
min(max_cutoff, 1 - len(chunk.content) / max_context_len), min_cutoff
)
if confidence < score_cutoff:
return None, None
extended_answer = find_extended_answer(answer, chunk.content)
danswer_quote = DanswerQuote(
quote=answer,
document_id=chunk.document_id,
link=chunk.source_links[0] if chunk.source_links else None,
source_type=chunk.source_type,
semantic_identifier=chunk.semantic_identifier,
blurb=chunk.blurb,
)
return extended_answer, danswer_quote
def warm_up_model(self) -> None:
get_default_transformer_model()
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
danswer_quotes: list[DanswerQuote] = []
d_answers: list[str] = []
for chunk in context_docs:
answer, quote = self._answer_one_chunk(query, chunk)
if answer is not None and quote is not None:
d_answers.append(answer)
danswer_quotes.append(quote)
answers_list = [
f"Answer {ind}: {answer.strip()}"
for ind, answer in enumerate(d_answers, start=1)
]
combined_answer = "\n".join(answers_list)
return DanswerAnswer(answer=combined_answer), DanswerQuotes(
quotes=danswer_quotes
)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
quotes: list[DanswerQuote] = []
answers: list[str] = []
for chunk in context_docs:
answer, quote = self._answer_one_chunk(query, chunk)
if answer is not None and quote is not None:
answers.append(answer)
quotes.append(quote)
# Delay the output of the answers so there isn't long gap between first answer and quotes
answer_count = 1
for answer in answers:
if answer_count == 1:
yield DanswerAnswerPiece(answer_piece="Source 1: ")
else:
yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ")
answer_count += 1
for char in answer.strip():
yield DanswerAnswerPiece(answer_piece=char)
# signal end of answer
yield DanswerAnswerPiece(answer_piece=None)
yield DanswerQuotes(quotes=quotes)

View File

@@ -1,209 +0,0 @@
from abc import ABC
from collections.abc import Callable
from collections.abc import Generator
from copy import copy
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
import openai
import tiktoken
from openai.error import AuthenticationError
from openai.error import Timeout
from danswer.configs.app_configs import INCLUDE_METADATA
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import API_VERSION_OPENAI
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
F = TypeVar("F", bound=Callable)
if API_BASE_OPENAI:
openai.api_base = API_BASE_OPENAI
if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread"]
openai.api_type = API_TYPE_OPENAI
openai.api_version = API_VERSION_OPENAI
def _ensure_openai_api_key(api_key: str | None) -> str:
final_api_key = api_key or get_gen_ai_api_key()
if final_api_key is None:
raise OpenAIKeyMissing()
return final_api_key
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"temperature": 0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
**({"deployment_id": AZURE_DEPLOYMENT_ID} if AZURE_DEPLOYMENT_ID else {}),
**kwargs,
}
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
@wraps(openai_call)
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
try:
# if streamed, the call returns a generator
if kwargs.get("stream"):
def _generator() -> Generator[Any, None, None]:
yield from openai_call(*args, **kwargs)
return _generator()
return openai_call(*args, **kwargs)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Timeout:
logger.exception("OpenAI API timed out for query: %s", query)
raise
except Exception:
logger.exception("Unexpected error with OpenAI API for query: %s", query)
raise
return cast(F, wrapped_call)
def _tiktoken_trim_chunks(
chunks: list[InferenceChunk], model_version: str, max_chunk_toks: int = 512
) -> list[InferenceChunk]:
"""Edit chunks that have too high token count. Generally due to parsing issues or
characters from another language that are 1 char = 1 token
Trimming by tokens leads to information loss but currently no better way of handling
"""
encoder = tiktoken.encoding_for_model(model_version)
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
tokens = encoder.encode(chunk.content)
if len(tokens) > max_chunk_toks:
new_chunk = copy(chunk)
new_chunk.content = encoder.decode(tokens[:max_chunk_toks])
new_chunks[ind] = new_chunk
return new_chunks
# used to check if the QAModel is an OpenAI model
class OpenAIQAModel(QAModel, ABC):
pass
class OpenAICompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
api_key: str | None = None,
timeout: int | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.timeout = timeout
self.include_metadata = include_metadata
try:
self.api_key = api_key or get_gen_ai_api_key()
except ConfigNotFoundError:
raise OpenAIKeyMissing()
@staticmethod
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
for event in response:
yield event["choices"][0]["text"]
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(str, response["choices"][0]["text"]).strip()
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
logger.debug(model_output)
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)
tokens = self._generate_tokens_from_response(response)
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -28,7 +28,7 @@ from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.indexing.models import InferenceChunk
from danswer.llm.llm import LLM
from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import get_default_llm_tokenizer

View File

@@ -1,272 +0,0 @@
import abc
import json
from collections.abc import Callable
from collections.abc import Generator
import requests
from requests.exceptions import Timeout
from requests.models import Response
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.direct_qa.qa_utils import simulate_streaming_response
from danswer.indexing.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
class HostSpecificRequestModel(abc.ABC):
"""Provides a more minimal implementation requirement for extending to new models
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
"""
@property
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
return True
@staticmethod
@abc.abstractmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None,
max_output_tokens: int,
stream: bool,
timeout: int | None,
) -> Response:
"""Given a filled out prompt, how to send it to the model API with the
correct request format with the correct parameters"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def extract_model_output_from_response(
response: Response,
) -> str:
"""Extract the full model output text from a response.
This is for nonstreaming endpoints"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
"""Generate tokens from a streaming response
This is for streaming endpoints"""
raise NotImplementedError
class HuggingFaceRequestModel(HostSpecificRequestModel):
@staticmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None,
max_output_tokens: int,
stream: bool, # Not supported by Inference Endpoints (as of Aug 2023)
timeout: int | None,
) -> Response:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
data = {
"inputs": filled_prompt,
"parameters": {
# HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive
"temperature": 0.01,
# Skip the long tail
"top_p": 0.9,
"max_new_tokens": max_output_tokens,
},
}
try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except Timeout as error:
raise Timeout(f"Model inference to {endpoint} timed out") from error
@staticmethod
def _hf_extract_model_output(
response: Response,
) -> str:
if response.status_code != 200:
response.raise_for_status()
return json.loads(response.content)[0].get("generated_text", "")
@staticmethod
def extract_model_output_from_response(
response: Response,
) -> str:
return HuggingFaceRequestModel._hf_extract_model_output(response)
@staticmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
"""HF endpoints do not do streaming currently so this function will
simulate streaming for the meantime but will need to be replaced in
the future once streaming is enabled."""
model_out = HuggingFaceRequestModel._hf_extract_model_output(response)
yield from simulate_streaming_response(model_out)
class ColabDemoRequestModel(HostSpecificRequestModel):
"""Guide found at:
https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
"""
@property
def requires_api_key(self) -> bool:
return False
@staticmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None, # ngrok basic setup doesn't require this
max_output_tokens: int,
stream: bool,
timeout: int | None,
) -> Response:
headers = {
"Content-Type": "application/json",
}
data = {
"inputs": filled_prompt,
"parameters": {
"temperature": 0.0,
"max_tokens": max_output_tokens,
},
}
try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except Timeout as error:
raise Timeout(f"Model inference to {endpoint} timed out") from error
@staticmethod
def _colab_demo_extract_model_output(
response: Response,
) -> str:
if response.status_code != 200:
response.raise_for_status()
return json.loads(response.content).get("generated_text", "")
@staticmethod
def extract_model_output_from_response(
response: Response,
) -> str:
return ColabDemoRequestModel._colab_demo_extract_model_output(response)
@staticmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response)
yield from simulate_streaming_response(model_out)
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
if model_host_type == ModelHostType.HUGGINGFACE.value:
return HuggingFaceRequestModel()
if model_host_type == ModelHostType.COLAB_DEMO.value:
return ColabDemoRequestModel()
else:
# TODO support Azure, GCP, AWS
raise ValueError("Invalid model hosting service selected")
class RequestCompletionQA(QAModel):
def __init__(
self,
endpoint: str = GEN_AI_ENDPOINT,
model_host_type: str = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
) -> None:
self.endpoint = endpoint
self.api_key = api_key
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.model_class = get_host_specific_model_class(model_host_type)
self.timeout = timeout
@property
def requires_api_key(self) -> bool:
return self.model_class.requires_api_key
def _get_request_response(
self, query: str, context_docs: list[InferenceChunk], stream: bool
) -> Response:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, include_metadata=False
)
logger.debug(filled_prompt)
return self.model_class.send_model_request(
filled_prompt,
self.endpoint,
self.api_key,
self.max_output_tokens,
stream,
self.timeout,
)
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)
model_output = self.model_class.extract_model_output_from_response(
model_api_response
)
logger.debug(model_output)
return process_answer(model_output, context_docs)
def answer_question_stream(
self,
query: str,
context_docs: list[InferenceChunk],
) -> AnswerQuestionStreamReturn:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)
token_generator = self.model_class.generate_model_tokens_from_response(
model_api_response
)
yield from process_model_tokens(
tokens=token_generator,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

View File

@@ -1,46 +0,0 @@
from typing import Any
from langchain.chat_models.azure_openai import AzureChatOpenAI
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_VERSION_OPENAI
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
class AzureGPT(LangChainChatLLM):
def __init__(
self,
api_key: str,
max_output_tokens: int,
timeout: int,
model_version: str,
api_base: str = API_BASE_OPENAI,
api_version: str = API_VERSION_OPENAI,
deployment_name: str = AZURE_DEPLOYMENT_ID,
*args: list[Any],
**kwargs: dict[str, Any]
):
self._llm = AzureChatOpenAI(
model=model_version,
openai_api_type="azure",
openai_api_base=api_base,
openai_api_version=api_version,
deployment_name=deployment_name,
openai_api_key=api_key,
max_tokens=max_output_tokens,
temperature=0,
request_timeout=timeout,
model_kwargs={
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
verbose=should_be_verbose(),
max_retries=0, # retries are handled outside of langchain
)
@property
def llm(self) -> AzureChatOpenAI:
return self._llm

View File

@@ -1,48 +1,25 @@
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.llm.azure import AzureGPT
from danswer.llm.google_colab_demo import GoogleColabDemo
from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM
from danswer.llm.multi_llm import DefaultMultiLLM
def get_default_llm(
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
) -> LLM:
"""NOTE: api_key/timeout must be a special args since we may want to check
if an API key is valid for the default model setup OR we may want to use the
default model with a different timeout specified."""
"""A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults"""
if api_key is None:
api_key = get_gen_ai_api_key()
model_args = {
# provide a dummy key since LangChain will throw an exception if not
# given, which would prevent server startup
"api_key": api_key or "dummy_api_key",
"timeout": timeout,
"model_version": GEN_AI_MODEL_VERSION,
"endpoint": GEN_AI_ENDPOINT,
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
"temperature": GEN_AI_TEMPERATURE,
}
if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**model_args) # type: ignore
return OpenAIGPT(**model_args) # type: ignore
if (
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**model_args) # type: ignore
if GEN_AI_MODEL_PROVIDER.lower() == "custom":
return CustomModelServer(api_key=api_key, timeout=timeout)
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")
if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all":
DanswerGPT4All(timeout=timeout)
return DefaultMultiLLM(api_key=api_key, timeout=timeout)

View File

@@ -1,24 +1,36 @@
import json
from collections.abc import Iterator
from typing import Any
import requests
from langchain.schema.language_model import LanguageModelInput
from requests import Timeout
from danswer.llm.llm import LLM
from danswer.llm.utils import convert_input
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.llm.interfaces import LLM
from danswer.llm.utils import convert_lm_input_to_basic_string
class GoogleColabDemo(LLM):
class CustomModelServer(LLM):
"""This class is to provide an example for how to use Danswer
with any LLM, even servers with custom API definitions.
To use with your own model server, simply implement the functions
below to fit your model server expectation"""
def __init__(
self,
endpoint: str,
max_output_tokens: int,
# Not used here but you probably want a model server that isn't completely open
api_key: str | None,
timeout: int,
*args: list[Any],
**kwargs: dict[str, Any],
endpoint: str | None = GEN_AI_API_ENDPOINT,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
):
if not endpoint:
raise ValueError(
"Cannot point Danswer to a custom LLM server without providing the "
"endpoint for the model server."
)
self._endpoint = endpoint
self._max_output_tokens = max_output_tokens
self._timeout = timeout
@@ -29,7 +41,7 @@ class GoogleColabDemo(LLM):
}
data = {
"inputs": convert_input(input),
"inputs": convert_lm_input_to_basic_string(input),
"parameters": {
"temperature": 0.0,
"max_tokens": self._max_output_tokens,

View File

@@ -0,0 +1,60 @@
from collections.abc import Iterator
from typing import Any
from langchain.schema.language_model import LanguageModelInput
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.utils import convert_lm_input_to_basic_string
from danswer.utils.logger import setup_logger
logger = setup_logger()
class DummyGPT4All:
"""In the case of import failure due to architectural incompatibilities,
this module does not raise exceptions during server startup,
as long as the module isn't actually used"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("GPT4All library not installed.")
try:
from gpt4all import GPT4All # type:ignore
except ImportError:
# Setting a low log level because users get scared when they see this
logger.debug(
"GPT4All library not installed. "
"If you wish to run GPT4ALL (in memory) to power Danswer's "
"Generative AI features, please install gpt4all==2.0.2."
)
GPT4All = DummyGPT4All
class DanswerGPT4All(LLM):
"""Option to run an LLM locally, however this is significantly slower and
answers tend to be much worse"""
def __init__(
self,
timeout: int,
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
):
self.timeout = timeout
self.max_output_tokens = max_output_tokens
self.temperature = temperature
self.gpt4all_model = GPT4All(model_version)
def invoke(self, prompt: LanguageModelInput) -> str:
prompt_basic = convert_lm_input_to_basic_string(prompt)
return self.gpt4all_model.generate(prompt_basic)
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
prompt_basic = convert_lm_input_to_basic_string(prompt)
return self.gpt4all_model.generate(prompt_basic, streaming=True)

View File

@@ -0,0 +1,74 @@
import litellm # type:ignore
from langchain.chat_models import ChatLiteLLM
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params = True
litellm.telemetry = False
def _get_model_str(
model_provider: str | None,
model_version: str | None,
) -> str:
if model_provider and model_version:
return model_provider + "/" + model_version
if model_version:
# Litellm defaults to openai if no provider specified
# It's implicit so no need to specify here either
return model_version
# User specified something wrong, just use Danswer default
return GEN_AI_MODEL_VERSION
class DefaultMultiLLM(LangChainChatLLM):
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
See https://python.langchain.com/docs/integrations/chat/litellm"""
DEFAULT_MODEL_PARAMS = {
"frequency_penalty": 0,
"presence_penalty": 0,
}
def __init__(
self,
api_key: str | None,
timeout: int,
model_provider: str | None = GEN_AI_MODEL_PROVIDER,
model_version: str | None = GEN_AI_MODEL_VERSION,
api_base: str | None = GEN_AI_API_ENDPOINT,
api_version: str | None = GEN_AI_API_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
):
# Litellm Langchain integration currently doesn't take in the api key param
# Can place this in the call below once integration is in
litellm.api_key = api_key
litellm.api_version = api_version
self._llm = ChatLiteLLM( # type: ignore
model=_get_model_str(model_provider, model_version),
api_base=api_base,
max_tokens=max_output_tokens,
temperature=temperature,
request_timeout=timeout,
model_kwargs=DefaultMultiLLM.DEFAULT_MODEL_PARAMS,
verbose=should_be_verbose(),
max_retries=0, # retries are handled outside of langchain
)
@property
def llm(self) -> ChatLiteLLM:
return self._llm

View File

@@ -1,50 +0,0 @@
from typing import Any
from typing import cast
from langchain.chat_models import ChatLiteLLM
import litellm # type:ignore
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params=True
class OpenAIGPT(LangChainChatLLM):
DEFAULT_MODEL_PARAMS = {
"frequency_penalty": 0,
"presence_penalty": 0,
}
def __init__(
self,
api_key: str,
max_output_tokens: int,
timeout: int,
model_version: str,
temperature: float = GEN_AI_TEMPERATURE,
*args: list[Any],
**kwargs: dict[str, Any]
):
litellm.api_key = api_key
self._llm = ChatLiteLLM( # type: ignore
model=model_version,
# Prefer using None which is the default value, endpoint could be empty string
api_base=cast(str, kwargs.get("endpoint")) or None,
max_tokens=max_output_tokens,
temperature=temperature,
request_timeout=timeout,
model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS,
verbose=should_be_verbose(),
max_retries=0, # retries are handled outside of langchain
)
@property
def llm(self) -> ChatLiteLLM:
return self._llm

View File

@@ -22,6 +22,7 @@ _LLM_TOKENIZER: Callable[[str], Any] | None = None
def get_default_llm_tokenizer() -> Callable:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode
@@ -71,14 +72,7 @@ def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]
def message_generator_to_string_generator(
messages: Iterator[BaseMessageChunk],
) -> Iterator[str]:
for message in messages:
yield message.content
def convert_input(lm_input: LanguageModelInput) -> str:
def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str:
"""Heavily inspired by:
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
"""
@@ -99,6 +93,13 @@ def convert_input(lm_input: LanguageModelInput) -> str:
return prompt_value.to_string()
def message_generator_to_string_generator(
messages: Iterator[BaseMessageChunk],
) -> Iterator[str]:
for message in messages:
yield message.content
def should_be_verbose() -> bool:
return LOG_LEVEL == "debug"

View File

@@ -22,13 +22,12 @@ from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.llm_utils import get_default_qa_model
@@ -152,14 +151,6 @@ def get_application() -> FastAPI:
warm_up_models,
)
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
if API_TYPE_OPENAI == "azure":
logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}")
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
@@ -169,6 +160,14 @@ def get_application() -> FastAPI:
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
logger.info("Both OAuth Client ID and Secret are configured.")
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}")
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
if SKIP_RERANKING:
logger.info("Reranking step of search flow is disabled")

View File

@@ -25,7 +25,7 @@ from danswer.db.feedback import update_document_hidden
from danswer.db.models import User
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.open_ai import get_gen_ai_api_key
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey

View File

@@ -19,8 +19,6 @@ from danswer.db.feedback import update_query_event_feedback
from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.llm_utils import get_default_qa_model
@@ -302,7 +300,7 @@ def stream_direct_qa(
try:
qa_model = get_default_qa_model()
except (UnknownModelError, OpenAIKeyMissing) as e:
except Exception as e:
logger.exception("Unable to get QA model")
error = StreamingError(error=str(e))
yield get_json_line(error.dict())

View File

@@ -13,9 +13,9 @@ filelock==3.12.0
google-api-python-client==2.86.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==1.0.0
# GPT4All library does not support M1 Mac architecture
# GPT4All library has issues running on Macs and python:3.11.4-slim-bookworm
# will reintroduce this when library version catches up
# gpt4all==1.0.5
# gpt4all==2.0.2
httpcore==0.16.3
httpx==0.23.3
httpx-oauth==0.11.2