mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-28 04:49:21 +02:00
Huggingface Inference backend internal models (#265)
This commit is contained in:
@@ -38,12 +38,27 @@ BATCH_SIZE_ENCODE_CHUNKS = 8
|
|||||||
# - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
|
# - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
|
||||||
# - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
|
# - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
|
||||||
# To use gpt4all, run: pip install --upgrade gpt4all==1.0.5
|
# To use gpt4all, run: pip install --upgrade gpt4all==1.0.5
|
||||||
|
# These support HuggingFace Inference API, Inference Endpoints and servers running the text-generation-inference backend
|
||||||
|
# - huggingface-inference-completion
|
||||||
|
# - huggingface-inference-chat-completion
|
||||||
|
|
||||||
INTERNAL_MODEL_VERSION = os.environ.get(
|
INTERNAL_MODEL_VERSION = os.environ.get(
|
||||||
"INTERNAL_MODEL_VERSION", "openai-chat-completion"
|
"INTERNAL_MODEL_VERSION", "openai-chat-completion"
|
||||||
)
|
)
|
||||||
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
|
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
|
||||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
|
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
|
||||||
GEN_AI_MAX_OUTPUT_TOKENS = 512
|
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512"))
|
||||||
|
# Use HuggingFace API Token for Huggingface inference client
|
||||||
|
GEN_AI_HUGGINGFACE_API_TOKEN = os.environ.get("GEN_AI_HUGGINGFACE_API_TOKEN", None)
|
||||||
|
# Use the conversational API with the huggingface-inference-chat-completion internal model
|
||||||
|
# Note - this only works with models that support conversational interfaces
|
||||||
|
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL = (
|
||||||
|
os.environ.get("GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL", "").lower() == "true"
|
||||||
|
)
|
||||||
|
# Disable streaming responses. Set this to true to "polyfill" streaming for models that don't support streaming
|
||||||
|
GEN_AI_HUGGINGFACE_DISABLE_STREAM = (
|
||||||
|
os.environ.get("GEN_AI_HUGGINGFACE_DISABLE_STREAM", "").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
# Danswer custom Deep Learning Models
|
# Danswer custom Deep Learning Models
|
||||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||||
|
@@ -4,9 +4,16 @@ from openai.error import AuthenticationError
|
|||||||
from openai.error import Timeout
|
from openai.error import Timeout
|
||||||
|
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
from danswer.configs.model_configs import (
|
||||||
|
GEN_AI_HUGGINGFACE_API_TOKEN,
|
||||||
|
INTERNAL_MODEL_VERSION,
|
||||||
|
)
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
|
from danswer.direct_qa.huggingface_inference import (
|
||||||
|
HuggingFaceInferenceChatCompletionQA,
|
||||||
|
HuggingFaceInferenceCompletionQA,
|
||||||
|
)
|
||||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||||
|
|
||||||
@@ -44,6 +51,12 @@ def get_default_backend_qa_model(
|
|||||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||||
elif internal_model == "openai-chat-completion":
|
elif internal_model == "openai-chat-completion":
|
||||||
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||||
|
elif internal_model == "huggingface-inference-completion":
|
||||||
|
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
|
||||||
|
return HuggingFaceInferenceCompletionQA(api_key=api_key, **kwargs)
|
||||||
|
elif internal_model == "huggingface-inference-chat-completion":
|
||||||
|
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
|
||||||
|
return HuggingFaceInferenceChatCompletionQA(api_key=api_key, **kwargs)
|
||||||
# Note GPT4All is not supported for M1 Mac machines currently, removing until support is added
|
# Note GPT4All is not supported for M1 Mac machines currently, removing until support is added
|
||||||
# elif internal_model == "gpt4all-completion":
|
# elif internal_model == "gpt4all-completion":
|
||||||
# return GPT4AllCompletionQA(**kwargs)
|
# return GPT4AllCompletionQA(**kwargs)
|
||||||
|
248
backend/danswer/direct_qa/huggingface_inference.py
Normal file
248
backend/danswer/direct_qa/huggingface_inference.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Iterator
|
||||||
|
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.model_configs import (
|
||||||
|
GEN_AI_HUGGINGFACE_DISABLE_STREAM,
|
||||||
|
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL,
|
||||||
|
GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
)
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
|
from danswer.direct_qa.interfaces import DanswerQuote
|
||||||
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
|
from danswer.direct_qa.qa_prompts import (
|
||||||
|
ChatPromptProcessor,
|
||||||
|
JsonChatProcessor,
|
||||||
|
JsonProcessor,
|
||||||
|
)
|
||||||
|
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||||
|
from danswer.direct_qa.qa_utils import process_answer
|
||||||
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Utility to add in some common default values so they don't have to be set every time.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"do_sample": False,
|
||||||
|
"seed": 69, # For reproducibility
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _generic_chat_dialog_to_prompt_formatter(dialog: list[dict[str, str]]) -> str:
|
||||||
|
"""
|
||||||
|
Utility to convert chat dialog to a text-generation prompt for models tuned for chat.
|
||||||
|
Note - This is a "best guess" attempt at a generic completions prompt for chat
|
||||||
|
completion models. It isn't optimized for all chat trained models, but tries
|
||||||
|
to serialize to a format that most models understand.
|
||||||
|
Models like Llama2-chat have been optimized for certain formatting of chat
|
||||||
|
completions, and this function doesn't take that into account, so you won't
|
||||||
|
always get the best possible outcome.
|
||||||
|
TODO - Add the ability to pass custom formatters for chat dialogue
|
||||||
|
"""
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """\
|
||||||
|
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.
|
||||||
|
If a question does not make any sense or is not factually coherent, explain why instead of answering incorrectly.
|
||||||
|
If you don't know the answer to a question, don't share false information."""
|
||||||
|
prompt = ""
|
||||||
|
if dialog[0]["role"] != "system":
|
||||||
|
dialog = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": DEFAULT_SYSTEM_PROMPT,
|
||||||
|
}
|
||||||
|
] + dialog
|
||||||
|
for message in dialog:
|
||||||
|
prompt += f"{message['role'].upper()}: {message['content']}\n"
|
||||||
|
prompt += "ASSISTANT:"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_streaming_response(tokens: str) -> Generator[str, None, None]:
|
||||||
|
"""Utility to mock a streaming response"""
|
||||||
|
for token in tokens:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceInferenceCompletionQA(QAModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
include_metadata: bool = False,
|
||||||
|
api_key: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.prompt_processor = prompt_processor
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
self.client = InferenceClient(model=model_version, token=api_key)
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
model_output = self.client.text_generation(
|
||||||
|
filled_prompt,
|
||||||
|
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||||
|
)
|
||||||
|
logger.debug(model_output)
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
|
||||||
|
model_stream = self.client.text_generation(
|
||||||
|
filled_prompt,
|
||||||
|
**_build_hf_inference_settings(
|
||||||
|
max_new_tokens=self.max_output_tokens, stream=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_output = self.client.text_generation(
|
||||||
|
filled_prompt,
|
||||||
|
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||||
|
)
|
||||||
|
logger.debug(model_output)
|
||||||
|
model_stream = _mock_streaming_response(model_output)
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=model_stream,
|
||||||
|
context_docs=context_docs,
|
||||||
|
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceInferenceChatCompletionQA(QAModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
include_metadata: bool = False,
|
||||||
|
api_key: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.prompt_processor = prompt_processor
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
self.client = InferenceClient(model=model_version, token=api_key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_dialog_to_conversational_format(
|
||||||
|
dialog: list[dict[str, str]]
|
||||||
|
) -> tuple[str, list[str], list[str]]:
|
||||||
|
if dialog[-1]["role"] != "user":
|
||||||
|
raise Exception(
|
||||||
|
"Last message in conversational dialog must be User message"
|
||||||
|
)
|
||||||
|
user_message = dialog[-1]["content"]
|
||||||
|
dialog = dialog[0:-1]
|
||||||
|
generated_responses = []
|
||||||
|
past_user_inputs = []
|
||||||
|
for message in dialog:
|
||||||
|
# HuggingFace inference client doesn't support system messages today
|
||||||
|
# so lumping them in with user messages
|
||||||
|
if message["role"] in ["user", "system"]:
|
||||||
|
past_user_inputs += [message["content"]]
|
||||||
|
else:
|
||||||
|
generated_responses += [message["content"]]
|
||||||
|
return user_message, generated_responses, past_user_inputs
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
|
||||||
|
(
|
||||||
|
message,
|
||||||
|
generated_responses,
|
||||||
|
past_user_inputs,
|
||||||
|
) = self.convert_dialog_to_conversational_format(filled_prompt)
|
||||||
|
model_output = self.client.conversational(
|
||||||
|
message,
|
||||||
|
generated_responses=generated_responses,
|
||||||
|
past_user_inputs=past_user_inputs,
|
||||||
|
parameters={"max_length": self.max_output_tokens},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
|
||||||
|
logger.debug(chat_prompt)
|
||||||
|
model_output = self.client.text_generation(
|
||||||
|
chat_prompt,
|
||||||
|
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
|
||||||
|
)
|
||||||
|
logger.debug(model_output)
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
|
||||||
|
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
|
||||||
|
raise Exception(
|
||||||
|
"Conversational API is not available with streaming enabled. Please either "
|
||||||
|
+ "disable streaming, or disable using conversational API."
|
||||||
|
)
|
||||||
|
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
|
||||||
|
logger.debug(chat_prompt)
|
||||||
|
model_stream = self.client.text_generation(
|
||||||
|
chat_prompt,
|
||||||
|
**_build_hf_inference_settings(
|
||||||
|
max_new_tokens=self.max_output_tokens, stream=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
|
||||||
|
(
|
||||||
|
message,
|
||||||
|
generated_responses,
|
||||||
|
past_user_inputs,
|
||||||
|
) = self.convert_dialog_to_conversational_format(filled_prompt)
|
||||||
|
model_output = self.client.conversational(
|
||||||
|
message,
|
||||||
|
generated_responses=generated_responses,
|
||||||
|
past_user_inputs=past_user_inputs,
|
||||||
|
parameters={"max_length": self.max_output_tokens},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
|
||||||
|
logger.debug(chat_prompt)
|
||||||
|
model_output = self.client.text_generation(
|
||||||
|
chat_prompt,
|
||||||
|
**_build_hf_inference_settings(
|
||||||
|
max_new_tokens=self.max_output_tokens
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.debug(model_output)
|
||||||
|
model_stream = _mock_streaming_response(model_output)
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=model_stream,
|
||||||
|
context_docs=context_docs,
|
||||||
|
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||||
|
)
|
@@ -16,6 +16,7 @@ google-auth-oauthlib==1.0.0
|
|||||||
httpcore==0.16.3
|
httpcore==0.16.3
|
||||||
httpx==0.23.3
|
httpx==0.23.3
|
||||||
httpx-oauth==0.11.2
|
httpx-oauth==0.11.2
|
||||||
|
huggingface-hub==0.16.4
|
||||||
jira==3.5.1
|
jira==3.5.1
|
||||||
Mako==1.2.4
|
Mako==1.2.4
|
||||||
nltk==3.8.1
|
nltk==3.8.1
|
||||||
|
Reference in New Issue
Block a user