Support for Request accessed GenAI Models (#270)

This commit is contained in:
Yuhong Sun
2023-08-06 18:31:47 -07:00
committed by GitHub
parent 0e667d3384
commit 3bfc72484d
19 changed files with 613 additions and 351 deletions

View File

@@ -138,12 +138,6 @@ CHUNK_WORD_OVERLAP = 5
CHUNK_MAX_CHAR_OVERLAP = 50
#####
# Other API Keys
#####
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
#####
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
#####

View File

@@ -12,7 +12,7 @@ SECTION_CONTINUATION = "section_continuation"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"
METADATA = "metadata"
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key"
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
HTML_SEPARATOR = "\n"
PUBLIC_DOC_PAT = "PUBLIC"
@@ -30,3 +30,26 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard"
FILE = "file"
NOTION = "notion"
class DanswerGenAIModel(str, Enum):
"""This represents the internal Danswer GenAI model which determines the class that is used
to generate responses to the user query. Different models/services require different internal
handling, this allows for modularity of implementation within Danswer"""
OPENAI = "openai-completion"
OPENAI_CHAT = "openai-chat-completion"
GPT4ALL = "gpt4all-completion"
GPT4ALL_CHAT = "gpt4all-chat-completion"
HUGGINGFACE = "huggingface-inference-completion"
HUGGINGFACE_CHAT = "huggingface-inference-chat-completion"
REQUEST = "request-completion"
class ModelHostType(str, Enum):
"""For GenAI models interfaced via requests, different services have different
expectations for what fields are included in the request"""
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
# TODO support for Azure, AWS, GCP GenAI model hosting

View File

@@ -1,4 +1,8 @@
import os
from enum import Enum
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
@@ -30,35 +34,46 @@ CROSS_EMBED_CONTEXT_SIZE = 512
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# QA Model API Configs
# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models
# Valid list:
# - openai-completion
# - openai-chat-completion
# - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
# - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
# To use gpt4all, run: pip install --upgrade gpt4all==1.0.5
# These support HuggingFace Inference API, Inference Endpoints and servers running the text-generation-inference backend
# - huggingface-inference-completion
# - huggingface-inference-chat-completion
#####
# Generative AI Model Configs
#####
# Other models should work as well, check the library/API compatibility.
# But these are the models that have been verified to work with the existing prompts.
# Using a different model may require some prompt tuning. See qa_prompts.py
VERIFIED_MODELS = {
DanswerGenAIModel.OPENAI: ["text-davinci-003"],
DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"],
DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"],
DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"],
# The "chat" model below is actually "instruction finetuned" and does not support conversational
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
}
# Sets the internal Danswer model class to use
INTERNAL_MODEL_VERSION = os.environ.get(
"INTERNAL_MODEL_VERSION", "openai-chat-completion"
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
)
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", "")
# If using GPT4All or OpenAI, specify the model version
GEN_AI_MODEL_VERSION = os.environ.get(
"GEN_AI_MODEL_VERSION",
VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0],
)
# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
# set the two below to specify
# - Where to hit the endpoint
# - How should the request be formed
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
# Set this to be enough for an answer + quotes
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512"))
# Use HuggingFace API Token for Huggingface inference client
GEN_AI_HUGGINGFACE_API_TOKEN = os.environ.get("GEN_AI_HUGGINGFACE_API_TOKEN", None)
# Use the conversational API with the huggingface-inference-chat-completion internal model
# Note - this only works with models that support conversational interfaces
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL = (
os.environ.get("GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL", "").lower() == "true"
)
# Disable streaming responses. Set this to true to "polyfill" streaming for models that don't support streaming
GEN_AI_HUGGINGFACE_DISABLE_STREAM = (
os.environ.get("GEN_AI_HUGGINGFACE_DISABLE_STREAM", "").lower() == "true"
)
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"

View File

