mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Support for Request accessed GenAI Models (#270)
This commit is contained in:
@@ -138,12 +138,6 @@ CHUNK_WORD_OVERLAP = 5
|
||||
CHUNK_MAX_CHAR_OVERLAP = 50
|
||||
|
||||
|
||||
#####
|
||||
# Other API Keys
|
||||
#####
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
||||
|
||||
|
||||
#####
|
||||
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
|
||||
#####
|
||||
|
@@ -12,7 +12,7 @@ SECTION_CONTINUATION = "section_continuation"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ALLOWED_GROUPS = "allowed_groups"
|
||||
METADATA = "metadata"
|
||||
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key"
|
||||
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
HTML_SEPARATOR = "\n"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
|
||||
@@ -30,3 +30,26 @@ class DocumentSource(str, Enum):
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
NOTION = "notion"
|
||||
|
||||
|
||||
class DanswerGenAIModel(str, Enum):
|
||||
"""This represents the internal Danswer GenAI model which determines the class that is used
|
||||
to generate responses to the user query. Different models/services require different internal
|
||||
handling, this allows for modularity of implementation within Danswer"""
|
||||
|
||||
OPENAI = "openai-completion"
|
||||
OPENAI_CHAT = "openai-chat-completion"
|
||||
GPT4ALL = "gpt4all-completion"
|
||||
GPT4ALL_CHAT = "gpt4all-chat-completion"
|
||||
HUGGINGFACE = "huggingface-inference-completion"
|
||||
HUGGINGFACE_CHAT = "huggingface-inference-chat-completion"
|
||||
REQUEST = "request-completion"
|
||||
|
||||
|
||||
class ModelHostType(str, Enum):
|
||||
"""For GenAI models interfaced via requests, different services have different
|
||||
expectations for what fields are included in the request"""
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
|
||||
# TODO support for Azure, AWS, GCP GenAI model hosting
|
||||
|
@@ -1,4 +1,8 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
|
||||
# Important considerations when choosing models
|
||||
# Max tokens count needs to be high considering use case (at least 512)
|
||||
@@ -30,35 +34,46 @@ CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
# QA Model API Configs
|
||||
# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models
|
||||
# Valid list:
|
||||
# - openai-completion
|
||||
# - openai-chat-completion
|
||||
# - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
|
||||
# - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
|
||||
# To use gpt4all, run: pip install --upgrade gpt4all==1.0.5
|
||||
# These support HuggingFace Inference API, Inference Endpoints and servers running the text-generation-inference backend
|
||||
# - huggingface-inference-completion
|
||||
# - huggingface-inference-chat-completion
|
||||
|
||||
#####
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
# Other models should work as well, check the library/API compatibility.
|
||||
# But these are the models that have been verified to work with the existing prompts.
|
||||
# Using a different model may require some prompt tuning. See qa_prompts.py
|
||||
VERIFIED_MODELS = {
|
||||
DanswerGenAIModel.OPENAI: ["text-davinci-003"],
|
||||
DanswerGenAIModel.OPENAI_CHAT: ["gpt-3.5-turbo", "gpt-4"],
|
||||
DanswerGenAIModel.GPT4ALL: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
||||
DanswerGenAIModel.GPT4ALL_CHAT: ["ggml-model-gpt4all-falcon-q4_0.bin"],
|
||||
# The "chat" model below is actually "instruction finetuned" and does not support conversational
|
||||
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
|
||||
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
|
||||
}
|
||||
|
||||
# Sets the internal Danswer model class to use
|
||||
INTERNAL_MODEL_VERSION = os.environ.get(
|
||||
"INTERNAL_MODEL_VERSION", "openai-chat-completion"
|
||||
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
|
||||
)
|
||||
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", "")
|
||||
|
||||
# If using GPT4All or OpenAI, specify the model version
|
||||
GEN_AI_MODEL_VERSION = os.environ.get(
|
||||
"GEN_AI_MODEL_VERSION",
|
||||
VERIFIED_MODELS.get(DanswerGenAIModel(INTERNAL_MODEL_VERSION), [""])[0],
|
||||
)
|
||||
|
||||
# If the Generative Model is hosted to accept requests (DanswerGenAIModel.REQUEST) then
|
||||
# set the two below to specify
|
||||
# - Where to hit the endpoint
|
||||
# - How should the request be formed
|
||||
GEN_AI_ENDPOINT = os.environ.get("GEN_AI_ENDPOINT", "")
|
||||
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)
|
||||
|
||||
# Set this to be enough for an answer + quotes
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512"))
|
||||
# Use HuggingFace API Token for Huggingface inference client
|
||||
GEN_AI_HUGGINGFACE_API_TOKEN = os.environ.get("GEN_AI_HUGGINGFACE_API_TOKEN", None)
|
||||
# Use the conversational API with the huggingface-inference-chat-completion internal model
|
||||
# Note - this only works with models that support conversational interfaces
|
||||
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL = (
|
||||
os.environ.get("GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL", "").lower() == "true"
|
||||
)
|
||||
# Disable streaming responses. Set this to true to "polyfill" streaming for models that don't support streaming
|
||||
GEN_AI_HUGGINGFACE_DISABLE_STREAM = (
|
||||
os.environ.get("GEN_AI_HUGGINGFACE_DISABLE_STREAM", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
@@ -1,25 +1,30 @@
|
||||
from typing import Any
|
||||
|
||||
import pkg_resources
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import (
|
||||
GEN_AI_HUGGINGFACE_API_TOKEN,
|
||||
INTERNAL_MODEL_VERSION,
|
||||
)
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.huggingface_inference import (
|
||||
HuggingFaceInferenceChatCompletionQA,
|
||||
HuggingFaceInferenceCompletionQA,
|
||||
)
|
||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
# Imports commented out temporarily due to incompatibility of gpt4all with M1 Mac hardware currently
|
||||
# from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
# from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
@@ -35,32 +40,66 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
except Timeout:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_default_backend_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
api_key: str | None = None,
|
||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> QAModel:
|
||||
if internal_model == "openai-completion":
|
||||
if not api_key:
|
||||
try:
|
||||
api_key = get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
if internal_model in [
|
||||
DanswerGenAIModel.GPT4ALL.value,
|
||||
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
||||
]:
|
||||
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
|
||||
pkg_resources.get_distribution("gpt4all")
|
||||
|
||||
if internal_model == DanswerGenAIModel.OPENAI.value:
|
||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == "openai-chat-completion":
|
||||
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == "huggingface-inference-completion":
|
||||
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
|
||||
return HuggingFaceInferenceCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == "huggingface-inference-chat-completion":
|
||||
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
|
||||
return HuggingFaceInferenceChatCompletionQA(api_key=api_key, **kwargs)
|
||||
# Note GPT4All is not supported for M1 Mac machines currently, removing until support is added
|
||||
# elif internal_model == "gpt4all-completion":
|
||||
# return GPT4AllCompletionQA(**kwargs)
|
||||
# elif internal_model == "gpt4all-chat-completion":
|
||||
# return GPT4AllChatCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
||||
return GPT4AllCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
||||
return GPT4AllChatCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
|
||||
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
|
||||
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.REQUEST.value:
|
||||
if endpoint is None or model_host_type is None:
|
||||
raise ValueError(
|
||||
"Request based GenAI model requires an endpoint and host type"
|
||||
)
|
||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
||||
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
||||
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
||||
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
prompt_processor=WeakModelFreeformProcessor(),
|
||||
timeout=timeout,
|
||||
)
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise UnknownModelError(internal_model)
|
||||
|
@@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from gpt4all import GPT4All # type:ignore
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
@@ -18,9 +16,30 @@ from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DummyGPT4All:
|
||||
"""In the case of import failure due to M1 Mac incompatibility,
|
||||
so this module does not raise exceptions during server startup,
|
||||
as long as this module isn't actually used"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise RuntimeError("GPT4All library not installed.")
|
||||
|
||||
|
||||
try:
|
||||
from gpt4all import GPT4All # type:ignore
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"GPT4All library not installed. "
|
||||
"If you wish to run GPT4ALL (in memory) to power Danswer's "
|
||||
"Generative AI features, please install gpt4all==1.0.5. "
|
||||
"As of Aug 2023, this library is not compatible with M1 Mac."
|
||||
)
|
||||
GPT4All = DummyGPT4All
|
||||
|
||||
|
||||
GPT4ALL_MODEL: GPT4All | None = None
|
||||
|
||||
|
||||
@@ -56,6 +75,10 @@ class GPT4AllCompletionQA(QAModel):
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return False
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
get_gpt_4_all_model(self.model_version)
|
||||
|
||||
@@ -117,6 +140,13 @@ class GPT4AllChatCompletionQA(QAModel):
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return False
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
get_gpt_4_all_model(self.model_version)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
|
189
backend/danswer/direct_qa/huggingface.py
Normal file
189
backend/danswer/direct_qa/huggingface.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import InferenceClient # type:ignore
|
||||
from huggingface_hub.utils import HfHubHTTPError # type:ignore
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import FreeformProcessor
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.direct_qa.qa_utils import simulate_streaming_response
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"do_sample": False,
|
||||
"seed": 69, # For reproducibility
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class HuggingFaceCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: NonChatPromptProcessor = FreeformProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.client = InferenceClient(model=model_version, token=api_key)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
model_output = self.client.text_generation(
|
||||
filled_prompt,
|
||||
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
model_stream = self.client.text_generation(
|
||||
filled_prompt,
|
||||
**_build_hf_inference_settings(
|
||||
max_new_tokens=self.max_output_tokens, stream=True
|
||||
),
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceChatCompletionQA(QAModel):
|
||||
"""Chat in this class refers to the HuggingFace Conversational API.
|
||||
Not to be confused with Chat/Instruction finetuned models.
|
||||
Llama2-chat... means it is an Instruction finetuned model, not necessarily that
|
||||
it supports Conversational API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.client = InferenceClient(model=model_version, token=api_key)
|
||||
|
||||
@staticmethod
|
||||
def _convert_chat_to_hf_conversational_format(
|
||||
dialog: list[dict[str, str]]
|
||||
) -> tuple[str, list[str], list[str]]:
|
||||
if dialog[-1]["role"] != "user":
|
||||
raise Exception(
|
||||
"Last message in conversational dialog must be User message"
|
||||
)
|
||||
user_message = dialog[-1]["content"]
|
||||
dialog = dialog[0:-1]
|
||||
generated_responses = []
|
||||
past_user_inputs = []
|
||||
for message in dialog:
|
||||
# HuggingFace inference client doesn't support system messages today
|
||||
# so lumping them in with user messages
|
||||
if message["role"] in ["user", "system"]:
|
||||
past_user_inputs += [message["content"]]
|
||||
else:
|
||||
generated_responses += [message["content"]]
|
||||
return user_message, generated_responses, past_user_inputs
|
||||
|
||||
def _get_hf_model_output(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> str:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
|
||||
(
|
||||
query,
|
||||
past_responses,
|
||||
past_inputs,
|
||||
) = self._convert_chat_to_hf_conversational_format(filled_prompt)
|
||||
|
||||
logger.debug(f"Last Input: {query}")
|
||||
logger.debug(f"Past Inputs: {past_inputs}")
|
||||
logger.debug(f"Past Responses: {past_responses}")
|
||||
try:
|
||||
model_output = self.client.conversational(
|
||||
query,
|
||||
generated_responses=past_responses,
|
||||
past_user_inputs=past_inputs,
|
||||
parameters={"max_length": self.max_output_tokens},
|
||||
)
|
||||
except HfHubHTTPError as model_error:
|
||||
if model_error.response.status_code == 422:
|
||||
raise ValueError(
|
||||
"Selected HuggingFace Model does not support HuggingFace Conversational API,"
|
||||
"try using the huggingface-inference-completion in Danswer instead"
|
||||
)
|
||||
raise
|
||||
logger.debug(model_output)
|
||||
|
||||
return model_output
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
model_output = self._get_hf_model_output(query, context_docs)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
|
||||
So here it is faked by streaming characters within Danswer from the model output
|
||||
"""
|
||||
model_output = self._get_hf_model_output(query, context_docs)
|
||||
|
||||
model_stream = simulate_streaming_response(model_output)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -1,248 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Iterator
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import (
|
||||
GEN_AI_HUGGINGFACE_DISABLE_STREAM,
|
||||
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL,
|
||||
GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
)
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import (
|
||||
ChatPromptProcessor,
|
||||
JsonChatProcessor,
|
||||
JsonProcessor,
|
||||
)
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"do_sample": False,
|
||||
"seed": 69, # For reproducibility
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _generic_chat_dialog_to_prompt_formatter(dialog: list[dict[str, str]]) -> str:
|
||||
"""
|
||||
Utility to convert chat dialog to a text-generation prompt for models tuned for chat.
|
||||
Note - This is a "best guess" attempt at a generic completions prompt for chat
|
||||
completion models. It isn't optimized for all chat trained models, but tries
|
||||
to serialize to a format that most models understand.
|
||||
Models like Llama2-chat have been optimized for certain formatting of chat
|
||||
completions, and this function doesn't take that into account, so you won't
|
||||
always get the best possible outcome.
|
||||
TODO - Add the ability to pass custom formatters for chat dialogue
|
||||
"""
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.
|
||||
If a question does not make any sense or is not factually coherent, explain why instead of answering incorrectly.
|
||||
If you don't know the answer to a question, don't share false information."""
|
||||
prompt = ""
|
||||
if dialog[0]["role"] != "system":
|
||||
dialog = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": DEFAULT_SYSTEM_PROMPT,
|
||||
}
|
||||
] + dialog
|
||||
for message in dialog:
|
||||
prompt += f"{message['role'].upper()}: {message['content']}\n"
|
||||
prompt += "ASSISTANT:"
|
||||
return prompt
|
||||
|
||||
|
||||
def _mock_streaming_response(tokens: str) -> Generator[str, None, None]:
|
||||
"""Utility to mock a streaming response"""
|
||||
for token in tokens:
|
||||
yield token
|
||||
|
||||
|
||||
class HuggingFaceInferenceCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.client = InferenceClient(model=model_version, token=api_key)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
model_output = self.client.text_generation(
|
||||
filled_prompt,
|
||||
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||
)
|
||||
logger.debug(model_output)
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
|
||||
model_stream = self.client.text_generation(
|
||||
filled_prompt,
|
||||
**_build_hf_inference_settings(
|
||||
max_new_tokens=self.max_output_tokens, stream=True
|
||||
),
|
||||
)
|
||||
else:
|
||||
model_output = self.client.text_generation(
|
||||
filled_prompt,
|
||||
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||
)
|
||||
logger.debug(model_output)
|
||||
model_stream = _mock_streaming_response(model_output)
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceInferenceChatCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.client = InferenceClient(model=model_version, token=api_key)
|
||||
|
||||
@staticmethod
|
||||
def convert_dialog_to_conversational_format(
|
||||
dialog: list[dict[str, str]]
|
||||
) -> tuple[str, list[str], list[str]]:
|
||||
if dialog[-1]["role"] != "user":
|
||||
raise Exception(
|
||||
"Last message in conversational dialog must be User message"
|
||||
)
|
||||
user_message = dialog[-1]["content"]
|
||||
dialog = dialog[0:-1]
|
||||
generated_responses = []
|
||||
past_user_inputs = []
|
||||
for message in dialog:
|
||||
# HuggingFace inference client doesn't support system messages today
|
||||
# so lumping them in with user messages
|
||||
if message["role"] in ["user", "system"]:
|
||||
past_user_inputs += [message["content"]]
|
||||
else:
|
||||
generated_responses += [message["content"]]
|
||||
return user_message, generated_responses, past_user_inputs
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
|
||||
(
|
||||
message,
|
||||
generated_responses,
|
||||
past_user_inputs,
|
||||
) = self.convert_dialog_to_conversational_format(filled_prompt)
|
||||
model_output = self.client.conversational(
|
||||
message,
|
||||
generated_responses=generated_responses,
|
||||
past_user_inputs=past_user_inputs,
|
||||
parameters={"max_length": self.max_output_tokens},
|
||||
)
|
||||
else:
|
||||
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
|
||||
logger.debug(chat_prompt)
|
||||
model_output = self.client.text_generation(
|
||||
chat_prompt,
|
||||
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||
)
|
||||
logger.debug(model_output)
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
|
||||
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
|
||||
raise Exception(
|
||||
"Conversational API is not available with streaming enabled. Please either "
|
||||
+ "disable streaming, or disable using conversational API."
|
||||
)
|
||||
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
|
||||
logger.debug(chat_prompt)
|
||||
model_stream = self.client.text_generation(
|
||||
chat_prompt,
|
||||
**_build_hf_inference_settings(
|
||||
max_new_tokens=self.max_output_tokens, stream=True
|
||||
),
|
||||
)
|
||||
else:
|
||||
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
|
||||
(
|
||||
message,
|
||||
generated_responses,
|
||||
past_user_inputs,
|
||||
) = self.convert_dialog_to_conversational_format(filled_prompt)
|
||||
model_output = self.client.conversational(
|
||||
message,
|
||||
generated_responses=generated_responses,
|
||||
past_user_inputs=past_user_inputs,
|
||||
parameters={"max_length": self.max_output_tokens},
|
||||
)
|
||||
else:
|
||||
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
|
||||
logger.debug(chat_prompt)
|
||||
model_output = self.client.text_generation(
|
||||
chat_prompt,
|
||||
**_build_hf_inference_settings(
|
||||
max_new_tokens=self.max_output_tokens
|
||||
),
|
||||
)
|
||||
logger.debug(model_output)
|
||||
model_stream = _mock_streaming_response(model_output)
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -23,6 +23,12 @@ class DanswerQuote:
|
||||
|
||||
|
||||
class QAModel:
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
"""Is this model protected by security features
|
||||
Does it need an api key to access the model for inference"""
|
||||
return True
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
"""This is called during server start up to load the models into memory
|
||||
pass if model is accessed via API"""
|
||||
|
@@ -15,8 +15,6 @@ from openai.error import Timeout
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import INCLUDE_METADATA
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
@@ -28,9 +26,9 @@ from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
@@ -41,15 +39,9 @@ logger = setup_logger()
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def get_openai_api_key() -> str:
|
||||
return OPENAI_API_KEY or cast(
|
||||
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
||||
)
|
||||
|
||||
|
||||
def _ensure_openai_api_key(api_key: str | None) -> str:
|
||||
try:
|
||||
return api_key or get_openai_api_key()
|
||||
return api_key or get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
@@ -131,7 +123,7 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
self.timeout = timeout
|
||||
self.include_metadata = include_metadata
|
||||
try:
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
self.api_key = api_key or get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
|
@@ -3,6 +3,7 @@ import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
@@ -12,14 +13,17 @@ from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
@@ -27,6 +31,12 @@ from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_gen_ai_api_key() -> str:
|
||||
return GEN_AI_API_KEY or cast(
|
||||
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
)
|
||||
|
||||
|
||||
def structure_quotes_for_response(
|
||||
quotes: list[DanswerQuote] | None,
|
||||
) -> dict[str, dict[str, str | None]]:
|
||||
@@ -246,3 +256,9 @@ def process_model_tokens(
|
||||
|
||||
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
|
||||
yield structure_quotes_for_response(quotes)
|
||||
|
||||
|
||||
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
"""Mock streaming by generating the passed in model output, character by character"""
|
||||
for token in model_out:
|
||||
yield token
|
||||
|
201
backend/danswer/direct_qa/request_model.py
Normal file
201
backend/danswer/direct_qa/request_model.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import abc
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.models import Response
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.direct_qa.qa_utils import simulate_streaming_response
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HostSpecificRequestModel(abc.ABC):
|
||||
"""Provides a more minimal implementation requirement for extending to new models
|
||||
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def send_model_request(
|
||||
filled_prompt: str,
|
||||
endpoint: str,
|
||||
api_key: str | None,
|
||||
max_output_tokens: int,
|
||||
stream: bool,
|
||||
timeout: int | None,
|
||||
) -> Response:
|
||||
"""Given a filled out prompt, how to send it to the model API with the
|
||||
correct request format with the correct parameters"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def extract_model_output_from_response(
|
||||
response: Response,
|
||||
) -> str:
|
||||
"""Extract the full model output text from a response.
|
||||
This is for nonstreaming endpoints"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def generate_model_tokens_from_response(
|
||||
response: Response,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate tokens from a streaming response
|
||||
This is for streaming endpoints"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HuggingFaceRequestModel(HostSpecificRequestModel):
|
||||
@staticmethod
|
||||
def send_model_request(
|
||||
filled_prompt: str,
|
||||
endpoint: str,
|
||||
api_key: str | None,
|
||||
max_output_tokens: int,
|
||||
stream: bool, # Not supported by Inference Endpoints (as of Aug 2023)
|
||||
timeout: int | None,
|
||||
) -> Response:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": filled_prompt,
|
||||
"parameters": {
|
||||
# HuggingFace requires this to be strictly positive from 0.0-100.0 noninclusive
|
||||
"temperature": 0.01,
|
||||
# Skip the long tail
|
||||
"top_p": 0.9,
|
||||
"max_new_tokens": max_output_tokens,
|
||||
},
|
||||
}
|
||||
try:
|
||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
||||
except TimeoutError as error:
|
||||
raise TimeoutError(f"Model inference to {endpoint} timed out") from error
|
||||
|
||||
@staticmethod
|
||||
def _hf_extract_model_output(
|
||||
response: Response,
|
||||
) -> str:
|
||||
if response.status_code != 200:
|
||||
response.raise_for_status()
|
||||
|
||||
return json.loads(response.content)[0].get("generated_text", "")
|
||||
|
||||
@staticmethod
|
||||
def extract_model_output_from_response(
|
||||
response: Response,
|
||||
) -> str:
|
||||
return HuggingFaceRequestModel._hf_extract_model_output(response)
|
||||
|
||||
@staticmethod
|
||||
def generate_model_tokens_from_response(
|
||||
response: Response,
|
||||
) -> Generator[str, None, None]:
|
||||
"""HF endpoints do not do streaming currently so this function will
|
||||
simulate streaming for the meantime but will need to be replaced in
|
||||
the future once streaming is enabled."""
|
||||
model_out = HuggingFaceRequestModel._hf_extract_model_output(response)
|
||||
yield from simulate_streaming_response(model_out)
|
||||
|
||||
|
||||
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
|
||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
||||
return HuggingFaceRequestModel()
|
||||
else:
|
||||
# TODO support Azure, GCP, AWS
|
||||
raise ValueError(
|
||||
"Invalid model hosting service selected, currently supports only huggingface"
|
||||
)
|
||||
|
||||
|
||||
class RequestCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = GEN_AI_ENDPOINT,
|
||||
model_host_type: str = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.prompt_processor = prompt_processor
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.model_class = get_host_specific_model_class(model_host_type)
|
||||
self.timeout = timeout
|
||||
|
||||
def _get_request_response(
|
||||
self, query: str, context_docs: list[InferenceChunk], stream: bool
|
||||
) -> Response:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, include_metadata=False
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
return self.model_class.send_model_request(
|
||||
filled_prompt,
|
||||
self.endpoint,
|
||||
self.api_key,
|
||||
self.max_output_tokens,
|
||||
stream,
|
||||
self.timeout,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
model_api_response = self._get_request_response(
|
||||
query, context_docs, stream=False
|
||||
)
|
||||
|
||||
model_output = self.model_class.extract_model_output_from_response(
|
||||
model_api_response
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
model_api_response = self._get_request_response(
|
||||
query, context_docs, stream=False
|
||||
)
|
||||
|
||||
token_generator = self.model_class.generate_model_tokens_from_response(
|
||||
model_api_response
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=token_generator,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -18,7 +18,7 @@ from danswer.auth.users import current_user
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.connectors.file.utils import write_temp_files
|
||||
from danswer.connectors.google_drive.connector_auth import DB_CREDENTIALS_DICT_KEY
|
||||
from danswer.connectors.google_drive.connector_auth import get_auth_url
|
||||
@@ -50,7 +50,7 @@ from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa import check_model_api_key_is_valid
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.open_ai import get_openai_api_key
|
||||
from danswer.direct_qa.open_ai import get_gen_ai_api_key
|
||||
from danswer.direct_qa.open_ai import OpenAIQAModel
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
@@ -293,19 +293,17 @@ def connector_run_once(
|
||||
)
|
||||
|
||||
|
||||
@router.head("/admin/openai-api-key/validate")
|
||||
def validate_existing_openai_api_key(
|
||||
@router.head("/admin/genai-api-key/validate")
|
||||
def validate_existing_genai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
# OpenAI key is only used for generative QA, so no need to validate this
|
||||
# if it's turned off or if a non-OpenAI model is being used
|
||||
if DISABLE_GENERATIVE_AI or not isinstance(
|
||||
get_default_backend_qa_model(), OpenAIQAModel
|
||||
):
|
||||
if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key:
|
||||
return
|
||||
|
||||
# Only validate every so often
|
||||
check_key_time = "openai_api_key_last_check_time"
|
||||
check_key_time = "genai_api_key_last_check_time"
|
||||
kv_store = get_dynamic_config_store()
|
||||
curr_time = datetime.now()
|
||||
try:
|
||||
@@ -318,7 +316,7 @@ def validate_existing_openai_api_key(
|
||||
pass
|
||||
|
||||
try:
|
||||
openai_api_key = get_openai_api_key()
|
||||
genai_api_key = get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
except ValueError as e:
|
||||
@@ -327,7 +325,7 @@ def validate_existing_openai_api_key(
|
||||
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
||||
|
||||
try:
|
||||
is_valid = check_model_api_key_is_valid(openai_api_key)
|
||||
is_valid = check_model_api_key_is_valid(genai_api_key)
|
||||
except ValueError:
|
||||
# this is the case where they aren't using an OpenAI-based model
|
||||
is_valid = True
|
||||
@@ -336,8 +334,8 @@ def validate_existing_openai_api_key(
|
||||
raise HTTPException(status_code=400, detail="Invalid API key provided")
|
||||
|
||||
|
||||
@router.get("/admin/openai-api-key", response_model=ApiKey)
|
||||
def get_openai_api_key_from_dynamic_config_store(
|
||||
@router.get("/admin/genai-api-key", response_model=ApiKey)
|
||||
def get_gen_ai_api_key_from_dynamic_config_store(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> ApiKey:
|
||||
"""
|
||||
@@ -347,15 +345,15 @@ def get_openai_api_key_from_dynamic_config_store(
|
||||
# only get last 4 characters of key to not expose full key
|
||||
return ApiKey(
|
||||
api_key=cast(
|
||||
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
||||
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
)[-4:]
|
||||
)
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
|
||||
@router.put("/admin/openai-api-key")
|
||||
def store_openai_api_key(
|
||||
@router.put("/admin/genai-api-key")
|
||||
def store_genai_api_key(
|
||||
request: ApiKey,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
@@ -363,16 +361,16 @@ def store_openai_api_key(
|
||||
is_valid = check_model_api_key_is_valid(request.api_key)
|
||||
if not is_valid:
|
||||
raise HTTPException(400, "Invalid API key provided")
|
||||
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
|
||||
@router.delete("/admin/openai-api-key")
|
||||
def delete_openai_api_key(
|
||||
@router.delete("/admin/genai-api-key")
|
||||
def delete_genai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY)
|
||||
get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
|
||||
|
||||
"""Endpoints for basic users"""
|
||||
|
@@ -6,6 +6,7 @@ reorder-python-imports==3.9.0
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
types-oauthlib==3.2.0.9
|
||||
types-setuptools==68.0.0.3
|
||||
types-psycopg2==2.9.21.10
|
||||
types-python-dateutil==2.8.19.13
|
||||
types-regex==2023.3.23.1
|
||||
|
Reference in New Issue
Block a user