diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index a4aa72fb342c..508b8169cebe 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -44,6 +44,7 @@ class DanswerGenAIModel(str, Enum): HUGGINGFACE = "huggingface-client-completion" HUGGINGFACE_CHAT = "huggingface-client-chat-completion" REQUEST = "request-completion" + TRANSFORMERS = "transformers" class ModelHostType(str, Enum): diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index ba7d55247cd6..4de79252f75c 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -49,11 +49,15 @@ VERIFIED_MODELS = { # The "chat" model below is actually "instruction finetuned" and does not support conversational DanswerGenAIModel.HUGGINGFACE.value: ["meta-llama/Llama-2-70b-chat-hf"], DanswerGenAIModel.HUGGINGFACE_CHAT.value: ["meta-llama/Llama-2-70b-hf"], + # Created by Deepset.ai + # https://huggingface.co/deepset/deberta-v3-large-squad2 + # Model provided with no modifications + DanswerGenAIModel.TRANSFORMERS.value: ["deepset/deberta-v3-large-squad2"], } # Sets the internal Danswer model class to use INTERNAL_MODEL_VERSION = os.environ.get( - "INTERNAL_MODEL_VERSION", DanswerGenAIModel.OPENAI_CHAT.value + "INTERNAL_MODEL_VERSION", DanswerGenAIModel.TRANSFORMERS.value ) # If the Generative AI model requires an API key for access, otherwise can leave blank diff --git a/backend/danswer/direct_qa/__init__.py b/backend/danswer/direct_qa/__init__.py index cd25ae5aa4f7..92df801fa2ba 100644 --- a/backend/danswer/direct_qa/__init__.py +++ b/backend/danswer/direct_qa/__init__.py @@ -16,6 +16,7 @@ from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA from danswer.direct_qa.huggingface import HuggingFaceCompletionQA from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.local_transformers import TransformerQA from danswer.direct_qa.open_ai import OpenAIChatCompletionQA from danswer.direct_qa.open_ai import OpenAICompletionQA from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor @@ -79,6 +80,8 @@ def get_default_backend_qa_model( return HuggingFaceCompletionQA(api_key=api_key, **kwargs) elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value: return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs) + elif internal_model == DanswerGenAIModel.TRANSFORMERS: + return TransformerQA() elif internal_model == DanswerGenAIModel.REQUEST.value: if endpoint is None or model_host_type is None: raise ValueError( diff --git a/backend/danswer/direct_qa/local_transformers.py b/backend/danswer/direct_qa/local_transformers.py new file mode 100644 index 000000000000..17e986a847e3 --- /dev/null +++ b/backend/danswer/direct_qa/local_transformers.py @@ -0,0 +1,147 @@ +import re +from collections.abc import Generator +from typing import Any + +from transformers import pipeline # type:ignore +from transformers import QuestionAnsweringPipeline # type:ignore + +from danswer.chunking.models import InferenceChunk +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.qa_utils import structure_quotes_for_response +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time + +logger = setup_logger() + +TRANSFORMER_DEFAULT_MAX_CONTEXT = 512 + +_TRANSFORMER_MODEL: QuestionAnsweringPipeline | None = None + + +def get_default_transformer_model( + model_version: str = GEN_AI_MODEL_VERSION, + max_context: int = TRANSFORMER_DEFAULT_MAX_CONTEXT, +) -> QuestionAnsweringPipeline: + global _TRANSFORMER_MODEL + if _TRANSFORMER_MODEL is None: + _TRANSFORMER_MODEL = pipeline( + "question-answering", model=model_version, max_seq_len=max_context + ) + + return _TRANSFORMER_MODEL + + +def find_extended_answer(answer: str, context: str) -> str: + """Try to extend the answer by matching across the context text and extending before + and after the quote to some termination character""" + result = re.search( + r"(^|\n\r?|\.)(?P[^\n]{{0,250}}{}[^\n]{{0,250}})(\.|$|\n\r?)".format( + re.escape(answer) + ), + context, + flags=re.MULTILINE | re.DOTALL, + ) + if result: + return result.group("content") + + return answer + + +class TransformerQA(QAModel): + @staticmethod + def _answer_one_chunk( + query: str, + chunk: InferenceChunk, + max_context_len: int = TRANSFORMER_DEFAULT_MAX_CONTEXT, + max_cutoff: float = 0.9, + min_cutoff: float = 0.5, + ) -> tuple[str | None, DanswerQuote | None]: + """Because this type of QA model only takes 1 chunk of context with a fairly small token limit + We have to iterate the checks and check if the answer is found in any of the chunks. + This type of approach does not allow for interpolating answers across chunks + """ + model = get_default_transformer_model() + model_out = model(question=query, context=chunk.content, max_answer_len=128) + + answer = model_out.get("answer") + confidence = model_out.get("score") + + if answer is None: + return None, None + + logger.info(f"Transformer Answer: {answer}") + logger.debug(f"Transformer Confidence: {confidence}") + + # Model tends to be overconfident on short chunks + # so min score required increases as chunk size decreases + # If it's at least 0.9, then it's good enough regardless + # Default minimum of 0.5 required + score_cutoff = max( + min(max_cutoff, 1 - len(chunk.content) / max_context_len), min_cutoff + ) + if confidence < score_cutoff: + return None, None + + extended_answer = find_extended_answer(answer, chunk.content) + + danswer_quote = DanswerQuote( + quote=answer, + document_id=chunk.document_id, + link=chunk.source_links[0] if chunk.source_links else None, + source_type=chunk.source_type, + semantic_identifier=chunk.semantic_identifier, + blurb=chunk.blurb, + ) + + return extended_answer, danswer_quote + + def warm_up_model(self) -> None: + get_default_transformer_model() + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + danswer_quotes: list[DanswerQuote] = [] + d_answers: list[str] = [] + for chunk in context_docs: + answer, quote = self._answer_one_chunk(query, chunk) + if answer is not None and quote is not None: + d_answers.append(answer) + danswer_quotes.append(quote) + + answers_list = [ + f"Answer {ind}: {answer.strip()}" + for ind, answer in enumerate(d_answers, start=1) + ] + combined_answer = "\n".join(answers_list) + return DanswerAnswer(answer=combined_answer), danswer_quotes + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + quotes: list[DanswerQuote] = [] + answers: list[str] = [] + for chunk in context_docs: + answer, quote = self._answer_one_chunk(query, chunk) + if answer is not None and quote is not None: + answers.append(answer) + quotes.append(quote) + + # Delay the output of the answers so there isn't long gap between first answer and quotes + answer_count = 1 + for answer in answers: + if answer_count == 1: + yield {"answer_data": "Source 1: "} + else: + yield {"answer_data": f"\nSource {answer_count}: "} + answer_count += 1 + for char in answer.strip(): + yield {"answer_data": char} + + yield {"answer_finished": True} + + yield structure_quotes_for_response(quotes)