mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-03 03:31:09 +02:00
Add LangChain-based LLM
This commit is contained in:
parent
20b6369eea
commit
4469447fde
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
.env
|
.env
|
||||||
|
.DS_store
|
||||||
|
@ -6,7 +6,7 @@ from danswer.datastores.document_index import get_default_document_index
|
|||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.llm_utils import get_default_llm
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
@ -73,7 +73,7 @@ def answer_question(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qa_model = get_default_llm(timeout=answer_generation_timeout)
|
qa_model = get_default_qa_model(timeout=answer_generation_timeout)
|
||||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=None,
|
answer=None,
|
||||||
|
@ -9,6 +9,8 @@ from danswer.configs.constants import ModelHostType
|
|||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||||
@ -17,12 +19,16 @@ from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
|||||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.local_transformers import TransformerQA
|
from danswer.direct_qa.local_transformers import TransformerQA
|
||||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
|
||||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||||
|
from danswer.direct_qa.qa_block import JsonChatQAHandler
|
||||||
|
from danswer.direct_qa.qa_block import QABlock
|
||||||
|
from danswer.direct_qa.qa_block import QAHandler
|
||||||
|
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
|
from danswer.llm.build import get_default_llm
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -32,7 +38,7 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
|||||||
if not model_api_key:
|
if not model_api_key:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
qa_model = get_default_llm(api_key=model_api_key, timeout=5)
|
qa_model = get_default_qa_model(api_key=model_api_key, timeout=5)
|
||||||
|
|
||||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
@ -47,12 +53,21 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm(
|
def get_default_qa_handler(model: str) -> QAHandler:
|
||||||
|
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||||
|
return JsonChatQAHandler()
|
||||||
|
|
||||||
|
return SimpleChatQAHandler()
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_qa_model(
|
||||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||||
api_key: str | None = GEN_AI_API_KEY,
|
api_key: str | None = GEN_AI_API_KEY,
|
||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> QAModel:
|
) -> QAModel:
|
||||||
if not api_key:
|
if not api_key:
|
||||||
@ -61,6 +76,31 @@ def get_default_llm(
|
|||||||
except ConfigNotFoundError:
|
except ConfigNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# un-used arguments will be ignored by the underlying `LLM` class
|
||||||
|
# if any args are missing, a `TypeError` will be thrown
|
||||||
|
llm = get_default_llm(
|
||||||
|
model=internal_model,
|
||||||
|
api_key=api_key,
|
||||||
|
model_version=model_version,
|
||||||
|
endpoint=endpoint,
|
||||||
|
model_host_type=model_host_type,
|
||||||
|
timeout=timeout,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
qa_handler = get_default_qa_handler(model=internal_model)
|
||||||
|
|
||||||
|
return QABlock(
|
||||||
|
llm=llm,
|
||||||
|
qa_handler=qa_handler,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.exception(
|
||||||
|
"Unable to build a QABlock with the new approach, going back to the "
|
||||||
|
"legacy approach"
|
||||||
|
)
|
||||||
|
|
||||||
if internal_model in [
|
if internal_model in [
|
||||||
DanswerGenAIModel.GPT4ALL.value,
|
DanswerGenAIModel.GPT4ALL.value,
|
||||||
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
||||||
@ -70,8 +110,6 @@ def get_default_llm(
|
|||||||
|
|
||||||
if internal_model == DanswerGenAIModel.OPENAI.value:
|
if internal_model == DanswerGenAIModel.OPENAI.value:
|
||||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||||
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
|
|
||||||
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
|
||||||
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
||||||
return GPT4AllCompletionQA(**kwargs)
|
return GPT4AllCompletionQA(**kwargs)
|
||||||
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
||||||
|
@ -25,9 +25,6 @@ from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
|||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
|
|
||||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
|
||||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||||
@ -207,107 +204,3 @@ class OpenAICompletionQA(OpenAIQAModel):
|
|||||||
context_docs=context_docs,
|
context_docs=context_docs,
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionQA(OpenAIQAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
|
||||||
model_version: str = GEN_AI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
timeout: int | None = None,
|
|
||||||
reflexion_try_count: int = 0,
|
|
||||||
api_key: str | None = None,
|
|
||||||
include_metadata: bool = INCLUDE_METADATA,
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.model_version = model_version
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.reflexion_try_count = reflexion_try_count
|
|
||||||
self.timeout = timeout
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
|
||||||
for event in response:
|
|
||||||
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
|
||||||
if (
|
|
||||||
"content" not in event_dict
|
|
||||||
): # could be a role message or empty termination
|
|
||||||
continue
|
|
||||||
yield event_dict["content"]
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
) -> AnswerQuestionReturn:
|
|
||||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
|
||||||
|
|
||||||
messages = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(json.dumps(messages, indent=4))
|
|
||||||
model_output = ""
|
|
||||||
for _ in range(self.reflexion_try_count + 1):
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.ChatCompletion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=_ensure_openai_api_key(self.api_key),
|
|
||||||
messages=messages,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model_output = cast(
|
|
||||||
str, response["choices"][0]["message"]["content"]
|
|
||||||
).strip()
|
|
||||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
|
||||||
messages.extend([assistant_msg, get_json_chat_reflexion_msg()])
|
|
||||||
logger.info(
|
|
||||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
answer, quotes = process_answer(model_output, context_docs)
|
|
||||||
return answer, quotes
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
|
||||||
|
|
||||||
messages = self.prompt_processor.fill_prompt(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(json.dumps(messages, indent=4))
|
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.ChatCompletion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=_ensure_openai_api_key(self.api_key),
|
|
||||||
messages=messages,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
stream=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = self._generate_tokens_from_response(response)
|
|
||||||
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=tokens,
|
|
||||||
context_docs=context_docs,
|
|
||||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
|
||||||
)
|
|
||||||
|
176
backend/danswer/direct_qa/qa_block.py
Normal file
176
backend/danswer/direct_qa/qa_block.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
import abc
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from langchain.schema.messages import AIMessage
|
||||||
|
from langchain.schema.messages import BaseMessage
|
||||||
|
from langchain.schema.messages import HumanMessage
|
||||||
|
from langchain.schema.messages import SystemMessage
|
||||||
|
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||||
|
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
|
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||||
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
|
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
|
from danswer.llm.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def _dict_based_prompt_to_langchain_prompt(
|
||||||
|
messages: list[dict[str, str]]
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
prompt: list[BaseMessage] = []
|
||||||
|
for message in messages:
|
||||||
|
role = message.get("role")
|
||||||
|
content = message.get("content")
|
||||||
|
if not role:
|
||||||
|
raise ValueError(f"Message missing `role`: {message}")
|
||||||
|
if not content:
|
||||||
|
raise ValueError(f"Message missing `content`: {message}")
|
||||||
|
elif role == "user":
|
||||||
|
prompt.append(HumanMessage(content=content))
|
||||||
|
elif role == "system":
|
||||||
|
prompt.append(SystemMessage(content=content))
|
||||||
|
elif role == "assistant":
|
||||||
|
prompt.append(AIMessage(content=content))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role: {role}")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
||||||
|
return [HumanMessage(content=message)]
|
||||||
|
|
||||||
|
|
||||||
|
class QAHandler(abc.ABC):
|
||||||
|
"""Evolution of the `PromptProcessor` - handles both building the prompt and
|
||||||
|
processing the response. These are neccessarily coupled, since the prompt determines
|
||||||
|
the response format (and thus how it should be parsed into an answer + quotes)."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def build_prompt(
|
||||||
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def process_response(
|
||||||
|
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||||
|
) -> AnswerQuestionStreamReturn:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class JsonChatQAHandler(QAHandler):
|
||||||
|
def build_prompt(
|
||||||
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
return _dict_based_prompt_to_langchain_prompt(
|
||||||
|
JsonChatProcessor.fill_prompt(
|
||||||
|
question=query, chunks=context_chunks, include_metadata=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
tokens: Iterator[str],
|
||||||
|
context_chunks: list[InferenceChunk],
|
||||||
|
) -> AnswerQuestionStreamReturn:
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=tokens,
|
||||||
|
context_docs=context_chunks,
|
||||||
|
is_json_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleChatQAHandler(QAHandler):
|
||||||
|
def build_prompt(
|
||||||
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
return _str_prompt_to_langchain_prompt(
|
||||||
|
WeakModelFreeformProcessor.fill_prompt(
|
||||||
|
question=query,
|
||||||
|
chunks=context_chunks,
|
||||||
|
include_metadata=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
tokens: Iterator[str],
|
||||||
|
context_chunks: list[InferenceChunk],
|
||||||
|
) -> AnswerQuestionStreamReturn:
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=tokens,
|
||||||
|
context_docs=context_chunks,
|
||||||
|
is_json_prompt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tiktoken_trim_chunks(
|
||||||
|
chunks: list[InferenceChunk], max_chunk_toks: int = 512
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
"""Edit chunks that have too high token count. Generally due to parsing issues or
|
||||||
|
characters from another language that are 1 char = 1 token
|
||||||
|
Trimming by tokens leads to information loss but currently no better way of handling
|
||||||
|
NOTE: currently gpt-3.5 / gpt-4 tokenizer across all LLMs currently
|
||||||
|
TODO: make "chunk modification" its own step in the pipeline
|
||||||
|
"""
|
||||||
|
encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
new_chunks = copy(chunks)
|
||||||
|
for ind, chunk in enumerate(new_chunks):
|
||||||
|
tokens = encoder.encode(chunk.content)
|
||||||
|
if len(tokens) > max_chunk_toks:
|
||||||
|
new_chunk = copy(chunk)
|
||||||
|
new_chunk.content = encoder.decode(tokens[:max_chunk_toks])
|
||||||
|
new_chunks[ind] = new_chunk
|
||||||
|
return new_chunks
|
||||||
|
|
||||||
|
|
||||||
|
class QABlock(QAModel):
|
||||||
|
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
|
||||||
|
self._llm = llm
|
||||||
|
self._qa_handler = qa_handler
|
||||||
|
|
||||||
|
def warm_up_model(self) -> None:
|
||||||
|
"""This is called during server start up to load the models into memory
|
||||||
|
in case the chosen LLM is not accessed via API"""
|
||||||
|
self._llm.stream("Ignore this!")
|
||||||
|
|
||||||
|
def answer_question(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context_docs: list[InferenceChunk],
|
||||||
|
) -> AnswerQuestionReturn:
|
||||||
|
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
||||||
|
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||||
|
tokens = self._llm.stream(prompt)
|
||||||
|
|
||||||
|
final_answer = ""
|
||||||
|
quotes = DanswerQuotes([])
|
||||||
|
for output in self._qa_handler.process_response(tokens, trimmed_context_docs):
|
||||||
|
if output is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(output, DanswerAnswerPiece):
|
||||||
|
if output.answer_piece:
|
||||||
|
final_answer += output.answer_piece
|
||||||
|
elif isinstance(output, DanswerQuotes):
|
||||||
|
quotes = output
|
||||||
|
|
||||||
|
return DanswerAnswer(final_answer), quotes
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context_docs: list[InferenceChunk],
|
||||||
|
) -> AnswerQuestionStreamReturn:
|
||||||
|
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
||||||
|
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||||
|
tokens = self._llm.stream(prompt)
|
||||||
|
yield from self._qa_handler.process_response(tokens, trimmed_context_docs)
|
@ -2,6 +2,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from collections.abc import Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
@ -191,7 +192,7 @@ def extract_quotes_from_completed_token_stream(
|
|||||||
|
|
||||||
|
|
||||||
def process_model_tokens(
|
def process_model_tokens(
|
||||||
tokens: Generator[str, None, None],
|
tokens: Iterator[str],
|
||||||
context_docs: list[InferenceChunk],
|
context_docs: list[InferenceChunk],
|
||||||
is_json_prompt: bool = True,
|
is_json_prompt: bool = True,
|
||||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||||
|
45
backend/danswer/llm/azure.py
Normal file
45
backend/danswer/llm/azure.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
|
||||||
|
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||||
|
from danswer.configs.model_configs import API_VERSION_OPENAI
|
||||||
|
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
||||||
|
from danswer.llm.llm import LangChainChatLLM
|
||||||
|
from danswer.llm.utils import should_be_verbose
|
||||||
|
|
||||||
|
|
||||||
|
class AzureGPT(LangChainChatLLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
max_output_tokens: int,
|
||||||
|
timeout: int,
|
||||||
|
model_version: str,
|
||||||
|
api_base: str = API_BASE_OPENAI,
|
||||||
|
api_version: str = API_VERSION_OPENAI,
|
||||||
|
deployment_name: str = AZURE_DEPLOYMENT_ID,
|
||||||
|
*args: list[Any],
|
||||||
|
**kwargs: dict[str, Any]
|
||||||
|
):
|
||||||
|
self._llm = AzureChatOpenAI(
|
||||||
|
model=model_version,
|
||||||
|
openai_api_type="azure",
|
||||||
|
openai_api_base=api_base,
|
||||||
|
openai_api_version=api_version,
|
||||||
|
deployment_name=deployment_name,
|
||||||
|
openai_api_key=api_key,
|
||||||
|
max_tokens=max_output_tokens,
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=timeout,
|
||||||
|
model_kwargs={
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
},
|
||||||
|
verbose=should_be_verbose(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm(self) -> AzureChatOpenAI:
|
||||||
|
return self._llm
|
16
backend/danswer/llm/build.py
Normal file
16
backend/danswer/llm/build.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from danswer.configs.constants import DanswerGenAIModel
|
||||||
|
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||||
|
from danswer.llm.azure import AzureGPT
|
||||||
|
from danswer.llm.llm import LLM
|
||||||
|
from danswer.llm.openai import OpenAIGPT
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_llm(model: str, **kwargs: Any) -> LLM:
|
||||||
|
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||||
|
if API_TYPE_OPENAI == "azure":
|
||||||
|
return AzureGPT(**kwargs)
|
||||||
|
return OpenAIGPT(**kwargs)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown LLM model: {model}")
|
53
backend/danswer/llm/google_colab_demo.py
Normal file
53
backend/danswer/llm/google_colab_demo.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
from langchain.schema.messages import BaseMessageChunk
|
||||||
|
from requests import Timeout
|
||||||
|
|
||||||
|
from danswer.llm.llm import LLM
|
||||||
|
from danswer.llm.utils import convert_input
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleColabDemo(LLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
max_output_tokens: int,
|
||||||
|
timeout: int,
|
||||||
|
*args: list[Any],
|
||||||
|
**kwargs: dict[str, Any],
|
||||||
|
):
|
||||||
|
self._endpoint = endpoint
|
||||||
|
self._max_output_tokens = max_output_tokens
|
||||||
|
self._timeout = timeout
|
||||||
|
|
||||||
|
def _execute(self, input: LanguageModelInput) -> str:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": convert_input(input),
|
||||||
|
"parameters": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": self._max_output_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self._endpoint, headers=headers, json=data, timeout=self._timeout
|
||||||
|
)
|
||||||
|
except Timeout as error:
|
||||||
|
raise Timeout(f"Model inference to {self._endpoint} timed out") from error
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return json.loads(response.content).get("generated_text", "")
|
||||||
|
|
||||||
|
def invoke(self, input: LanguageModelInput) -> str:
|
||||||
|
return self._execute(input)
|
||||||
|
|
||||||
|
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
||||||
|
yield self._execute(input)
|
44
backend/danswer/llm/llm.py
Normal file
44
backend/danswer/llm/llm.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import abc
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
|
||||||
|
from danswer.llm.utils import message_generator_to_string_generator
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class LLM(abc.ABC):
|
||||||
|
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||||
|
to use these implementations to connect to a variety of LLM providers."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def invoke(self, input: LanguageModelInput) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainChatLLM(LLM, abc.ABC):
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def llm(self) -> BaseChatModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _log_model_config(self) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, input: LanguageModelInput) -> str:
|
||||||
|
self._log_model_config()
|
||||||
|
return self.llm.invoke(input).content
|
||||||
|
|
||||||
|
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
||||||
|
self._log_model_config()
|
||||||
|
yield from message_generator_to_string_generator(self.llm.stream(input))
|
35
backend/danswer/llm/openai.py
Normal file
35
backend/danswer/llm/openai.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
|
||||||
|
from danswer.llm.llm import LangChainChatLLM
|
||||||
|
from danswer.llm.utils import should_be_verbose
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIGPT(LangChainChatLLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
max_output_tokens: int,
|
||||||
|
timeout: int,
|
||||||
|
model_version: str,
|
||||||
|
*args: list[Any],
|
||||||
|
**kwargs: dict[str, Any]
|
||||||
|
):
|
||||||
|
self._llm = ChatOpenAI(
|
||||||
|
model=model_version,
|
||||||
|
openai_api_key=api_key,
|
||||||
|
max_tokens=max_output_tokens,
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=timeout,
|
||||||
|
model_kwargs={
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
},
|
||||||
|
verbose=should_be_verbose(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm(self) -> ChatOpenAI:
|
||||||
|
return self._llm
|
43
backend/danswer/llm/utils.py
Normal file
43
backend/danswer/llm/utils.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from langchain.prompts.base import StringPromptValue
|
||||||
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
|
from langchain.schema import (
|
||||||
|
PromptValue,
|
||||||
|
)
|
||||||
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
from langchain.schema.messages import BaseMessageChunk
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import LOG_LEVEL
|
||||||
|
|
||||||
|
|
||||||
|
def message_generator_to_string_generator(
|
||||||
|
messages: Iterator[BaseMessageChunk],
|
||||||
|
) -> Iterator[str]:
|
||||||
|
for message in messages:
|
||||||
|
yield message.content
|
||||||
|
|
||||||
|
|
||||||
|
def convert_input(input: LanguageModelInput) -> str:
|
||||||
|
"""Heavily inspired by:
|
||||||
|
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
||||||
|
"""
|
||||||
|
prompt_value = None
|
||||||
|
if isinstance(input, PromptValue):
|
||||||
|
prompt_value = input
|
||||||
|
elif isinstance(input, str):
|
||||||
|
prompt_value = StringPromptValue(text=input)
|
||||||
|
elif isinstance(input, list):
|
||||||
|
prompt_value = ChatPromptValue(messages=input)
|
||||||
|
|
||||||
|
if prompt_value is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid input type {type(input)}. "
|
||||||
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_value.to_string()
|
||||||
|
|
||||||
|
|
||||||
|
def should_be_verbose() -> bool:
|
||||||
|
return LOG_LEVEL == "debug"
|
@ -29,7 +29,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.datastores.document_index import get_default_document_index
|
from danswer.datastores.document_index import get_default_document_index
|
||||||
from danswer.db.credentials import create_initial_public_credential
|
from danswer.db.credentials import create_initial_public_credential
|
||||||
from danswer.direct_qa.llm_utils import get_default_llm
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.server.credential import router as credential_router
|
from danswer.server.credential import router as credential_router
|
||||||
from danswer.server.event_loading import router as event_processing_router
|
from danswer.server.event_loading import router as event_processing_router
|
||||||
from danswer.server.health import router as health_router
|
from danswer.server.health import router as health_router
|
||||||
@ -178,7 +178,7 @@ def get_application() -> FastAPI:
|
|||||||
|
|
||||||
logger.info("Warming up local NLP models.")
|
logger.info("Warming up local NLP models.")
|
||||||
warm_up_models()
|
warm_up_models()
|
||||||
qa_model = get_default_llm()
|
qa_model = get_default_qa_model()
|
||||||
qa_model.warm_up_model()
|
qa_model.warm_up_model()
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
|
@ -57,7 +57,7 @@ from danswer.db.index_attempt import get_latest_index_attempts
|
|||||||
from danswer.db.models import DeletionAttempt
|
from danswer.db.models import DeletionAttempt
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
||||||
from danswer.direct_qa.llm_utils import get_default_llm
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.direct_qa.open_ai import get_gen_ai_api_key
|
from danswer.direct_qa.open_ai import get_gen_ai_api_key
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
@ -423,7 +423,7 @@ def validate_existing_genai_api_key(
|
|||||||
) -> None:
|
) -> None:
|
||||||
# OpenAI key is only used for generative QA, so no need to validate this
|
# OpenAI key is only used for generative QA, so no need to validate this
|
||||||
# if it's turned off or if a non-OpenAI model is being used
|
# if it's turned off or if a non-OpenAI model is being used
|
||||||
if DISABLE_GENERATIVE_AI or not get_default_llm().requires_api_key:
|
if DISABLE_GENERATIVE_AI or not get_default_qa_model().requires_api_key:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Only validate every so often
|
# Only validate every so often
|
||||||
|
@ -15,7 +15,7 @@ from danswer.db.models import User
|
|||||||
from danswer.direct_qa.answer_question import answer_question
|
from danswer.direct_qa.answer_question import answer_question
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.llm_utils import get_default_llm
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
@ -174,7 +174,7 @@ def stream_direct_qa(
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qa_model = get_default_llm()
|
qa_model = get_default_qa_model()
|
||||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||||
logger.exception("Unable to get QA model")
|
logger.exception("Unable to get QA model")
|
||||||
yield get_json_line({"error": str(e)})
|
yield get_json_line({"error": str(e)})
|
||||||
@ -199,6 +199,7 @@ def stream_direct_qa(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
yield get_json_line({"error": str(e)})
|
yield get_json_line({"error": str(e)})
|
||||||
|
logger.exception("Failed to run QA")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ httpx==0.23.3
|
|||||||
httpx-oauth==0.11.2
|
httpx-oauth==0.11.2
|
||||||
huggingface-hub==0.16.4
|
huggingface-hub==0.16.4
|
||||||
jira==3.5.1
|
jira==3.5.1
|
||||||
|
langchain==0.0.273
|
||||||
Mako==1.2.4
|
Mako==1.2.4
|
||||||
nltk==3.8.1
|
nltk==3.8.1
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
|
Loading…
x
Reference in New Issue
Block a user