@@ -1,25 +1,30 @@
from typing import Any
import pkg_resources
from openai.error import AuthenticationError
from openai.error import Timeout
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import (
GEN_AI_HUGGINGFACE_API_TOKEN,
INTERNAL_MODEL_VERSION,
)
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.huggingface_inference import (
HuggingFaceInferenceChatCompletionQA,
HuggingFaceInferenceCompletionQA,
)
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
# Imports commented out temporarily due to incompatibility of gpt4all with M1 Mac hardware currently
# from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
# from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
logger = setup_logger()
def check_model_api_key_is_valid(model_api_key: str) -> bool:
@@ -35,32 +40,66 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
return True
except AuthenticationError:
return False
except Timeout:
pass
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
def get_default_backend_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
api_key: str | None = None,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
**kwargs: Any
**kwargs: Any,
) -> QAModel:
if internal_model == "openai-completion":
if not api_key:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
pass
if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
]:
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
pkg_resources.get_distribution("gpt4all")
if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "openai-chat-completion":
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "huggingface-inference-completion":
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
return HuggingFaceInferenceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == "huggingface-inference-chat-completion":
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
return HuggingFaceInferenceChatCompletionQA(api_key=api_key, **kwargs)
# Note GPT4All is not supported for M1 Mac machines currently, removing until support is added
# elif internal_model == "gpt4all-completion":
# return GPT4AllCompletionQA(**kwargs)
# elif internal_model == "gpt4all-chat-completion":
# return GPT4AllChatCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
return GPT4AllCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
return GPT4AllChatCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.REQUEST.value:
if endpoint is None or model_host_type is None:
raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if model_host_type == ModelHostType.HUGGINGFACE.value:
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
prompt_processor=WeakModelFreeformProcessor(),
timeout=timeout,
)
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
timeout=timeout,
)
else:
raise UnknownModelError(internal_model)

View File

