mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
Add LangChain-based LLM
This commit is contained in:
parent
20b6369eea
commit
4469447fde
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,2 @@
|
||||
.env
|
||||
.env
|
||||
.DS_store
|
||||
|
@ -6,7 +6,7 @@ from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
from danswer.search.models import QueryFlow
|
||||
@ -73,7 +73,7 @@ def answer_question(
|
||||
)
|
||||
|
||||
try:
|
||||
qa_model = get_default_llm(timeout=answer_generation_timeout)
|
||||
qa_model = get_default_qa_model(timeout=answer_generation_timeout)
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
|
@ -9,6 +9,8 @@ from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
@ -17,12 +19,16 @@ from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.local_transformers import TransformerQA
|
||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||
from danswer.direct_qa.qa_block import JsonChatQAHandler
|
||||
from danswer.direct_qa.qa_block import QABlock
|
||||
from danswer.direct_qa.qa_block import QAHandler
|
||||
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -32,7 +38,7 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
if not model_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_llm(api_key=model_api_key, timeout=5)
|
||||
qa_model = get_default_qa_model(api_key=model_api_key, timeout=5)
|
||||
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
@ -47,12 +53,21 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
def get_default_qa_handler(model: str) -> QAHandler:
|
||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
return JsonChatQAHandler()
|
||||
|
||||
return SimpleChatQAHandler()
|
||||
|
||||
|
||||
def get_default_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
**kwargs: Any,
|
||||
) -> QAModel:
|
||||
if not api_key:
|
||||
@ -61,6 +76,31 @@ def get_default_llm(
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
# un-used arguments will be ignored by the underlying `LLM` class
|
||||
# if any args are missing, a `TypeError` will be thrown
|
||||
llm = get_default_llm(
|
||||
model=internal_model,
|
||||
api_key=api_key,
|
||||
model_version=model_version,
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
timeout=timeout,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
qa_handler = get_default_qa_handler(model=internal_model)
|
||||
|
||||
return QABlock(
|
||||
llm=llm,
|
||||
qa_handler=qa_handler,
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Unable to build a QABlock with the new approach, going back to the "
|
||||
"legacy approach"
|
||||
)
|
||||
|
||||
if internal_model in [
|
||||
DanswerGenAIModel.GPT4ALL.value,
|
||||
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
||||
@ -70,8 +110,6 @@ def get_default_llm(
|
||||
|
||||
if internal_model == DanswerGenAIModel.OPENAI.value:
|
||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
||||
return GPT4AllCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
||||
|
@ -25,9 +25,6 @@ from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
@ -207,107 +204,3 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
timeout: int | None = None,
|
||||
reflexion_try_count: int = 0,
|
||||
api_key: str | None = None,
|
||||
include_metadata: bool = INCLUDE_METADATA,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.reflexion_try_count = reflexion_try_count
|
||||
self.timeout = timeout
|
||||
self.include_metadata = include_metadata
|
||||
self.api_key = api_key
|
||||
|
||||
@staticmethod
|
||||
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
||||
for event in response:
|
||||
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
continue
|
||||
yield event_dict["content"]
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> AnswerQuestionReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(json.dumps(messages, indent=4))
|
||||
model_output = ""
|
||||
for _ in range(self.reflexion_try_count + 1):
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
messages=messages,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = cast(
|
||||
str, response["choices"][0]["message"]["content"]
|
||||
).strip()
|
||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||
messages.extend([assistant_msg, get_json_chat_reflexion_msg()])
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes = process_answer(model_output, context_docs)
|
||||
return answer, quotes
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(json.dumps(messages, indent=4))
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
messages=messages,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
|
||||
tokens = self._generate_tokens_from_response(response)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
176
backend/danswer/direct_qa/qa_block.py
Normal file
176
backend/danswer/direct_qa/qa_block.py
Normal file
@ -0,0 +1,176 @@
|
||||
import abc
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from copy import copy
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.llm.llm import LLM
|
||||
|
||||
|
||||
def _dict_based_prompt_to_langchain_prompt(
|
||||
messages: list[dict[str, str]]
|
||||
) -> list[BaseMessage]:
|
||||
prompt: list[BaseMessage] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if not role:
|
||||
raise ValueError(f"Message missing `role`: {message}")
|
||||
if not content:
|
||||
raise ValueError(f"Message missing `content`: {message}")
|
||||
elif role == "user":
|
||||
prompt.append(HumanMessage(content=content))
|
||||
elif role == "system":
|
||||
prompt.append(SystemMessage(content=content))
|
||||
elif role == "assistant":
|
||||
prompt.append(AIMessage(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
return prompt
|
||||
|
||||
|
||||
def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
||||
return [HumanMessage(content=message)]
|
||||
|
||||
|
||||
class QAHandler(abc.ABC):
|
||||
"""Evolution of the `PromptProcessor` - handles both building the prompt and
|
||||
processing the response. These are neccessarily coupled, since the prompt determines
|
||||
the response format (and thus how it should be parsed into an answer + quotes)."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def process_response(
|
||||
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JsonChatQAHandler(QAHandler):
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
return _dict_based_prompt_to_langchain_prompt(
|
||||
JsonChatProcessor.fill_prompt(
|
||||
question=query, chunks=context_chunks, include_metadata=False
|
||||
)
|
||||
)
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
tokens: Iterator[str],
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_chunks,
|
||||
is_json_prompt=True,
|
||||
)
|
||||
|
||||
|
||||
class SimpleChatQAHandler(QAHandler):
|
||||
def build_prompt(
|
||||
self, query: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[BaseMessage]:
|
||||
return _str_prompt_to_langchain_prompt(
|
||||
WeakModelFreeformProcessor.fill_prompt(
|
||||
question=query,
|
||||
chunks=context_chunks,
|
||||
include_metadata=False,
|
||||
)
|
||||
)
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
tokens: Iterator[str],
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_chunks,
|
||||
is_json_prompt=False,
|
||||
)
|
||||
|
||||
|
||||
def _tiktoken_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = 512
|
||||
) -> list[InferenceChunk]:
|
||||
"""Edit chunks that have too high token count. Generally due to parsing issues or
|
||||
characters from another language that are 1 char = 1 token
|
||||
Trimming by tokens leads to information loss but currently no better way of handling
|
||||
NOTE: currently gpt-3.5 / gpt-4 tokenizer across all LLMs currently
|
||||
TODO: make "chunk modification" its own step in the pipeline
|
||||
"""
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
tokens = encoder.encode(chunk.content)
|
||||
if len(tokens) > max_chunk_toks:
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = encoder.decode(tokens[:max_chunk_toks])
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
|
||||
|
||||
class QABlock(QAModel):
|
||||
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
|
||||
self._llm = llm
|
||||
self._qa_handler = qa_handler
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
"""This is called during server start up to load the models into memory
|
||||
in case the chosen LLM is not accessed via API"""
|
||||
self._llm.stream("Ignore this!")
|
||||
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> AnswerQuestionReturn:
|
||||
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
||||
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||
tokens = self._llm.stream(prompt)
|
||||
|
||||
final_answer = ""
|
||||
quotes = DanswerQuotes([])
|
||||
for output in self._qa_handler.process_response(tokens, trimmed_context_docs):
|
||||
if output is None:
|
||||
continue
|
||||
|
||||
if isinstance(output, DanswerAnswerPiece):
|
||||
if output.answer_piece:
|
||||
final_answer += output.answer_piece
|
||||
elif isinstance(output, DanswerQuotes):
|
||||
quotes = output
|
||||
|
||||
return DanswerAnswer(final_answer), quotes
|
||||
|
||||
def answer_question_stream(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
||||
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||
tokens = self._llm.stream(prompt)
|
||||
yield from self._qa_handler.process_response(tokens, trimmed_context_docs)
|
@ -2,6 +2,7 @@ import json
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
@ -191,7 +192,7 @@ def extract_quotes_from_completed_token_stream(
|
||||
|
||||
|
||||
def process_model_tokens(
|
||||
tokens: Generator[str, None, None],
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[InferenceChunk],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
|
45
backend/danswer/llm/azure.py
Normal file
45
backend/danswer/llm/azure.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_VERSION_OPENAI
|
||||
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
||||
from danswer.llm.llm import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
|
||||
|
||||
class AzureGPT(LangChainChatLLM):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
max_output_tokens: int,
|
||||
timeout: int,
|
||||
model_version: str,
|
||||
api_base: str = API_BASE_OPENAI,
|
||||
api_version: str = API_VERSION_OPENAI,
|
||||
deployment_name: str = AZURE_DEPLOYMENT_ID,
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
self._llm = AzureChatOpenAI(
|
||||
model=model_version,
|
||||
openai_api_type="azure",
|
||||
openai_api_base=api_base,
|
||||
openai_api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
openai_api_key=api_key,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=0,
|
||||
request_timeout=timeout,
|
||||
model_kwargs={
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
verbose=should_be_verbose(),
|
||||
)
|
||||
|
||||
@property
|
||||
def llm(self) -> AzureChatOpenAI:
|
||||
return self._llm
|
16
backend/danswer/llm/build.py
Normal file
16
backend/danswer/llm/build.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.llm.azure import AzureGPT
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.openai import OpenAIGPT
|
||||
|
||||
|
||||
def get_default_llm(model: str, **kwargs: Any) -> LLM:
|
||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
if API_TYPE_OPENAI == "azure":
|
||||
return AzureGPT(**kwargs)
|
||||
return OpenAIGPT(**kwargs)
|
||||
|
||||
raise ValueError(f"Unknown LLM model: {model}")
|
53
backend/danswer/llm/google_colab_demo.py
Normal file
53
backend/danswer/llm/google_colab_demo.py
Normal file
@ -0,0 +1,53 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import BaseMessageChunk
|
||||
from requests import Timeout
|
||||
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import convert_input
|
||||
|
||||
|
||||
class GoogleColabDemo(LLM):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
max_output_tokens: int,
|
||||
timeout: int,
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any],
|
||||
):
|
||||
self._endpoint = endpoint
|
||||
self._max_output_tokens = max_output_tokens
|
||||
self._timeout = timeout
|
||||
|
||||
def _execute(self, input: LanguageModelInput) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": convert_input(input),
|
||||
"parameters": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self._max_output_tokens,
|
||||
},
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
self._endpoint, headers=headers, json=data, timeout=self._timeout
|
||||
)
|
||||
except Timeout as error:
|
||||
raise Timeout(f"Model inference to {self._endpoint} timed out") from error
|
||||
|
||||
response.raise_for_status()
|
||||
return json.loads(response.content).get("generated_text", "")
|
||||
|
||||
def invoke(self, input: LanguageModelInput) -> str:
|
||||
return self._execute(input)
|
||||
|
||||
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
||||
yield self._execute(input)
|
44
backend/danswer/llm/llm.py
Normal file
44
backend/danswer/llm/llm.py
Normal file
@ -0,0 +1,44 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||
to use these implementations to connect to a variety of LLM providers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def invoke(self, input: LanguageModelInput) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LangChainChatLLM(LLM, abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def llm(self) -> BaseChatModel:
|
||||
raise NotImplementedError
|
||||
|
||||
def _log_model_config(self) -> None:
|
||||
logger.debug(
|
||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||
)
|
||||
|
||||
def invoke(self, input: LanguageModelInput) -> str:
|
||||
self._log_model_config()
|
||||
return self.llm.invoke(input).content
|
||||
|
||||
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
||||
self._log_model_config()
|
||||
yield from message_generator_to_string_generator(self.llm.stream(input))
|
35
backend/danswer/llm/openai.py
Normal file
35
backend/danswer/llm/openai.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
|
||||
from danswer.llm.llm import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
|
||||
|
||||
class OpenAIGPT(LangChainChatLLM):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
max_output_tokens: int,
|
||||
timeout: int,
|
||||
model_version: str,
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
self._llm = ChatOpenAI(
|
||||
model=model_version,
|
||||
openai_api_key=api_key,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=0,
|
||||
request_timeout=timeout,
|
||||
model_kwargs={
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
verbose=should_be_verbose(),
|
||||
)
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatOpenAI:
|
||||
return self._llm
|
43
backend/danswer/llm/utils.py
Normal file
43
backend/danswer/llm/utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import (
|
||||
PromptValue,
|
||||
)
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import BaseMessageChunk
|
||||
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
|
||||
|
||||
def message_generator_to_string_generator(
|
||||
messages: Iterator[BaseMessageChunk],
|
||||
) -> Iterator[str]:
|
||||
for message in messages:
|
||||
yield message.content
|
||||
|
||||
|
||||
def convert_input(input: LanguageModelInput) -> str:
|
||||
"""Heavily inspired by:
|
||||
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
||||
"""
|
||||
prompt_value = None
|
||||
if isinstance(input, PromptValue):
|
||||
prompt_value = input
|
||||
elif isinstance(input, str):
|
||||
prompt_value = StringPromptValue(text=input)
|
||||
elif isinstance(input, list):
|
||||
prompt_value = ChatPromptValue(messages=input)
|
||||
|
||||
if prompt_value is None:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
return prompt_value.to_string()
|
||||
|
||||
|
||||
def should_be_verbose() -> bool:
|
||||
return LOG_LEVEL == "debug"
|
@ -29,7 +29,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.server.credential import router as credential_router
|
||||
from danswer.server.event_loading import router as event_processing_router
|
||||
from danswer.server.health import router as health_router
|
||||
@ -178,7 +178,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
qa_model = get_default_llm()
|
||||
qa_model = get_default_qa_model()
|
||||
qa_model.warm_up_model()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
|
@ -57,7 +57,7 @@ from danswer.db.index_attempt import get_latest_index_attempts
|
||||
from danswer.db.models import DeletionAttempt
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.open_ai import get_gen_ai_api_key
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
@ -423,7 +423,7 @@ def validate_existing_genai_api_key(
|
||||
) -> None:
|
||||
# OpenAI key is only used for generative QA, so no need to validate this
|
||||
# if it's turned off or if a non-OpenAI model is being used
|
||||
if DISABLE_GENERATIVE_AI or not get_default_llm().requires_api_key:
|
||||
if DISABLE_GENERATIVE_AI or not get_default_qa_model().requires_api_key:
|
||||
return
|
||||
|
||||
# Only validate every so often
|
||||
|
@ -15,7 +15,7 @@ from danswer.db.models import User
|
||||
from danswer.direct_qa.answer_question import answer_question
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
@ -174,7 +174,7 @@ def stream_direct_qa(
|
||||
return
|
||||
|
||||
try:
|
||||
qa_model = get_default_llm()
|
||||
qa_model = get_default_qa_model()
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
yield get_json_line({"error": str(e)})
|
||||
@ -199,6 +199,7 @@ def stream_direct_qa(
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
yield get_json_line({"error": str(e)})
|
||||
logger.exception("Failed to run QA")
|
||||
|
||||
return
|
||||
|
||||
|
@ -20,6 +20,7 @@ httpx==0.23.3
|
||||
httpx-oauth==0.11.2
|
||||
huggingface-hub==0.16.4
|
||||
jira==3.5.1
|
||||
langchain==0.0.273
|
||||
Mako==1.2.4
|
||||
nltk==3.8.1
|
||||
docx2txt==0.8
|
||||
|
Loading…
x
Reference in New Issue
Block a user