mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 21:33:56 +02:00
Reworking the LLM layer (#666)
This commit is contained in:
@@ -34,7 +34,7 @@ from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.document_index import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
|
@@ -77,21 +77,6 @@ class DocumentIndexType(str, Enum):
|
||||
SPLIT = "split" # Typesense + Qdrant
|
||||
|
||||
|
||||
class DanswerGenAIModel(str, Enum):
|
||||
"""This represents the internal Danswer GenAI model which determines the class that is used
|
||||
to generate responses to the user query. Different models/services require different internal
|
||||
handling, this allows for modularity of implementation within Danswer"""
|
||||
|
||||
OPENAI = "openai-completion"
|
||||
OPENAI_CHAT = "openai-chat-completion"
|
||||
GPT4ALL = "gpt4all-completion"
|
||||
GPT4ALL_CHAT = "gpt4all-chat-completion"
|
||||
HUGGINGFACE = "huggingface-client-completion"
|
||||
HUGGINGFACE_CHAT = "huggingface-client-chat-completion"
|
||||
REQUEST = "request-completion"
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
class AuthType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
BASIC = "basic"
|
||||
@@ -100,17 +85,6 @@ class AuthType(str, Enum):
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class ModelHostType(str, Enum):
|
||||
"""For GenAI models interfaced via requests, different services have different
|
||||
expectations for what fields are included in the request"""
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
|
||||
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
|
||||
COLAB_DEMO = "colab-demo"
|
||||
# TODO support for Azure, AWS, GCP GenAI model hosting
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
|
@@ -1,9 +1,5 @@
|
||||
import os
|
||||
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
|
||||
|
||||
#####
|
||||
# Embedding/Reranking Model Configs
|
||||
#####
|
||||
@@ -55,62 +51,38 @@ SEARCH_DISTANCE_CUTOFF = 0
|
||||
# Intent model max context size
|
||||
QUERY_MAX_CONTEXT_SIZE = 256
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
||||
|
||||
#####
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
# Other models should work as well, check the library/API compatibility.
|
||||
# But these are the models that have been verified to work with the existing prompts.
|
||||
# Using a different model may require some prompt tuning. See qa_prompts.py
|
||||
VERIFIED_MODELS = {
|
||||
DanswerGenAIModel.OPENAI: ["text-davinci-003"],
|
||||
DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"],
|
||||
DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
||||
DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
||||
# The "chat" model below is actually "instruction finetuned" and does not support conversational
|
||||
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
|
||||
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
|
||||
# Created by Deepset.ai
|
||||
# https://huggingface.co/deepset/deberta-v3-large-squad2
|
||||
# Model provided with no modifications
|
||||
DanswerGenAIModel.TRANSFORMERS.value: ["deepset/deberta-v3-large-squad2"],
|
||||
}
|
||||
|
||||
# Sets the internal Danswer model class to use
|
||||
INTERNAL_MODEL_VERSION = os.environ.get(
|
||||
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
|
||||
)
|
||||
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
|
||||
# be sure to use one that is LiteLLM compatible:
|
||||
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
|
||||
# The provider is the prefix before / in the model argument
|
||||
|
||||
# Additionally Danswer supports GPT4All and custom request library based models
|
||||
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
|
||||
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
|
||||
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||
# If using Azure, it's the engine name, for example: Danswer
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
|
||||
GEN_AI_MODEL_VERSION = os.environ.get(
|
||||
"GEN_AI_MODEL_VERSION",
|
||||
VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0],
|
||||
GEN_AI_API_KEY = (
|
||||
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
|
||||
)
|
||||
|
||||
# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
|
||||
# set the two below to specify
|
||||
# - Where to hit the endpoint
|
||||
# - How should the request be formed
|
||||
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
|
||||
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
|
||||
# API Base, such as (for Azure): https://danswer.openai.azure.com/
|
||||
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
# API Version, such as (for Azure): 2023-09-15-preview
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# This next restriction is only used for chat ATM, used to expire old messages as needed
|
||||
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
|
||||
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
||||
#####
|
||||
# OpenAI Azure
|
||||
#####
|
||||
API_BASE_OPENAI = os.environ.get("API_BASE_OPENAI", "")
|
||||
API_TYPE_OPENAI = os.environ.get("API_TYPE_OPENAI", "").lower()
|
||||
API_VERSION_OPENAI = os.environ.get("API_VERSION_OPENAI", "")
|
||||
# Deployment ID used interchangeably with "engine" parameter
|
||||
AZURE_DEPLOYMENT_ID = os.environ.get("AZURE_DEPLOYMENT_ID", "")
|
||||
|
@@ -9,8 +9,6 @@ from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.feedback import update_query_event_retrieved_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
@@ -132,7 +130,7 @@ def answer_qa_query(
|
||||
qa_model = get_default_qa_model(
|
||||
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
||||
)
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
except Exception as e:
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
quotes=None,
|
||||
|
@@ -1,13 +0,0 @@
|
||||
class OpenAIKeyMissing(Exception):
|
||||
default_msg = (
|
||||
"Unable to find existing OpenAI Key. "
|
||||
'A new key can be added from "Keys" section of the Admin Panel'
|
||||
)
|
||||
|
||||
def __init__(self, msg: str = default_msg) -> None:
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class UnknownModelError(Exception):
|
||||
def __init__(self, model_name: str) -> None:
|
||||
super().__init__(f"Unknown Internal QA model name: {model_name}")
|
@@ -1,209 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DummyGPT4All:
|
||||
"""In the case of import failure due to M1 Mac incompatibility,
|
||||
so this module does not raise exceptions during server startup,
|
||||
as long as this module isn't actually used"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise RuntimeError("GPT4All library not installed.")
|
||||
|
||||
|
||||
try:
|
||||
from gpt4all import GPT4All # type:ignore
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"GPT4All library not installed. "
|
||||
"If you wish to run GPT4ALL (in memory) to power Danswer's "
|
||||
"Generative AI features, please install gpt4all==1.0.5. "
|
||||
"As of Aug 2023, this library is not compatible with M1 Mac."
|
||||
)
|
||||
GPT4All = DummyGPT4All
|
||||
|
||||
|
||||
GPT4ALL_MODEL: GPT4All | None = None
|
||||
|
||||
|
||||
def get_gpt_4_all_model(
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
) -> GPT4All:
|
||||
global GPT4ALL_MODEL
|
||||
if GPT4ALL_MODEL is None:
|
||||
GPT4ALL_MODEL = GPT4All(model_version)
|
||||
return GPT4ALL_MODEL
|
||||
|
||||
|
||||
def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"temp": 0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class GPT4AllCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False, # gpt4all models can't handle this atm
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return False
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
get_gpt_4_all_model(self.model_version)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
model_output = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=filled_prompt, max_tokens=self.max_output_tokens
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
model_stream = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True
|
||||
),
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
||||
|
||||
class GPT4AllChatCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False, # gpt4all models can't handle this atm
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return False
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
get_gpt_4_all_model(self.model_version)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
with gen_ai_model.chat_session():
|
||||
context_msgs = filled_prompt[:-1]
|
||||
user_query = filled_prompt[-1].get("content")
|
||||
for message in context_msgs:
|
||||
gen_ai_model.current_chat_session.append(message)
|
||||
|
||||
model_output = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=user_query, max_tokens=self.max_output_tokens
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
with gen_ai_model.chat_session():
|
||||
context_msgs = filled_prompt[:-1]
|
||||
user_query = filled_prompt[-1].get("content")
|
||||
for message in context_msgs:
|
||||
gen_ai_model.current_chat_session.append(message)
|
||||
|
||||
model_stream = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=user_query, max_tokens=self.max_output_tokens
|
||||
),
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -1,195 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import InferenceClient # type:ignore
|
||||
from huggingface_hub.utils import HfHubHTTPError # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import FreeformProcessor
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.direct_qa.qa_utils import simulate_streaming_response
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"do_sample": False,
|
||||
"seed": 69, # For reproducibility
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class HuggingFaceCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: NonChatPromptProcessor = FreeformProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.client = InferenceClient(model=model_version, token=api_key)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
model_output = self.client.text_generation(
|
||||
filled_prompt,
|
||||
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
model_stream = self.client.text_generation(
|
||||
filled_prompt,
|
||||
**_build_hf_inference_settings(
|
||||
max_new_tokens=self.max_output_tokens, stream=True
|
||||
),
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceChatCompletionQA(QAModel):
|
||||
"""Chat in this class refers to the HuggingFace Conversational API.
|
||||
Not to be confused with Chat/Instruction finetuned models.
|
||||
Llama2-chat... means it is an Instruction finetuned model, not necessarily that
|
||||
it supports Conversational API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.client = InferenceClient(model=model_version, token=api_key)
|
||||
|
||||
@staticmethod
|
||||
def _convert_chat_to_hf_conversational_format(
|
||||
dialog: list[dict[str, str]]
|
||||
) -> tuple[str, list[str], list[str]]:
|
||||
if dialog[-1]["role"] != "user":
|
||||
raise Exception(
|
||||
"Last message in conversational dialog must be User message"
|
||||
)
|
||||
user_message = dialog[-1]["content"]
|
||||
dialog = dialog[0:-1]
|
||||
generated_responses = []
|
||||
past_user_inputs = []
|
||||
for message in dialog:
|
||||
# HuggingFace inference client doesn't support system messages today
|
||||
# so lumping them in with user messages
|
||||
if message["role"] in ["user", "system"]:
|
||||
past_user_inputs += [message["content"]]
|
||||
else:
|
||||
generated_responses += [message["content"]]
|
||||
return user_message, generated_responses, past_user_inputs
|
||||
|
||||
def _get_hf_model_output(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> str:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
|
||||
(
|
||||
query,
|
||||
past_responses,
|
||||
past_inputs,
|
||||
) = self._convert_chat_to_hf_conversational_format(filled_prompt)
|
||||
|
||||
logger.debug(f"Last Input: {query}")
|
||||
logger.debug(f"Past Inputs: {past_inputs}")
|
||||
logger.debug(f"Past Responses: {past_responses}")
|
||||
try:
|
||||
model_output = self.client.conversational(
|
||||
query,
|
||||
generated_responses=past_responses,
|
||||
past_user_inputs=past_inputs,
|
||||
parameters={"max_length": self.max_output_tokens},
|
||||
)
|
||||
except HfHubHTTPError as model_error:
|
||||
if model_error.response.status_code == 422:
|
||||
raise ValueError(
|
||||
"Selected HuggingFace Model does not support HuggingFace Conversational API,"
|
||||
"try using the huggingface-inference-completion in Danswer instead"
|
||||
)
|
||||
raise
|
||||
logger.debug(model_output)
|
||||
|
||||
return model_output
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
model_output = self._get_hf_model_output(query, context_docs)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
|
||||
So here it is faked by streaming characters within Danswer from the model output
|
||||
"""
|
||||
model_output = self._get_hf_model_output(query, context_docs)
|
||||
|
||||
model_stream = simulate_streaming_response(model_output)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -52,6 +52,7 @@ class QAModel:
|
||||
def requires_api_key(self) -> bool:
|
||||
"""Is this model protected by security features
|
||||
Does it need an api key to access the model for inference"""
|
||||
# TODO, this should be false for custom request model and gpt4all
|
||||
return True
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
|
@@ -1,32 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
import pkg_resources
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.local_transformers import TransformerQA
|
||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||
from danswer.direct_qa.qa_block import QABlock
|
||||
from danswer.direct_qa.qa_block import QAHandler
|
||||
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
||||
from danswer.direct_qa.qa_block import SingleMessageQAHandler
|
||||
from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -52,93 +31,23 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler:
|
||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
return (
|
||||
SingleMessageQAHandler()
|
||||
if real_time_flow
|
||||
else SingleMessageScratchpadHandler()
|
||||
)
|
||||
|
||||
return SimpleChatQAHandler()
|
||||
# TODO introduce the prompt choice parameter
|
||||
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
|
||||
return (
|
||||
SingleMessageQAHandler() if real_time_flow else SingleMessageScratchpadHandler()
|
||||
)
|
||||
# return SimpleChatQAHandler()
|
||||
|
||||
|
||||
def get_default_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
real_time_flow: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> QAModel:
|
||||
if not api_key:
|
||||
try:
|
||||
api_key = get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
llm = get_default_llm(api_key=api_key, timeout=timeout)
|
||||
qa_handler = get_default_qa_handler(real_time_flow=real_time_flow)
|
||||
|
||||
try:
|
||||
# un-used arguments will be ignored by the underlying `LLM` class
|
||||
# if any args are missing, a `TypeError` will be thrown
|
||||
llm = get_default_llm(timeout=timeout)
|
||||
qa_handler = get_default_qa_handler(
|
||||
model=internal_model, real_time_flow=real_time_flow
|
||||
)
|
||||
|
||||
return QABlock(
|
||||
llm=llm,
|
||||
qa_handler=qa_handler,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unable to build a QABlock with the new approach, going back to the "
|
||||
"legacy approach"
|
||||
)
|
||||
|
||||
if internal_model in [
|
||||
DanswerGenAIModel.GPT4ALL.value,
|
||||
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
||||
]:
|
||||
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
|
||||
pkg_resources.get_distribution("gpt4all")
|
||||
|
||||
if internal_model == DanswerGenAIModel.OPENAI.value:
|
||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
||||
return GPT4AllCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
||||
return GPT4AllChatCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
|
||||
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
|
||||
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
|
||||
return TransformerQA()
|
||||
elif internal_model == DanswerGenAIModel.REQUEST.value:
|
||||
if endpoint is None or model_host_type is None:
|
||||
raise ValueError(
|
||||
"Request based GenAI model requires an endpoint and host type"
|
||||
)
|
||||
if (
|
||||
model_host_type == ModelHostType.HUGGINGFACE.value
|
||||
or model_host_type == ModelHostType.COLAB_DEMO.value
|
||||
):
|
||||
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
||||
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
||||
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
prompt_processor=WeakModelFreeformProcessor(),
|
||||
timeout=timeout,
|
||||
)
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise UnknownModelError(internal_model)
|
||||
return QABlock(
|
||||
llm=llm,
|
||||
qa_handler=qa_handler,
|
||||
)
|
||||
|
@@ -1,156 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from transformers import pipeline # type:ignore
|
||||
from transformers import QuestionAnsweringPipeline # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
TRANSFORMER_DEFAULT_MAX_CONTEXT = 512
|
||||
|
||||
_TRANSFORMER_MODEL: QuestionAnsweringPipeline | None = None
|
||||
|
||||
|
||||
def get_default_transformer_model(
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_context: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
|
||||
) -> QuestionAnsweringPipeline:
|
||||
global _TRANSFORMER_MODEL
|
||||
if _TRANSFORMER_MODEL is None:
|
||||
_TRANSFORMER_MODEL = pipeline(
|
||||
"question-answering", model=model_version, max_seq_len=max_context
|
||||
)
|
||||
|
||||
return _TRANSFORMER_MODEL
|
||||
|
||||
|
||||
def find_extended_answer(answer: str, context: str) -> str:
|
||||
"""Try to extend the answer by matching across the context text and extending before
|
||||
and after the quote to some termination character"""
|
||||
result = re.search(
|
||||
r"(^|\n\r?|\.)(?P<content>[^\n]{{0,250}}{}[^\n]{{0,250}})(\.|$|\n\r?)".format(
|
||||
re.escape(answer)
|
||||
),
|
||||
context,
|
||||
flags=re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if result:
|
||||
return result.group("content")
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
class TransformerQA(QAModel):
|
||||
@staticmethod
|
||||
def _answer_one_chunk(
|
||||
query: str,
|
||||
chunk: InferenceChunk,
|
||||
max_context_len: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
|
||||
max_cutoff: float = 0.9,
|
||||
min_cutoff: float = 0.5,
|
||||
) -> tuple[str | None, DanswerQuote | None]:
|
||||
"""Because this type of QA model only takes 1 chunk of context with a fairly small token limit
|
||||
We have to iterate the checks and check if the answer is found in any of the chunks.
|
||||
This type of approach does not allow for interpolating answers across chunks
|
||||
"""
|
||||
model = get_default_transformer_model()
|
||||
model_out = model(question=query, context=chunk.content, max_answer_len=128)
|
||||
|
||||
answer = model_out.get("answer")
|
||||
confidence = model_out.get("score")
|
||||
|
||||
if answer is None:
|
||||
return None, None
|
||||
|
||||
logger.info(f"Transformer Answer: {answer}")
|
||||
logger.debug(f"Transformer Confidence: {confidence}")
|
||||
|
||||
# Model tends to be overconfident on short chunks
|
||||
# so min score required increases as chunk size decreases
|
||||
# If it's at least 0.9, then it's good enough regardless
|
||||
# Default minimum of 0.5 required
|
||||
score_cutoff = max(
|
||||
min(max_cutoff, 1 - len(chunk.content) / max_context_len), min_cutoff
|
||||
)
|
||||
if confidence < score_cutoff:
|
||||
return None, None
|
||||
|
||||
extended_answer = find_extended_answer(answer, chunk.content)
|
||||
|
||||
danswer_quote = DanswerQuote(
|
||||
quote=answer,
|
||||
document_id=chunk.document_id,
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
source_type=chunk.source_type,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
blurb=chunk.blurb,
|
||||
)
|
||||
|
||||
return extended_answer, danswer_quote
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
get_default_transformer_model()
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
danswer_quotes: list[DanswerQuote] = []
|
||||
d_answers: list[str] = []
|
||||
for chunk in context_docs:
|
||||
answer, quote = self._answer_one_chunk(query, chunk)
|
||||
if answer is not None and quote is not None:
|
||||
d_answers.append(answer)
|
||||
danswer_quotes.append(quote)
|
||||
|
||||
answers_list = [
|
||||
f"Answer {ind}: {answer.strip()}"
|
||||
for ind, answer in enumerate(d_answers, start=1)
|
||||
]
|
||||
combined_answer = "\n".join(answers_list)
|
||||
return DanswerAnswer(answer=combined_answer), DanswerQuotes(
|
||||
quotes=danswer_quotes
|
||||
)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
quotes: list[DanswerQuote] = []
|
||||
answers: list[str] = []
|
||||
for chunk in context_docs:
|
||||
answer, quote = self._answer_one_chunk(query, chunk)
|
||||
if answer is not None and quote is not None:
|
||||
answers.append(answer)
|
||||
quotes.append(quote)
|
||||
|
||||
# Delay the output of the answers so there isn't long gap between first answer and quotes
|
||||
answer_count = 1
|
||||
for answer in answers:
|
||||
if answer_count == 1:
|
||||
yield DanswerAnswerPiece(answer_piece="Source 1: ")
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ")
|
||||
answer_count += 1
|
||||
for char in answer.strip():
|
||||
yield DanswerAnswerPiece(answer_piece=char)
|
||||
|
||||
# signal end of answer
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
|
||||
yield DanswerQuotes(quotes=quotes)
|
@@ -1,209 +0,0 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from copy import copy
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
from danswer.configs.app_configs import INCLUDE_METADATA
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.configs.model_configs import API_VERSION_OPENAI
|
||||
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
if API_BASE_OPENAI:
|
||||
openai.api_base = API_BASE_OPENAI
|
||||
if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread"]
|
||||
openai.api_type = API_TYPE_OPENAI
|
||||
openai.api_version = API_VERSION_OPENAI
|
||||
|
||||
|
||||
def _ensure_openai_api_key(api_key: str | None) -> str:
|
||||
final_api_key = api_key or get_gen_ai_api_key()
|
||||
if final_api_key is None:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
return final_api_key
|
||||
|
||||
|
||||
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
**({"deployment_id": AZURE_DEPLOYMENT_ID} if AZURE_DEPLOYMENT_ID else {}),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
||||
@wraps(openai_call)
|
||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
try:
|
||||
# if streamed, the call returns a generator
|
||||
if kwargs.get("stream"):
|
||||
|
||||
def _generator() -> Generator[Any, None, None]:
|
||||
yield from openai_call(*args, **kwargs)
|
||||
|
||||
return _generator()
|
||||
return openai_call(*args, **kwargs)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Timeout:
|
||||
logger.exception("OpenAI API timed out for query: %s", query)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _tiktoken_trim_chunks(
|
||||
chunks: list[InferenceChunk], model_version: str, max_chunk_toks: int = 512
|
||||
) -> list[InferenceChunk]:
|
||||
"""Edit chunks that have too high token count. Generally due to parsing issues or
|
||||
characters from another language that are 1 char = 1 token
|
||||
Trimming by tokens leads to information loss but currently no better way of handling
|
||||
"""
|
||||
encoder = tiktoken.encoding_for_model(model_version)
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
tokens = encoder.encode(chunk.content)
|
||||
if len(tokens) > max_chunk_toks:
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = encoder.decode(tokens[:max_chunk_toks])
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
|
||||
|
||||
# used to check if the QAModel is an OpenAI model
|
||||
class OpenAIQAModel(QAModel, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAICompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
api_key: str | None = None,
|
||||
timeout: int | None = None,
|
||||
include_metadata: bool = INCLUDE_METADATA,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.timeout = timeout
|
||||
self.include_metadata = include_metadata
|
||||
try:
|
||||
self.api_key = api_key or get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
@staticmethod
|
||||
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
||||
for event in response:
|
||||
yield event["choices"][0]["text"]
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
prompt=filled_prompt,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = cast(str, response["choices"][0]["text"]).strip()
|
||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
||||
logger.debug(model_output)
|
||||
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
prompt=filled_prompt,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
|
||||
tokens = self._generate_tokens_from_response(response)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -28,7 +28,7 @@ from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
|
@@ -1,272 +0,0 @@
|
||||
import abc
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
import requests
|
||||
from requests.exceptions import Timeout
|
||||
from requests.models import Response
|
||||
|
||||
from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.direct_qa.qa_utils import simulate_streaming_response
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HostSpecificRequestModel(abc.ABC):
|
||||
"""Provides a more minimal implementation requirement for extending to new models
|
||||
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
|
||||
"""
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
"""Is this model protected by security features
|
||||
Does it need an api key to access the model for inference"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def send_model_request(
|
||||
filled_prompt: str,
|
||||
endpoint: str,
|
||||
api_key: str | None,
|
||||
max_output_tokens: int,
|
||||
stream: bool,
|
||||
timeout: int | None,
|
||||
) -> Response:
|
||||
"""Given a filled out prompt, how to send it to the model API with the
|
||||
correct request format with the correct parameters"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def extract_model_output_from_response(
|
||||
response: Response,
|
||||
) -> str:
|
||||
"""Extract the full model output text from a response.
|
||||
This is for nonstreaming endpoints"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def generate_model_tokens_from_response(
|
||||
response: Response,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate tokens from a streaming response
|
||||
This is for streaming endpoints"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HuggingFaceRequestModel(HostSpecificRequestModel):
|
||||
@staticmethod
|
||||
def send_model_request(
|
||||
filled_prompt: str,
|
||||
endpoint: str,
|
||||
api_key: str | None,
|
||||
max_output_tokens: int,
|
||||
stream: bool, # Not supported by Inference Endpoints (as of Aug 2023)
|
||||
timeout: int | None,
|
||||
) -> Response:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": filled_prompt,
|
||||
"parameters": {
|
||||
# HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive
|
||||
"temperature": 0.01,
|
||||
# Skip the long tail
|
||||
"top_p": 0.9,
|
||||
"max_new_tokens": max_output_tokens,
|
||||
},
|
||||
}
|
||||
try:
|
||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
||||
except Timeout as error:
|
||||
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
||||
|
||||
@staticmethod
|
||||
def _hf_extract_model_output(
|
||||
response: Response,
|
||||
) -> str:
|
||||
if response.status_code != 200:
|
||||
response.raise_for_status()
|
||||
|
||||
return json.loads(response.content)[0].get("generated_text", "")
|
||||
|
||||
@staticmethod
|
||||
def extract_model_output_from_response(
|
||||
response: Response,
|
||||
) -> str:
|
||||
return HuggingFaceRequestModel._hf_extract_model_output(response)
|
||||
|
||||
@staticmethod
|
||||
def generate_model_tokens_from_response(
|
||||
response: Response,
|
||||
) -> Generator[str, None, None]:
|
||||
"""HF endpoints do not do streaming currently so this function will
|
||||
simulate streaming for the meantime but will need to be replaced in
|
||||
the future once streaming is enabled."""
|
||||
model_out = HuggingFaceRequestModel._hf_extract_model_output(response)
|
||||
yield from simulate_streaming_response(model_out)
|
||||
|
||||
|
||||
class ColabDemoRequestModel(HostSpecificRequestModel):
|
||||
"""Guide found at:
|
||||
https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
|
||||
"""
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def send_model_request(
|
||||
filled_prompt: str,
|
||||
endpoint: str,
|
||||
api_key: str | None, # ngrok basic setup doesn't require this
|
||||
max_output_tokens: int,
|
||||
stream: bool,
|
||||
timeout: int | None,
|
||||
) -> Response:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": filled_prompt,
|
||||
"parameters": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": max_output_tokens,
|
||||
},
|
||||
}
|
||||
try:
|
||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
||||
except Timeout as error:
|
||||
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
||||
|
||||
@staticmethod
|
||||
def _colab_demo_extract_model_output(
|
||||
response: Response,
|
||||
) -> str:
|
||||
if response.status_code != 200:
|
||||
response.raise_for_status()
|
||||
|
||||
return json.loads(response.content).get("generated_text", "")
|
||||
|
||||
@staticmethod
|
||||
def extract_model_output_from_response(
|
||||
response: Response,
|
||||
) -> str:
|
||||
return ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
||||
|
||||
@staticmethod
|
||||
def generate_model_tokens_from_response(
|
||||
response: Response,
|
||||
) -> Generator[str, None, None]:
|
||||
model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
||||
yield from simulate_streaming_response(model_out)
|
||||
|
||||
|
||||
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
|
||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
||||
return HuggingFaceRequestModel()
|
||||
if model_host_type == ModelHostType.COLAB_DEMO.value:
|
||||
return ColabDemoRequestModel()
|
||||
else:
|
||||
# TODO support Azure, GCP, AWS
|
||||
raise ValueError("Invalid model hosting service selected")
|
||||
|
||||
|
||||
class RequestCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = GEN_AI_ENDPOINT,
|
||||
model_host_type: str = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.model_class = get_host_specific_model_class(model_host_type)
|
||||
self.timeout = timeout
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return self.model_class.requires_api_key
|
||||
|
||||
def _get_request_response(
|
||||
self, query: str, context_docs: list[InferenceChunk], stream: bool
|
||||
) -> Response:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, include_metadata=False
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
return self.model_class.send_model_request(
|
||||
filled_prompt,
|
||||
self.endpoint,
|
||||
self.api_key,
|
||||
self.max_output_tokens,
|
||||
stream,
|
||||
self.timeout,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
model_api_response = self._get_request_response(
|
||||
query, context_docs, stream=False
|
||||
)
|
||||
|
||||
model_output = self.model_class.extract_model_output_from_response(
|
||||
model_api_response
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
model_api_response = self._get_request_response(
|
||||
query, context_docs, stream=False
|
||||
)
|
||||
|
||||
token_generator = self.model_class.generate_model_tokens_from_response(
|
||||
model_api_response
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=token_generator,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
0
backend/danswer/llm/__init__.py
Normal file
0
backend/danswer/llm/__init__.py
Normal file
@@ -1,46 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_VERSION_OPENAI
|
||||
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
||||
from danswer.llm.llm import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
|
||||
|
||||
class AzureGPT(LangChainChatLLM):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
max_output_tokens: int,
|
||||
timeout: int,
|
||||
model_version: str,
|
||||
api_base: str = API_BASE_OPENAI,
|
||||
api_version: str = API_VERSION_OPENAI,
|
||||
deployment_name: str = AZURE_DEPLOYMENT_ID,
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
self._llm = AzureChatOpenAI(
|
||||
model=model_version,
|
||||
openai_api_type="azure",
|
||||
openai_api_base=api_base,
|
||||
openai_api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
openai_api_key=api_key,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=0,
|
||||
request_timeout=timeout,
|
||||
model_kwargs={
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
verbose=should_be_verbose(),
|
||||
max_retries=0, # retries are handled outside of langchain
|
||||
)
|
||||
|
||||
@property
|
||||
def llm(self) -> AzureChatOpenAI:
|
||||
return self._llm
|
@@ -1,48 +1,25 @@
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.llm.azure import AzureGPT
|
||||
from danswer.llm.google_colab_demo import GoogleColabDemo
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.openai import OpenAIGPT
|
||||
from danswer.llm.custom_llm import CustomModelServer
|
||||
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.multi_llm import DefaultMultiLLM
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
) -> LLM:
|
||||
"""NOTE: api_key/timeout must be a special args since we may want to check
|
||||
if an API key is valid for the default model setup OR we may want to use the
|
||||
default model with a different timeout specified."""
|
||||
"""A single place to fetch the configured LLM for Danswer
|
||||
Also allows overriding certain LLM defaults"""
|
||||
if api_key is None:
|
||||
api_key = get_gen_ai_api_key()
|
||||
|
||||
model_args = {
|
||||
# provide a dummy key since LangChain will throw an exception if not
|
||||
# given, which would prevent server startup
|
||||
"api_key": api_key or "dummy_api_key",
|
||||
"timeout": timeout,
|
||||
"model_version": GEN_AI_MODEL_VERSION,
|
||||
"endpoint": GEN_AI_ENDPOINT,
|
||||
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
"temperature": GEN_AI_TEMPERATURE,
|
||||
}
|
||||
if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
if API_TYPE_OPENAI == "azure":
|
||||
return AzureGPT(**model_args) # type: ignore
|
||||
return OpenAIGPT(**model_args) # type: ignore
|
||||
if (
|
||||
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
|
||||
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
|
||||
):
|
||||
return GoogleColabDemo(**model_args) # type: ignore
|
||||
if GEN_AI_MODEL_PROVIDER.lower() == "custom":
|
||||
return CustomModelServer(api_key=api_key, timeout=timeout)
|
||||
|
||||
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")
|
||||
if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all":
|
||||
DanswerGPT4All(timeout=timeout)
|
||||
|
||||
return DefaultMultiLLM(api_key=api_key, timeout=timeout)
|
||||
|
@@ -1,24 +1,36 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from requests import Timeout
|
||||
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import convert_input
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||
|
||||
|
||||
class GoogleColabDemo(LLM):
|
||||
class CustomModelServer(LLM):
|
||||
"""This class is to provide an example for how to use Danswer
|
||||
with any LLM, even servers with custom API definitions.
|
||||
To use with your own model server, simply implement the functions
|
||||
below to fit your model server expectation"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
max_output_tokens: int,
|
||||
# Not used here but you probably want a model server that isn't completely open
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any],
|
||||
endpoint: str | None = GEN_AI_API_ENDPOINT,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
):
|
||||
if not endpoint:
|
||||
raise ValueError(
|
||||
"Cannot point Danswer to a custom LLM server without providing the "
|
||||
"endpoint for the model server."
|
||||
)
|
||||
|
||||
self._endpoint = endpoint
|
||||
self._max_output_tokens = max_output_tokens
|
||||
self._timeout = timeout
|
||||
@@ -29,7 +41,7 @@ class GoogleColabDemo(LLM):
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": convert_input(input),
|
||||
"inputs": convert_lm_input_to_basic_string(input),
|
||||
"parameters": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self._max_output_tokens,
|
60
backend/danswer/llm/gpt_4_all.py
Normal file
60
backend/danswer/llm/gpt_4_all.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DummyGPT4All:
|
||||
"""In the case of import failure due to architectural incompatibilities,
|
||||
this module does not raise exceptions during server startup,
|
||||
as long as the module isn't actually used"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise RuntimeError("GPT4All library not installed.")
|
||||
|
||||
|
||||
try:
|
||||
from gpt4all import GPT4All # type:ignore
|
||||
except ImportError:
|
||||
# Setting a low log level because users get scared when they see this
|
||||
logger.debug(
|
||||
"GPT4All library not installed. "
|
||||
"If you wish to run GPT4ALL (in memory) to power Danswer's "
|
||||
"Generative AI features, please install gpt4all==2.0.2."
|
||||
)
|
||||
GPT4All = DummyGPT4All
|
||||
|
||||
|
||||
class DanswerGPT4All(LLM):
|
||||
"""Option to run an LLM locally, however this is significantly slower and
|
||||
answers tend to be much worse"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int,
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.temperature = temperature
|
||||
self.gpt4all_model = GPT4All(model_version)
|
||||
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
prompt_basic = convert_lm_input_to_basic_string(prompt)
|
||||
return self.gpt4all_model.generate(prompt_basic)
|
||||
|
||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||
prompt_basic = convert_lm_input_to_basic_string(prompt)
|
||||
return self.gpt4all_model.generate(prompt_basic, streaming=True)
|
74
backend/danswer/llm/multi_llm.py
Normal file
74
backend/danswer/llm/multi_llm.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import litellm # type:ignore
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
|
||||
def _get_model_str(
|
||||
model_provider: str | None,
|
||||
model_version: str | None,
|
||||
) -> str:
|
||||
if model_provider and model_version:
|
||||
return model_provider + "/" + model_version
|
||||
|
||||
if model_version:
|
||||
# Litellm defaults to openai if no provider specified
|
||||
# It's implicit so no need to specify here either
|
||||
return model_version
|
||||
|
||||
# User specified something wrong, just use Danswer default
|
||||
return GEN_AI_MODEL_VERSION
|
||||
|
||||
|
||||
class DefaultMultiLLM(LangChainChatLLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
|
||||
DEFAULT_MODEL_PARAMS = {
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
model_provider: str | None = GEN_AI_MODEL_PROVIDER,
|
||||
model_version: str | None = GEN_AI_MODEL_VERSION,
|
||||
api_base: str | None = GEN_AI_API_ENDPOINT,
|
||||
api_version: str | None = GEN_AI_API_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
):
|
||||
# Litellm Langchain integration currently doesn't take in the api key param
|
||||
# Can place this in the call below once integration is in
|
||||
litellm.api_key = api_key
|
||||
litellm.api_version = api_version
|
||||
|
||||
self._llm = ChatLiteLLM( # type: ignore
|
||||
model=_get_model_str(model_provider, model_version),
|
||||
api_base=api_base,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=temperature,
|
||||
request_timeout=timeout,
|
||||
model_kwargs=DefaultMultiLLM.DEFAULT_MODEL_PARAMS,
|
||||
verbose=should_be_verbose(),
|
||||
max_retries=0, # retries are handled outside of langchain
|
||||
)
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatLiteLLM:
|
||||
return self._llm
|
@@ -1,50 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
import litellm # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.llm import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params=True
|
||||
|
||||
|
||||
class OpenAIGPT(LangChainChatLLM):
|
||||
|
||||
DEFAULT_MODEL_PARAMS = {
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
max_output_tokens: int,
|
||||
timeout: int,
|
||||
model_version: str,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
litellm.api_key = api_key
|
||||
|
||||
self._llm = ChatLiteLLM( # type: ignore
|
||||
model=model_version,
|
||||
# Prefer using None which is the default value, endpoint could be empty string
|
||||
api_base=cast(str, kwargs.get("endpoint")) or None,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=temperature,
|
||||
request_timeout=timeout,
|
||||
model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS,
|
||||
verbose=should_be_verbose(),
|
||||
max_retries=0, # retries are handled outside of langchain
|
||||
)
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatLiteLLM:
|
||||
return self._llm
|
@@ -22,6 +22,7 @@ _LLM_TOKENIZER: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Callable:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode
|
||||
@@ -71,14 +72,7 @@ def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
||||
return [HumanMessage(content=message)]
|
||||
|
||||
|
||||
def message_generator_to_string_generator(
|
||||
messages: Iterator[BaseMessageChunk],
|
||||
) -> Iterator[str]:
|
||||
for message in messages:
|
||||
yield message.content
|
||||
|
||||
|
||||
def convert_input(lm_input: LanguageModelInput) -> str:
|
||||
def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str:
|
||||
"""Heavily inspired by:
|
||||
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
||||
"""
|
||||
@@ -99,6 +93,13 @@ def convert_input(lm_input: LanguageModelInput) -> str:
|
||||
return prompt_value.to_string()
|
||||
|
||||
|
||||
def message_generator_to_string_generator(
|
||||
messages: Iterator[BaseMessageChunk],
|
||||
) -> Iterator[str]:
|
||||
for message in messages:
|
||||
yield message.content
|
||||
|
||||
|
||||
def should_be_verbose() -> bool:
|
||||
return LOG_LEVEL == "debug"
|
||||
|
||||
|
@@ -22,13 +22,12 @@ from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
@@ -152,14 +151,6 @@ def get_application() -> FastAPI:
|
||||
warm_up_models,
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.info("Generative AI Q&A disabled")
|
||||
else:
|
||||
logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
|
||||
logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
|
||||
if API_TYPE_OPENAI == "azure":
|
||||
logger.info(f"Using Azure OpenAI with Endpoint: {API_BASE_OPENAI}")
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "verify_auth_setting"
|
||||
)
|
||||
@@ -169,6 +160,14 @@ def get_application() -> FastAPI:
|
||||
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
||||
logger.info("Both OAuth Client ID and Secret are configured.")
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.info("Generative AI Q&A disabled")
|
||||
else:
|
||||
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
||||
logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}")
|
||||
if GEN_AI_API_ENDPOINT:
|
||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||
|
||||
if SKIP_RERANKING:
|
||||
logger.info("Reranking step of search flow is disabled")
|
||||
|
||||
|
@@ -25,7 +25,7 @@ from danswer.db.feedback import update_document_hidden
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.open_ai import get_gen_ai_api_key
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.models import ApiKey
|
||||
|
@@ -19,8 +19,6 @@ from danswer.db.feedback import update_query_event_feedback
|
||||
from danswer.db.feedback import update_query_event_retrieved_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
@@ -302,7 +300,7 @@ def stream_direct_qa(
|
||||
|
||||
try:
|
||||
qa_model = get_default_qa_model()
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
except Exception as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
error = StreamingError(error=str(e))
|
||||
yield get_json_line(error.dict())
|
||||
|
@@ -13,9 +13,9 @@ filelock==3.12.0
|
||||
google-api-python-client==2.86.0
|
||||
google-auth-httplib2==0.1.0
|
||||
google-auth-oauthlib==1.0.0
|
||||
# GPT4All library does not support M1 Mac architecture
|
||||
# GPT4All library has issues running on Macs and python:3.11.4-slim-bookworm
|
||||
# will reintroduce this when library version catches up
|
||||
# gpt4all==1.0.5
|
||||
# gpt4all==2.0.2
|
||||
httpcore==0.16.3
|
||||
httpx==0.23.3
|
||||
httpx-oauth==0.11.2
|
||||
|
Reference in New Issue
Block a user