@@ -1,8 +1,6 @@
from collections.abc import Generator
from typing import Any
from gpt4all import GPT4All # type:ignore
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
@@ -18,9 +16,30 @@ from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
class DummyGPT4All:
"""In the case of import failure due to M1 Mac incompatibility,
so this module does not raise exceptions during server startup,
as long as this module isn't actually used"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("GPT4All library not installed.")
try:
from gpt4all import GPT4All # type:ignore
except ImportError:
logger.warning(
"GPT4All library not installed. "
"If you wish to run GPT4ALL (in memory) to power Danswer's "
"Generative AI features, please install gpt4all==1.0.5. "
"As of Aug 2023, this library is not compatible with M1 Mac."
)
GPT4All = DummyGPT4All
GPT4ALL_MODEL: GPT4All | None = None
@@ -56,6 +75,10 @@ class GPT4AllCompletionQA(QAModel):
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
@property
def requires_api_key(self) -> bool:
return False
def warm_up_model(self) -> None:
get_gpt_4_all_model(self.model_version)
@@ -117,6 +140,13 @@ class GPT4AllChatCompletionQA(QAModel):
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
@property
def requires_api_key(self) -> bool:
return False
def warm_up_model(self) -> None:
get_gpt_4_all_model(self.model_version)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]

View File

@@ -0,0 +1,189 @@
from collections.abc import Generator
from typing import Any
from huggingface_hub import InferenceClient # type:ignore
from huggingface_hub.utils import HfHubHTTPError # type:ignore
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import FreeformProcessor
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.direct_qa.qa_utils import simulate_streaming_response
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"do_sample": False,
"seed": 69, # For reproducibility
**kwargs,
}
class HuggingFaceCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = FreeformProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_stream = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class HuggingFaceChatCompletionQA(QAModel):
"""Chat in this class refers to the HuggingFace Conversational API.
Not to be confused with Chat/Instruction finetuned models.
Llama2-chat... means it is an Instruction finetuned model, not necessarily that
it supports Conversational API"""
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@staticmethod
def _convert_chat_to_hf_conversational_format(
dialog: list[dict[str, str]]
) -> tuple[str, list[str], list[str]]:
if dialog[-1]["role"] != "user":
raise Exception(
"Last message in conversational dialog must be User message"
)
user_message = dialog[-1]["content"]
dialog = dialog[0:-1]
generated_responses = []
past_user_inputs = []
for message in dialog:
# HuggingFace inference client doesn't support system messages today
# so lumping them in with user messages
if message["role"] in ["user", "system"]:
past_user_inputs += [message["content"]]
else:
generated_responses += [message["content"]]
return user_message, generated_responses, past_user_inputs
def _get_hf_model_output(
self, query: str, context_docs: list[InferenceChunk]
) -> str:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
(
query,
past_responses,
past_inputs,
) = self._convert_chat_to_hf_conversational_format(filled_prompt)
logger.debug(f"Last Input: {query}")
logger.debug(f"Past Inputs: {past_inputs}")
logger.debug(f"Past Responses: {past_responses}")
try:
model_output = self.client.conversational(
query,
generated_responses=past_responses,
past_user_inputs=past_inputs,
parameters={"max_length": self.max_output_tokens},
)
except HfHubHTTPError as model_error:
if model_error.response.status_code == 422:
raise ValueError(
"Selected HuggingFace Model does not support HuggingFace Conversational API,"
"try using the huggingface-inference-completion in Danswer instead"
)
raise
logger.debug(model_output)
return model_output
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
model_output = self._get_hf_model_output(query, context_docs)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
So here it is faked by streaming characters within Danswer from the model output
"""
model_output = self._get_hf_model_output(query, context_docs)
model_stream = simulate_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -1,248 +0,0 @@
from collections.abc import Generator
from typing import Any, Iterator
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import (
GEN_AI_HUGGINGFACE_DISABLE_STREAM,
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL,
GEN_AI_MAX_OUTPUT_TOKENS,
)
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import (
ChatPromptProcessor,
JsonChatProcessor,
JsonProcessor,
)
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from huggingface_hub import InferenceClient
logger = setup_logger()
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"do_sample": False,
"seed": 69, # For reproducibility
**kwargs,
}
def _generic_chat_dialog_to_prompt_formatter(dialog: list[dict[str, str]]) -> str:
"""
Utility to convert chat dialog to a text-generation prompt for models tuned for chat.
Note - This is a "best guess" attempt at a generic completions prompt for chat
completion models. It isn't optimized for all chat trained models, but tries
to serialize to a format that most models understand.
Models like Llama2-chat have been optimized for certain formatting of chat
completions, and this function doesn't take that into account, so you won't
always get the best possible outcome.
TODO - Add the ability to pass custom formatters for chat dialogue
"""
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.
If a question does not make any sense or is not factually coherent, explain why instead of answering incorrectly.
If you don't know the answer to a question, don't share false information."""
prompt = ""
if dialog[0]["role"] != "system":
dialog = [
{
"role": "system",
"content": DEFAULT_SYSTEM_PROMPT,
}
] + dialog
for message in dialog:
prompt += f"{message['role'].upper()}: {message['content']}\n"
prompt += "ASSISTANT:"
return prompt
def _mock_streaming_response(tokens: str) -> Generator[str, None, None]:
"""Utility to mock a streaming response"""
for token in tokens:
yield token
class HuggingFaceInferenceCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
model_stream = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
else:
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
model_stream = _mock_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class HuggingFaceInferenceChatCompletionQA(QAModel):
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)
@staticmethod
def convert_dialog_to_conversational_format(
dialog: list[dict[str, str]]
) -> tuple[str, list[str], list[str]]:
if dialog[-1]["role"] != "user":
raise Exception(
"Last message in conversational dialog must be User message"
)
user_message = dialog[-1]["content"]
dialog = dialog[0:-1]
generated_responses = []
past_user_inputs = []
for message in dialog:
# HuggingFace inference client doesn't support system messages today
# so lumping them in with user messages
if message["role"] in ["user", "system"]:
past_user_inputs += [message["content"]]
else:
generated_responses += [message["content"]]
return user_message, generated_responses, past_user_inputs
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
(
message,
generated_responses,
past_user_inputs,
) = self.convert_dialog_to_conversational_format(filled_prompt)
model_output = self.client.conversational(
message,
generated_responses=generated_responses,
past_user_inputs=past_user_inputs,
parameters={"max_length": self.max_output_tokens},
)
else:
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_output = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
raise Exception(
"Conversational API is not available with streaming enabled. Please either "
+ "disable streaming, or disable using conversational API."
)
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_stream = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
else:
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
(
message,
generated_responses,
past_user_inputs,
) = self.convert_dialog_to_conversational_format(filled_prompt)
model_output = self.client.conversational(
message,
generated_responses=generated_responses,
past_user_inputs=past_user_inputs,
parameters={"max_length": self.max_output_tokens},
)
else:
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_output = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens
),
)
logger.debug(model_output)
model_stream = _mock_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -23,6 +23,12 @@ class DanswerQuote:
class QAModel:
@property
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
return True
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
pass if model is accessed via API"""

