mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-25 11:16:43 +02:00
add simple local llm (#202)
A very simple local llm. Not as good as OpenAI but works as a drop-in replacement for on premise deployments. --------- Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
@@ -44,6 +44,7 @@ class DanswerGenAIModel(str, Enum):
|
||||
HUGGINGFACE = "huggingface-client-completion"
|
||||
HUGGINGFACE_CHAT = "huggingface-client-chat-completion"
|
||||
REQUEST = "request-completion"
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
class ModelHostType(str, Enum):
|
||||
|
@@ -49,11 +49,15 @@ VERIFIED_MODELS = {
|
||||
# The "chat" model below is actually "instruction finetuned" and does not support conversational
|
||||
DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"],
|
||||
DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"],
|
||||
# Created by Deepset.ai
|
||||
# https://huggingface.co/deepset/deberta-v3-large-squad2
|
||||
# Model provided with no modifications
|
||||
DanswerGenAIModel.TRANSFORMERS.value: ["deepset/deberta-v3-large-squad2"],
|
||||
}
|
||||
|
||||
# Sets the internal Danswer model class to use
|
||||
INTERNAL_MODEL_VERSION = os.environ.get(
|
||||
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value
|
||||
"INTERNAL_MODEL_VERSION", DanswerGenAIModel.TRANSFORMERS.value
|
||||
)
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
|
@@ -16,6 +16,7 @@ from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.local_transformers import TransformerQA
|
||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
@@ -79,6 +80,8 @@ def get_default_backend_qa_model(
|
||||
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
|
||||
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
|
||||
return TransformerQA()
|
||||
elif internal_model == DanswerGenAIModel.REQUEST.value:
|
||||
if endpoint is None or model_host_type is None:
|
||||
raise ValueError(
|
||||
|
147
backend/danswer/direct_qa/local_transformers.py
Normal file
147
backend/danswer/direct_qa/local_transformers.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from transformers import pipeline # type:ignore
|
||||
from transformers import QuestionAnsweringPipeline # type:ignore
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_utils import structure_quotes_for_response
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
TRANSFORMER_DEFAULT_MAX_CONTEXT = 512
|
||||
|
||||
_TRANSFORMER_MODEL: QuestionAnsweringPipeline | None = None
|
||||
|
||||
|
||||
def get_default_transformer_model(
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_context: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
|
||||
) -> QuestionAnsweringPipeline:
|
||||
global _TRANSFORMER_MODEL
|
||||
if _TRANSFORMER_MODEL is None:
|
||||
_TRANSFORMER_MODEL = pipeline(
|
||||
"question-answering", model=model_version, max_seq_len=max_context
|
||||
)
|
||||
|
||||
return _TRANSFORMER_MODEL
|
||||
|
||||
|
||||
def find_extended_answer(answer: str, context: str) -> str:
|
||||
"""Try to extend the answer by matching across the context text and extending before
|
||||
and after the quote to some termination character"""
|
||||
result = re.search(
|
||||
r"(^|\n\r?|\.)(?P<content>[^\n]{{0,250}}{}[^\n]{{0,250}})(\.|$|\n\r?)".format(
|
||||
re.escape(answer)
|
||||
),
|
||||
context,
|
||||
flags=re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if result:
|
||||
return result.group("content")
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
class TransformerQA(QAModel):
|
||||
@staticmethod
|
||||
def _answer_one_chunk(
|
||||
query: str,
|
||||
chunk: InferenceChunk,
|
||||
max_context_len: int = TRANSFORMER_DEFAULT_MAX_CONTEXT,
|
||||
max_cutoff: float = 0.9,
|
||||
min_cutoff: float = 0.5,
|
||||
) -> tuple[str | None, DanswerQuote | None]:
|
||||
"""Because this type of QA model only takes 1 chunk of context with a fairly small token limit
|
||||
We have to iterate the checks and check if the answer is found in any of the chunks.
|
||||
This type of approach does not allow for interpolating answers across chunks
|
||||
"""
|
||||
model = get_default_transformer_model()
|
||||
model_out = model(question=query, context=chunk.content, max_answer_len=128)
|
||||
|
||||
answer = model_out.get("answer")
|
||||
confidence = model_out.get("score")
|
||||
|
||||
if answer is None:
|
||||
return None, None
|
||||
|
||||
logger.info(f"Transformer Answer: {answer}")
|
||||
logger.debug(f"Transformer Confidence: {confidence}")
|
||||
|
||||
# Model tends to be overconfident on short chunks
|
||||
# so min score required increases as chunk size decreases
|
||||
# If it's at least 0.9, then it's good enough regardless
|
||||
# Default minimum of 0.5 required
|
||||
score_cutoff = max(
|
||||
min(max_cutoff, 1 - len(chunk.content) / max_context_len), min_cutoff
|
||||
)
|
||||
if confidence < score_cutoff:
|
||||
return None, None
|
||||
|
||||
extended_answer = find_extended_answer(answer, chunk.content)
|
||||
|
||||
danswer_quote = DanswerQuote(
|
||||
quote=answer,
|
||||
document_id=chunk.document_id,
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
source_type=chunk.source_type,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
blurb=chunk.blurb,
|
||||
)
|
||||
|
||||
return extended_answer, danswer_quote
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
get_default_transformer_model()
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
danswer_quotes: list[DanswerQuote] = []
|
||||
d_answers: list[str] = []
|
||||
for chunk in context_docs:
|
||||
answer, quote = self._answer_one_chunk(query, chunk)
|
||||
if answer is not None and quote is not None:
|
||||
d_answers.append(answer)
|
||||
danswer_quotes.append(quote)
|
||||
|
||||
answers_list = [
|
||||
f"Answer {ind}: {answer.strip()}"
|
||||
for ind, answer in enumerate(d_answers, start=1)
|
||||
]
|
||||
combined_answer = "\n".join(answers_list)
|
||||
return DanswerAnswer(answer=combined_answer), danswer_quotes
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
quotes: list[DanswerQuote] = []
|
||||
answers: list[str] = []
|
||||
for chunk in context_docs:
|
||||
answer, quote = self._answer_one_chunk(query, chunk)
|
||||
if answer is not None and quote is not None:
|
||||
answers.append(answer)
|
||||
quotes.append(quote)
|
||||
|
||||
# Delay the output of the answers so there isn't long gap between first answer and quotes
|
||||
answer_count = 1
|
||||
for answer in answers:
|
||||
if answer_count == 1:
|
||||
yield {"answer_data": "Source 1: "}
|
||||
else:
|
||||
yield {"answer_data": f"\nSource {answer_count}: "}
|
||||
answer_count += 1
|
||||
for char in answer.strip():
|
||||
yield {"answer_data": char}
|
||||
|
||||
yield {"answer_finished": True}
|
||||
|
||||
yield structure_quotes_for_response(quotes)
|
Reference in New Issue
Block a user