View File

@@ -15,8 +15,6 @@ from openai.error import Timeout
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import INCLUDE_METADATA
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
@@ -28,9 +26,9 @@ from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@@ -41,15 +39,9 @@ logger = setup_logger()
F = TypeVar("F", bound=Callable)
def get_openai_api_key() -> str:
return OPENAI_API_KEY or cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
)
def _ensure_openai_api_key(api_key: str | None) -> str:
try:
return api_key or get_openai_api_key()
return api_key or get_gen_ai_api_key()
except ConfigNotFoundError:
raise OpenAIKeyMissing()
@@ -131,7 +123,7 @@ class OpenAICompletionQA(OpenAIQAModel):
self.timeout = timeout
self.include_metadata = include_metadata
try:
self.api_key = api_key or get_openai_api_key()
self.api_key = api_key or get_gen_ai_api_key()
except ConfigNotFoundError:
raise OpenAIKeyMissing()

View File

@@ -3,6 +3,7 @@ import math
import re
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import Optional
from typing import Tuple
@@ -12,14 +13,17 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
@@ -27,6 +31,12 @@ from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
def get_gen_ai_api_key() -> str:
return GEN_AI_API_KEY or cast(
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
)
def structure_quotes_for_response(
quotes: list[DanswerQuote] | None,
) -> dict[str, dict[str, str | None]]:
@@ -246,3 +256,9 @@ def process_model_tokens(
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
yield structure_quotes_for_response(quotes)
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character"""
for token in model_out:
yield token

View File

@@ -0,0 +1,201 @@
import abc
import json
from collections.abc import Generator
from typing import Any
import requests
from requests.models import Response
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.direct_qa.qa_utils import simulate_streaming_response
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
class HostSpecificRequestModel(abc.ABC):
"""Provides a more minimal implementation requirement for extending to new models
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
"""
@staticmethod
@abc.abstractmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None,
max_output_tokens: int,
stream: bool,
timeout: int | None,
) -> Response:
"""Given a filled out prompt, how to send it to the model API with the
correct request format with the correct parameters"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def extract_model_output_from_response(
response: Response,
) -> str:
"""Extract the full model output text from a response.
This is for nonstreaming endpoints"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
"""Generate tokens from a streaming response
This is for streaming endpoints"""
raise NotImplementedError
class HuggingFaceRequestModel(HostSpecificRequestModel):
@staticmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None,
max_output_tokens: int,
stream: bool, # Not supported by Inference Endpoints (as of Aug 2023)
timeout: int | None,
) -> Response:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
data = {
"inputs": filled_prompt,
"parameters": {
# HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive
"temperature": 0.01,
# Skip the long tail
"top_p": 0.9,
"max_new_tokens": max_output_tokens,
},
}
try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except TimeoutError as error:
raise TimeoutError(f"Model inference to {endpoint} timed out") from error
@staticmethod
def _hf_extract_model_output(
response: Response,
) -> str:
if response.status_code != 200:
response.raise_for_status()
return json.loads(response.content)[0].get("generated_text", "")
@staticmethod
def extract_model_output_from_response(
response: Response,
) -> str:
return HuggingFaceRequestModel._hf_extract_model_output(response)
@staticmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
"""HF endpoints do not do streaming currently so this function will
simulate streaming for the meantime but will need to be replaced in
the future once streaming is enabled."""
model_out = HuggingFaceRequestModel._hf_extract_model_output(response)
yield from simulate_streaming_response(model_out)
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
if model_host_type == ModelHostType.HUGGINGFACE.value:
return HuggingFaceRequestModel()
else:
# TODO support Azure, GCP, AWS
raise ValueError(
"Invalid model hosting service selected, currently supports only huggingface"
)
class RequestCompletionQA(QAModel):
def __init__(
self,
endpoint: str = GEN_AI_ENDPOINT,
model_host_type: str = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
) -> None:
self.endpoint = endpoint
self.api_key = api_key
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.model_class = get_host_specific_model_class(model_host_type)
self.timeout = timeout
def _get_request_response(
self, query: str, context_docs: list[InferenceChunk], stream: bool
) -> Response:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, include_metadata=False
)
logger.debug(filled_prompt)
return self.model_class.send_model_request(
filled_prompt,
self.endpoint,
self.api_key,
self.max_output_tokens,
stream,
self.timeout,
)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)
model_output = self.model_class.extract_model_output_from_response(
model_api_response
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self,
query: str,
context_docs: list[InferenceChunk],
) -> Generator[dict[str, Any] | None, None, None]:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)
token_generator = self.model_class.generate_model_tokens_from_response(
model_api_response
)
yield from process_model_tokens(
tokens=token_generator,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -18,7 +18,7 @@ from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.connectors.file.utils import write_temp_files
from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_KEY
from danswer.connectors.google_drive.connector_auth import get_auth_url
@@ -50,7 +50,7 @@ from danswer.db.index_attempt import create_index_attempt
from danswer.db.models import User
from danswer.direct_qa import check_model_api_key_is_valid
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.open_ai import get_openai_api_key
from danswer.direct_qa.open_ai import get_gen_ai_api_key
from danswer.direct_qa.open_ai import OpenAIQAModel
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
@@ -293,19 +293,17 @@ def connector_run_once(
)
@router.head("/admin/openai-api-key/validate")
def validate_existing_openai_api_key(
@router.head("/admin/genai-api-key/validate")
def validate_existing_genai_api_key(
_: User = Depends(current_admin_user),
) -> None:
# OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off or if a non-OpenAI model is being used
if DISABLE_GENERATIVE_AI or not isinstance(
get_default_backend_qa_model(), OpenAIQAModel
):
if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key:
return
# Only validate every so often
check_key_time = "openai_api_key_last_check_time"
check_key_time = "genai_api_key_last_check_time"
kv_store = get_dynamic_config_store()
curr_time = datetime.now()
try:
@@ -318,7 +316,7 @@ def validate_existing_openai_api_key(
pass
try:
openai_api_key = get_openai_api_key()
genai_api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
@@ -327,7 +325,7 @@ def validate_existing_openai_api_key(
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try:
is_valid = check_model_api_key_is_valid(openai_api_key)
is_valid = check_model_api_key_is_valid(genai_api_key)
except ValueError:
# this is the case where they aren't using an OpenAI-based model
is_valid = True
@@ -336,8 +334,8 @@ def validate_existing_openai_api_key(
raise HTTPException(status_code=400, detail="Invalid API key provided")
@router.get("/admin/openai-api-key", response_model=ApiKey)
def get_openai_api_key_from_dynamic_config_store(
@router.get("/admin/genai-api-key", response_model=ApiKey)
def get_gen_ai_api_key_from_dynamic_config_store(
_: User = Depends(current_admin_user),
) -> ApiKey:
"""
@@ -347,15 +345,15 @@ def get_openai_api_key_from_dynamic_config_store(
# only get last 4 characters of key to not expose full key
return ApiKey(
api_key=cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
)[-4:]
)
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
@router.put("/admin/openai-api-key")
def store_openai_api_key(
@router.put("/admin/genai-api-key")
def store_genai_api_key(
request: ApiKey,
_: User = Depends(current_admin_user),
) -> None:
@@ -363,16 +361,16 @@ def store_openai_api_key(
is_valid = check_model_api_key_is_valid(request.api_key)
if not is_valid:
raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except RuntimeError as e:
raise HTTPException(400, str(e))
@router.delete("/admin/openai-api-key")
def delete_openai_api_key(
@router.delete("/admin/genai-api-key")
def delete_genai_api_key(
_: User = Depends(current_admin_user),
) -> None:
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY)
get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
"""Endpoints for basic users"""

View File

@@ -6,6 +6,7 @@ reorder-python-imports==3.9.0
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-oauthlib==3.2.0.9
types-setuptools==68.0.0.3
types-psycopg2==2.9.21.10
types-python-dateutil==2.8.19.13
types-regex==2023.3.23.1