From 2a556965456d2d7554154df9ea2c15b381033f83 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Wed, 4 Dec 2024 16:30:47 -0800 Subject: [PATCH 1/5] Move Answer (#3339) --- .../danswer/{llm/answering => chat}/answer.py | 26 +-- backend/danswer/chat/chat_utils.py | 2 +- .../llm_response_handler.py | 48 +----- backend/danswer/chat/models.py | 117 +++++++++++++ backend/danswer/chat/process_message.py | 12 +- .../prompts => chat/prompt_builder}/build.py | 26 ++- .../prompt_builder}/citations_prompt.py | 2 +- .../prompt_builder}/quotes_prompt.py | 2 +- backend/danswer/chat/prompt_builder/utils.py | 62 +++++++ .../answering => chat}/prune_and_merge.py | 6 +- .../answer_response_handler.py | 8 +- .../stream_processing/citation_processing.py | 2 +- .../stream_processing/quotes_processing.py | 0 .../stream_processing/utils.py | 0 .../tool_handling}/tool_response_handler.py | 4 +- backend/danswer/context/search/pipeline.py | 8 +- backend/danswer/llm/answering/models.py | 163 ------------------ .../danswer/llm/answering/prompts/utils.py | 20 --- backend/danswer/llm/models.py | 59 +++++++ backend/danswer/llm/utils.py | 39 ----- backend/danswer/prompts/prompt_utils.py | 2 +- .../secondary_llm_flows/choose_search.py | 4 +- .../secondary_llm_flows/query_expansion.py | 2 +- .../danswer/server/features/persona/api.py | 2 +- .../server/query_and_chat/chat_backend.py | 6 +- .../danswer/server/query_and_chat/models.py | 4 + backend/danswer/tools/base_tool.py | 2 +- backend/danswer/tools/tool.py | 4 +- backend/danswer/tools/tool_constructor.py | 8 +- .../custom/custom_tool.py | 4 +- .../images/image_generation_tool.py | 4 +- .../internet_search/internet_search_tool.py | 8 +- .../search/search_tool.py | 20 +-- .../search_like_tool_utils.py | 12 +- backend/danswer/tools/tool_runner.py | 2 +- backend/danswer/tools/tool_selection.py | 2 +- .../{llm/answering => chat}/conftest.py | 8 +- .../test_citation_processing.py | 6 +- .../test_quotes_processing.py} | 8 +- .../{llm/answering => chat}/test_answer.py | 10 +- .../test_prune_and_merge.py | 2 +- .../answering => chat}/test_skip_gen_ai.py | 8 +- 42 files changed, 364 insertions(+), 370 deletions(-) rename backend/danswer/{llm/answering => chat}/answer.py (93%) rename backend/danswer/{llm/answering => chat}/llm_response_handler.py (53%) rename backend/danswer/{llm/answering/prompts => chat/prompt_builder}/build.py (85%) rename backend/danswer/{llm/answering/prompts => chat/prompt_builder}/citations_prompt.py (99%) rename backend/danswer/{llm/answering/prompts => chat/prompt_builder}/quotes_prompt.py (97%) create mode 100644 backend/danswer/chat/prompt_builder/utils.py rename backend/danswer/{llm/answering => chat}/prune_and_merge.py (98%) rename backend/danswer/{llm/answering => chat}/stream_processing/answer_response_handler.py (92%) rename backend/danswer/{llm/answering => chat}/stream_processing/citation_processing.py (98%) rename backend/danswer/{llm/answering => chat}/stream_processing/quotes_processing.py (100%) rename backend/danswer/{llm/answering => chat}/stream_processing/utils.py (100%) rename backend/danswer/{llm/answering/tool => chat/tool_handling}/tool_response_handler.py (98%) delete mode 100644 backend/danswer/llm/answering/models.py delete mode 100644 backend/danswer/llm/answering/prompts/utils.py create mode 100644 backend/danswer/llm/models.py rename backend/tests/unit/danswer/{llm/answering => chat}/conftest.py (93%) rename backend/tests/unit/danswer/{llm/answering => chat}/stream_processing/test_citation_processing.py (98%) rename backend/tests/unit/danswer/{direct_qa/test_qa_utils.py => chat/stream_processing/test_quotes_processing.py} (97%) rename backend/tests/unit/danswer/{llm/answering => chat}/test_answer.py (97%) rename backend/tests/unit/danswer/{llm/answering => chat}/test_prune_and_merge.py (99%) rename backend/tests/unit/danswer/{llm/answering => chat}/test_skip_gen_ai.py (95%) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/chat/answer.py similarity index 93% rename from backend/danswer/llm/answering/answer.py rename to backend/danswer/chat/answer.py index f4aad9a49..d2db03186 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/chat/answer.py @@ -6,27 +6,27 @@ from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk from langchain_core.messages import ToolCall +from danswer.chat.llm_response_handler import LLMResponseHandlerManager from danswer.chat.models import AnswerQuestionPossibleReturn +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.file_store.utils import InMemoryChatFile -from danswer.llm.answering.llm_response_handler import LLMCall -from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.llm.answering.prompts.build import default_build_system_message -from danswer.llm.answering.prompts.build import default_build_user_message -from danswer.llm.answering.stream_processing.answer_response_handler import ( +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder +from danswer.chat.prompt_builder.build import default_build_system_message +from danswer.chat.prompt_builder.build import default_build_user_message +from danswer.chat.prompt_builder.build import LLMCall +from danswer.chat.stream_processing.answer_response_handler import ( CitationResponseHandler, ) -from danswer.llm.answering.stream_processing.answer_response_handler import ( +from danswer.chat.stream_processing.answer_response_handler import ( DummyAnswerResponseHandler, ) -from danswer.llm.answering.stream_processing.utils import map_document_id_order -from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler +from danswer.chat.stream_processing.utils import map_document_id_order +from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler +from danswer.file_store.utils import InMemoryChatFile from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.natural_language_processing.utils import get_tokenizer from danswer.tools.force import ForceUseTool from danswer.tools.models import ToolResponse diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index ccb978e8b..eb63c6875 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -26,7 +26,7 @@ from danswer.db.models import Prompt from danswer.db.models import Tool from danswer.db.models import User from danswer.db.persona import get_prompts_by_ids -from danswer.llm.answering.models import PreviousMessage +from danswer.llm.models import PreviousMessage from danswer.natural_language_processing.utils import BaseTokenizer from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.tools.tool_implementations.custom.custom_tool import ( diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/chat/llm_response_handler.py similarity index 53% rename from backend/danswer/llm/answering/llm_response_handler.py rename to backend/danswer/chat/llm_response_handler.py index bfd83c7b4..ee3d3f930 100644 --- a/backend/danswer/llm/answering/llm_response_handler.py +++ b/backend/danswer/chat/llm_response_handler.py @@ -1,58 +1,22 @@ from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator -from typing import TYPE_CHECKING from langchain_core.messages import BaseMessage -from pydantic.v1 import BaseModel as BaseModel__v1 -from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import ResponsePart from danswer.chat.models import StreamStopInfo from danswer.chat.models import StreamStopReason -from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.tools.force import ForceUseTool -from danswer.tools.models import ToolCallFinalResult -from danswer.tools.models import ToolCallKickoff -from danswer.tools.models import ToolResponse -from danswer.tools.tool import Tool - - -if TYPE_CHECKING: - from danswer.llm.answering.stream_processing.answer_response_handler import ( - AnswerResponseHandler, - ) - from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler - - -ResponsePart = ( - DanswerAnswerPiece - | CitationInfo - | ToolCallKickoff - | ToolResponse - | ToolCallFinalResult - | StreamStopInfo -) - - -class LLMCall(BaseModel__v1): - prompt_builder: AnswerPromptBuilder - tools: list[Tool] - force_use_tool: ForceUseTool - files: list[InMemoryChatFile] - tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult] - using_tool_calling_llm: bool - - class Config: - arbitrary_types_allowed = True +from danswer.chat.prompt_builder.build import LLMCall +from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler +from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler class LLMResponseHandlerManager: def __init__( self, - tool_handler: "ToolResponseHandler", - answer_handler: "AnswerResponseHandler", + tool_handler: ToolResponseHandler, + answer_handler: AnswerResponseHandler, is_cancelled: Callable[[], bool], ): self.tool_handler = tool_handler diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index b65620516..213a5ed74 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -1,10 +1,14 @@ +from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from enum import Enum from typing import Any +from typing import TYPE_CHECKING from pydantic import BaseModel +from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType @@ -12,8 +16,15 @@ from danswer.context.search.enums import QueryFlow from danswer.context.search.enums import RecencyBiasSetting from danswer.context.search.enums import SearchType from danswer.context.search.models import RetrievalDocs +from danswer.llm.override_models import PromptOverride +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType +if TYPE_CHECKING: + from danswer.db.models import Prompt + class LlmDoc(BaseModel): """This contains the minimal set information for the LLM portion including citations""" @@ -210,3 +221,109 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn] class LLMMetricsContainer(BaseModel): prompt_tokens: int response_tokens: int + + +StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] + + +class DocumentPruningConfig(BaseModel): + max_chunks: int | None = None + max_window_percentage: float | None = None + max_tokens: int | None = None + # different pruning behavior is expected when the + # user manually selects documents they want to chat with + # e.g. we don't want to truncate each document to be no more + # than one chunk long + is_manually_selected_docs: bool = False + # If user specifies to include additional context Chunks for each match, then different pruning + # is used. As many Sections as possible are included, and the last Section is truncated + # If this is false, all of the Sections are truncated if they are longer than the expected Chunk size. + # Sections are often expected to be longer than the maximum Chunk size but Chunks should not be. + use_sections: bool = True + # If using tools, then we need to consider the tool length + tool_num_tokens: int = 0 + # If using a tool message to represent the docs, then we have to JSON serialize + # the document content, which adds to the token count. + using_tool_message: bool = False + + +class ContextualPruningConfig(DocumentPruningConfig): + num_chunk_multiple: int + + @classmethod + def from_doc_pruning_config( + cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig + ) -> "ContextualPruningConfig": + return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict()) + + +class CitationConfig(BaseModel): + all_docs_useful: bool = False + + +class QuotesConfig(BaseModel): + pass + + +class AnswerStyleConfig(BaseModel): + citation_config: CitationConfig | None = None + quotes_config: QuotesConfig | None = None + document_pruning_config: DocumentPruningConfig = Field( + default_factory=DocumentPruningConfig + ) + # forces the LLM to return a structured response, see + # https://platform.openai.com/docs/guides/structured-outputs/introduction + # right now, only used by the simple chat API + structured_response_format: dict | None = None + + @model_validator(mode="after") + def check_quotes_and_citation(self) -> "AnswerStyleConfig": + if self.citation_config is None and self.quotes_config is None: + raise ValueError( + "One of `citation_config` or `quotes_config` must be provided" + ) + + if self.citation_config is not None and self.quotes_config is not None: + raise ValueError( + "Only one of `citation_config` or `quotes_config` must be provided" + ) + + return self + + +class PromptConfig(BaseModel): + """Final representation of the Prompt configuration passed + into the `Answer` object.""" + + system_prompt: str + task_prompt: str + datetime_aware: bool + include_citations: bool + + @classmethod + def from_model( + cls, model: "Prompt", prompt_override: PromptOverride | None = None + ) -> "PromptConfig": + override_system_prompt = ( + prompt_override.system_prompt if prompt_override else None + ) + override_task_prompt = prompt_override.task_prompt if prompt_override else None + + return cls( + system_prompt=override_system_prompt or model.system_prompt, + task_prompt=override_task_prompt or model.task_prompt, + datetime_aware=model.datetime_aware, + include_citations=model.include_citations, + ) + + model_config = ConfigDict(frozen=True) + + +ResponsePart = ( + DanswerAnswerPiece + | CitationInfo + | ToolCallKickoff + | ToolResponse + | ToolCallFinalResult + | StreamStopInfo +) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index e9d29360e..41de3c04d 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -6,19 +6,24 @@ from typing import cast from sqlalchemy.orm import Session +from danswer.chat.answer import Answer from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import create_temporary_persona from danswer.chat.models import AllCitations +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import ChatDanswerBotResponse +from danswer.chat.models import CitationConfig from danswer.chat.models import CitationInfo from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerContexts +from danswer.chat.models import DocumentPruningConfig from danswer.chat.models import FileChatDisplay from danswer.chat.models import FinalUsedContextDocsResponse from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import MessageResponseIDInfo from danswer.chat.models import MessageSpecificCitations +from danswer.chat.models import PromptConfig from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.chat.models import StreamStopInfo @@ -58,15 +63,10 @@ from danswer.file_store.models import ChatFileType from danswer.file_store.models import FileDescriptor from danswer.file_store.utils import load_all_chat_files from danswer.file_store.utils import save_files_from_urls -from danswer.llm.answering.answer import Answer -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple +from danswer.llm.models import PreviousMessage from danswer.llm.utils import litellm_exception_to_error_msg from danswer.natural_language_processing.utils import get_tokenizer from danswer.server.query_and_chat.models import ChatMessageDetail diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/chat/prompt_builder/build.py similarity index 85% rename from backend/danswer/llm/answering/prompts/build.py rename to backend/danswer/chat/prompt_builder/build.py index fd44adbe3..b5a90197b 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/chat/prompt_builder/build.py @@ -4,20 +4,26 @@ from typing import cast from langchain_core.messages import BaseMessage from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage +from pydantic.v1 import BaseModel as BaseModel__v1 +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens +from danswer.chat.prompt_builder.utils import translate_history_to_basemessages from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens from danswer.llm.interfaces import LLMConfig +from danswer.llm.models import PreviousMessage from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_message_tokens from danswer.llm.utils import message_to_prompt_and_imgs -from danswer.llm.utils import translate_history_to_basemessages from danswer.natural_language_processing.utils import get_tokenizer from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.prompt_utils import add_date_time_to_prompt from danswer.prompts.prompt_utils import drop_messages_history_overflow +from danswer.tools.force import ForceUseTool +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool def default_build_system_message( @@ -139,3 +145,15 @@ class AnswerPromptBuilder: return drop_messages_history_overflow( final_messages_with_tokens, self.max_tokens ) + + +class LLMCall(BaseModel__v1): + prompt_builder: AnswerPromptBuilder + tools: list[Tool] + force_use_tool: ForceUseTool + files: list[InMemoryChatFile] + tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult] + using_tool_calling_llm: bool + + class Config: + arbitrary_types_allowed = True diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/chat/prompt_builder/citations_prompt.py similarity index 99% rename from backend/danswer/llm/answering/prompts/citations_prompt.py rename to backend/danswer/chat/prompt_builder/citations_prompt.py index 1ff48432b..a49dd25ae 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/chat/prompt_builder/citations_prompt.py @@ -2,12 +2,12 @@ from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.context.search.models import InferenceChunk from danswer.db.models import Persona from danswer.db.persona import get_default_prompt__read_only from danswer.db.search_settings import get_multilingual_expansion -from danswer.llm.answering.models import PromptConfig from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.interfaces import LLMConfig diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/chat/prompt_builder/quotes_prompt.py similarity index 97% rename from backend/danswer/llm/answering/prompts/quotes_prompt.py rename to backend/danswer/chat/prompt_builder/quotes_prompt.py index 00f22f9e7..fa51b571e 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/chat/prompt_builder/quotes_prompt.py @@ -1,10 +1,10 @@ from langchain.schema.messages import HumanMessage from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.context.search.models import InferenceChunk from danswer.db.search_settings import get_multilingual_expansion -from danswer.llm.answering.models import PromptConfig from danswer.llm.utils import message_to_prompt_and_imgs from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK diff --git a/backend/danswer/chat/prompt_builder/utils.py b/backend/danswer/chat/prompt_builder/utils.py new file mode 100644 index 000000000..6383be534 --- /dev/null +++ b/backend/danswer/chat/prompt_builder/utils.py @@ -0,0 +1,62 @@ +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage + +from danswer.configs.constants import MessageType +from danswer.db.models import ChatMessage +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.models import PreviousMessage +from danswer.llm.utils import build_content_with_imgs +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT + + +def build_dummy_prompt( + system_prompt: str, task_prompt: str, retrieval_disabled: bool +) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() + + return PARAMATERIZED_PROMPT.format( + context_docs_str="", + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() + + +def translate_danswer_msg_to_langchain( + msg: ChatMessage | PreviousMessage, +) -> BaseMessage: + files: list[InMemoryChatFile] = [] + + # If the message is a `ChatMessage`, it doesn't have the downloaded files + # attached. Just ignore them for now. + if not isinstance(msg, ChatMessage): + files = msg.files + content = build_content_with_imgs(msg.message, files, message_type=msg.message_type) + + if msg.message_type == MessageType.SYSTEM: + raise ValueError("System messages are not currently part of history") + if msg.message_type == MessageType.ASSISTANT: + return AIMessage(content=content) + if msg.message_type == MessageType.USER: + return HumanMessage(content=content) + + raise ValueError(f"New message type {msg.message_type} not handled") + + +def translate_history_to_basemessages( + history: list[ChatMessage] | list["PreviousMessage"], +) -> tuple[list[BaseMessage], list[int]]: + history_basemessages = [ + translate_danswer_msg_to_langchain(msg) + for msg in history + if msg.token_count != 0 + ] + history_token_counts = [msg.token_count for msg in history if msg.token_count != 0] + return history_basemessages, history_token_counts diff --git a/backend/danswer/llm/answering/prune_and_merge.py b/backend/danswer/chat/prune_and_merge.py similarity index 98% rename from backend/danswer/llm/answering/prune_and_merge.py rename to backend/danswer/chat/prune_and_merge.py index 21ea2226d..0085793f8 100644 --- a/backend/danswer/llm/answering/prune_and_merge.py +++ b/backend/danswer/chat/prune_and_merge.py @@ -5,16 +5,16 @@ from typing import TypeVar from pydantic import BaseModel +from danswer.chat.models import ContextualPruningConfig from danswer.chat.models import ( LlmDoc, ) +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.context.search.models import InferenceChunk from danswer.context.search.models import InferenceSection -from danswer.llm.answering.models import ContextualPruningConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.llm.interfaces import LLMConfig from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content diff --git a/backend/danswer/llm/answering/stream_processing/answer_response_handler.py b/backend/danswer/chat/stream_processing/answer_response_handler.py similarity index 92% rename from backend/danswer/llm/answering/stream_processing/answer_response_handler.py rename to backend/danswer/chat/stream_processing/answer_response_handler.py index f0eb86ce2..8a8bda40d 100644 --- a/backend/danswer/llm/answering/stream_processing/answer_response_handler.py +++ b/backend/danswer/chat/stream_processing/answer_response_handler.py @@ -3,13 +3,11 @@ from collections.abc import Generator from langchain_core.messages import BaseMessage +from danswer.chat.llm_response_handler import ResponsePart from danswer.chat.models import CitationInfo from danswer.chat.models import LlmDoc -from danswer.llm.answering.llm_response_handler import ResponsePart -from danswer.llm.answering.stream_processing.citation_processing import ( - CitationProcessor, -) -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping +from danswer.chat.stream_processing.citation_processing import CitationProcessor +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping from danswer.utils.logger import setup_logger logger = setup_logger() diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/chat/stream_processing/citation_processing.py similarity index 98% rename from backend/danswer/llm/answering/stream_processing/citation_processing.py rename to backend/danswer/chat/stream_processing/citation_processing.py index 2acaa98f8..5a50855e9 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/chat/stream_processing/citation_processing.py @@ -4,8 +4,8 @@ from collections.abc import Generator from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/chat/stream_processing/quotes_processing.py similarity index 100% rename from backend/danswer/llm/answering/stream_processing/quotes_processing.py rename to backend/danswer/chat/stream_processing/quotes_processing.py diff --git a/backend/danswer/llm/answering/stream_processing/utils.py b/backend/danswer/chat/stream_processing/utils.py similarity index 100% rename from backend/danswer/llm/answering/stream_processing/utils.py rename to backend/danswer/chat/stream_processing/utils.py diff --git a/backend/danswer/llm/answering/tool/tool_response_handler.py b/backend/danswer/chat/tool_handling/tool_response_handler.py similarity index 98% rename from backend/danswer/llm/answering/tool/tool_response_handler.py rename to backend/danswer/chat/tool_handling/tool_response_handler.py index db35663c4..5438aa225 100644 --- a/backend/danswer/llm/answering/tool/tool_response_handler.py +++ b/backend/danswer/chat/tool_handling/tool_response_handler.py @@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage from langchain_core.messages import ToolCall -from danswer.llm.answering.llm_response_handler import LLMCall -from danswer.llm.answering.llm_response_handler import ResponsePart +from danswer.chat.models import ResponsePart +from danswer.chat.prompt_builder.build import LLMCall from danswer.llm.interfaces import LLM from danswer.tools.force import ForceUseTool from danswer.tools.message import build_tool_message diff --git a/backend/danswer/context/search/pipeline.py b/backend/danswer/context/search/pipeline.py index 21c518348..527485140 100644 --- a/backend/danswer/context/search/pipeline.py +++ b/backend/danswer/context/search/pipeline.py @@ -5,7 +5,11 @@ from typing import cast from sqlalchemy.orm import Session +from danswer.chat.models import PromptConfig from danswer.chat.models import SectionRelevancePiece +from danswer.chat.prune_and_merge import _merge_sections +from danswer.chat.prune_and_merge import ChunkRange +from danswer.chat.prune_and_merge import merge_chunk_intervals from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE from danswer.context.search.enums import LLMEvaluationType from danswer.context.search.enums import QueryFlow @@ -27,10 +31,6 @@ from danswer.db.models import User from danswer.db.search_settings import get_current_search_settings from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaChunkRequest -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prune_and_merge import _merge_sections -from danswer.llm.answering.prune_and_merge import ChunkRange -from danswer.llm.answering.prune_and_merge import merge_chunk_intervals from danswer.llm.interfaces import LLM from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section from danswer.utils.logger import setup_logger diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py deleted file mode 100644 index 03f72a096..000000000 --- a/backend/danswer/llm/answering/models.py +++ /dev/null @@ -1,163 +0,0 @@ -from collections.abc import Callable -from collections.abc import Iterator -from typing import TYPE_CHECKING - -from langchain.schema.messages import AIMessage -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage -from pydantic import BaseModel -from pydantic import ConfigDict -from pydantic import Field -from pydantic import model_validator - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.configs.constants import MessageType -from danswer.file_store.models import InMemoryChatFile -from danswer.llm.override_models import PromptOverride -from danswer.llm.utils import build_content_with_imgs -from danswer.tools.models import ToolCallFinalResult - -if TYPE_CHECKING: - from danswer.db.models import ChatMessage - from danswer.db.models import Prompt - - -StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] - - -class PreviousMessage(BaseModel): - """Simplified version of `ChatMessage`""" - - message: str - token_count: int - message_type: MessageType - files: list[InMemoryChatFile] - tool_call: ToolCallFinalResult | None - - @classmethod - def from_chat_message( - cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile] - ) -> "PreviousMessage": - message_file_ids = ( - [file["id"] for file in chat_message.files] if chat_message.files else [] - ) - return cls( - message=chat_message.message, - token_count=chat_message.token_count, - message_type=chat_message.message_type, - files=[ - file - for file in available_files - if str(file.file_id) in message_file_ids - ], - tool_call=ToolCallFinalResult( - tool_name=chat_message.tool_call.tool_name, - tool_args=chat_message.tool_call.tool_arguments, - tool_result=chat_message.tool_call.tool_result, - ) - if chat_message.tool_call - else None, - ) - - def to_langchain_msg(self) -> BaseMessage: - content = build_content_with_imgs(self.message, self.files) - if self.message_type == MessageType.USER: - return HumanMessage(content=content) - elif self.message_type == MessageType.ASSISTANT: - return AIMessage(content=content) - else: - return SystemMessage(content=content) - - -class DocumentPruningConfig(BaseModel): - max_chunks: int | None = None - max_window_percentage: float | None = None - max_tokens: int | None = None - # different pruning behavior is expected when the - # user manually selects documents they want to chat with - # e.g. we don't want to truncate each document to be no more - # than one chunk long - is_manually_selected_docs: bool = False - # If user specifies to include additional context Chunks for each match, then different pruning - # is used. As many Sections as possible are included, and the last Section is truncated - # If this is false, all of the Sections are truncated if they are longer than the expected Chunk size. - # Sections are often expected to be longer than the maximum Chunk size but Chunks should not be. - use_sections: bool = True - # If using tools, then we need to consider the tool length - tool_num_tokens: int = 0 - # If using a tool message to represent the docs, then we have to JSON serialize - # the document content, which adds to the token count. - using_tool_message: bool = False - - -class ContextualPruningConfig(DocumentPruningConfig): - num_chunk_multiple: int - - @classmethod - def from_doc_pruning_config( - cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig - ) -> "ContextualPruningConfig": - return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict()) - - -class CitationConfig(BaseModel): - all_docs_useful: bool = False - - -class QuotesConfig(BaseModel): - pass - - -class AnswerStyleConfig(BaseModel): - citation_config: CitationConfig | None = None - quotes_config: QuotesConfig | None = None - document_pruning_config: DocumentPruningConfig = Field( - default_factory=DocumentPruningConfig - ) - # forces the LLM to return a structured response, see - # https://platform.openai.com/docs/guides/structured-outputs/introduction - # right now, only used by the simple chat API - structured_response_format: dict | None = None - - @model_validator(mode="after") - def check_quotes_and_citation(self) -> "AnswerStyleConfig": - if self.citation_config is None and self.quotes_config is None: - raise ValueError( - "One of `citation_config` or `quotes_config` must be provided" - ) - - if self.citation_config is not None and self.quotes_config is not None: - raise ValueError( - "Only one of `citation_config` or `quotes_config` must be provided" - ) - - return self - - -class PromptConfig(BaseModel): - """Final representation of the Prompt configuration passed - into the `Answer` object.""" - - system_prompt: str - task_prompt: str - datetime_aware: bool - include_citations: bool - - @classmethod - def from_model( - cls, model: "Prompt", prompt_override: PromptOverride | None = None - ) -> "PromptConfig": - override_system_prompt = ( - prompt_override.system_prompt if prompt_override else None - ) - override_task_prompt = prompt_override.task_prompt if prompt_override else None - - return cls( - system_prompt=override_system_prompt or model.system_prompt, - task_prompt=override_task_prompt or model.task_prompt, - datetime_aware=model.datetime_aware, - include_citations=model.include_citations, - ) - - model_config = ConfigDict(frozen=True) diff --git a/backend/danswer/llm/answering/prompts/utils.py b/backend/danswer/llm/answering/prompts/utils.py deleted file mode 100644 index bcc8b8918..000000000 --- a/backend/danswer/llm/answering/prompts/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT - - -def build_dummy_prompt( - system_prompt: str, task_prompt: str, retrieval_disabled: bool -) -> str: - if retrieval_disabled: - return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - return PARAMATERIZED_PROMPT.format( - context_docs_str="", - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() diff --git a/backend/danswer/llm/models.py b/backend/danswer/llm/models.py new file mode 100644 index 000000000..182fc97fb --- /dev/null +++ b/backend/danswer/llm/models.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage +from pydantic import BaseModel + +from danswer.configs.constants import MessageType +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.utils import build_content_with_imgs +from danswer.tools.models import ToolCallFinalResult + +if TYPE_CHECKING: + from danswer.db.models import ChatMessage + + +class PreviousMessage(BaseModel): + """Simplified version of `ChatMessage`""" + + message: str + token_count: int + message_type: MessageType + files: list[InMemoryChatFile] + tool_call: ToolCallFinalResult | None + + @classmethod + def from_chat_message( + cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile] + ) -> "PreviousMessage": + message_file_ids = ( + [file["id"] for file in chat_message.files] if chat_message.files else [] + ) + return cls( + message=chat_message.message, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + files=[ + file + for file in available_files + if str(file.file_id) in message_file_ids + ], + tool_call=ToolCallFinalResult( + tool_name=chat_message.tool_call.tool_name, + tool_args=chat_message.tool_call.tool_arguments, + tool_result=chat_message.tool_call.tool_result, + ) + if chat_message.tool_call + else None, + ) + + def to_langchain_msg(self) -> BaseMessage: + content = build_content_with_imgs(self.message, self.files) + if self.message_type == MessageType.USER: + return HumanMessage(content=content) + elif self.message_type == MessageType.ASSISTANT: + return AIMessage(content=content) + else: + return SystemMessage(content=content) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index e5564e88d..5573312b9 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -5,8 +5,6 @@ from collections.abc import Callable from collections.abc import Iterator from typing import Any from typing import cast -from typing import TYPE_CHECKING -from typing import Union import litellm # type: ignore import pandas as pd @@ -36,7 +34,6 @@ from danswer.configs.constants import MessageType from danswer.configs.model_configs import GEN_AI_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS -from danswer.db.models import ChatMessage from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile from danswer.llm.interfaces import LLM @@ -44,9 +41,6 @@ from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.utils.logger import setup_logger from shared_configs.configs import LOG_LEVEL -if TYPE_CHECKING: - from danswer.llm.answering.models import PreviousMessage - logger = setup_logger() @@ -104,39 +98,6 @@ def litellm_exception_to_error_msg( return error_msg -def translate_danswer_msg_to_langchain( - msg: Union[ChatMessage, "PreviousMessage"], -) -> BaseMessage: - files: list[InMemoryChatFile] = [] - - # If the message is a `ChatMessage`, it doesn't have the downloaded files - # attached. Just ignore them for now. - if not isinstance(msg, ChatMessage): - files = msg.files - content = build_content_with_imgs(msg.message, files, message_type=msg.message_type) - - if msg.message_type == MessageType.SYSTEM: - raise ValueError("System messages are not currently part of history") - if msg.message_type == MessageType.ASSISTANT: - return AIMessage(content=content) - if msg.message_type == MessageType.USER: - return HumanMessage(content=content) - - raise ValueError(f"New message type {msg.message_type} not handled") - - -def translate_history_to_basemessages( - history: list[ChatMessage] | list["PreviousMessage"], -) -> tuple[list[BaseMessage], list[int]]: - history_basemessages = [ - translate_danswer_msg_to_langchain(msg) - for msg in history - if msg.token_count != 0 - ] - history_token_counts = [msg.token_count for msg in history if msg.token_count != 0] - return history_basemessages, history_token_counts - - # Processes CSV files to show the first 5 rows and max_columns (default 40) columns def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str: df = pd.read_csv(io.StringIO(file.content.decode("utf-8"))) diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py index 4195926db..ec2801783 100644 --- a/backend/danswer/prompts/prompt_utils.py +++ b/backend/danswer/prompts/prompt_utils.py @@ -5,11 +5,11 @@ from typing import cast from langchain_core.messages import BaseMessage from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.configs.chat_configs import LANGUAGE_HINT from danswer.configs.constants import DocumentSource from danswer.context.search.models import InferenceChunk from danswer.db.models import Prompt -from danswer.llm.answering.models import PromptConfig from danswer.prompts.chat_prompts import ADDITIONAL_INFO from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.constants import CODE_BLOCK_PAT diff --git a/backend/danswer/secondary_llm_flows/choose_search.py b/backend/danswer/secondary_llm_flows/choose_search.py index 5016cf055..36539dd4a 100644 --- a/backend/danswer/secondary_llm_flows/choose_search.py +++ b/backend/danswer/secondary_llm_flows/choose_search.py @@ -3,14 +3,14 @@ from langchain.schema import HumanMessage from langchain.schema import SystemMessage from danswer.chat.chat_utils import combine_message_chain +from danswer.chat.prompt_builder.utils import translate_danswer_msg_to_langchain from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string -from danswer.llm.utils import translate_danswer_msg_to_langchain from danswer.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE from danswer.prompts.chat_prompts import NO_SEARCH from danswer.prompts.chat_prompts import REQUIRE_SEARCH_HINT diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 585af00bd..07f187e5b 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -4,10 +4,10 @@ from danswer.chat.chat_utils import combine_message_chain from danswer.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage -from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llms from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import HISTORY_QUERY_REPHRASE diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index fd092fb90..f6cb3a2d2 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -13,6 +13,7 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_limited_user from danswer.auth.users import current_user +from danswer.chat.prompt_builder.utils import build_dummy_prompt from danswer.configs.constants import FileOrigin from danswer.configs.constants import NotificationType from danswer.db.engine import get_session @@ -33,7 +34,6 @@ from danswer.db.persona import update_persona_shared_users from danswer.db.persona import update_persona_visibility from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import ChatFileType -from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import ImageGenerationToolStatus from danswer.server.features.persona.models import PersonaCategoryCreate diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 7b1413e13..a6592a907 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -23,6 +23,9 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import extract_headers from danswer.chat.process_message import stream_chat_message +from danswer.chat.prompt_builder.citations_prompt import ( + compute_max_document_tokens_for_persona, +) from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType @@ -51,9 +54,6 @@ from danswer.file_processing.extract_file_text import extract_file_text from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import ChatFileType from danswer.file_store.models import FileDescriptor -from danswer.llm.answering.prompts.citations_prompt import ( - compute_max_document_tokens_for_persona, -) from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llms from danswer.llm.factory import get_llms_for_persona diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 6e2f31b64..34ef556da 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import Any +from typing import TYPE_CHECKING from uuid import UUID from pydantic import BaseModel @@ -22,6 +23,9 @@ from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import PromptOverride from danswer.tools.models import ToolCallFinalResult +if TYPE_CHECKING: + pass + class SourceTag(Tag): source: DocumentSource diff --git a/backend/danswer/tools/base_tool.py b/backend/danswer/tools/base_tool.py index 739025044..ebacf687a 100644 --- a/backend/danswer/tools/base_tool.py +++ b/backend/danswer/tools/base_tool.py @@ -7,7 +7,7 @@ from danswer.llm.utils import message_to_prompt_and_imgs from danswer.tools.tool import Tool if TYPE_CHECKING: - from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.tools.tool_implementations.custom.custom_tool import ( CustomToolCallSummary, ) diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index 6fc9251a1..210a80286 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -3,13 +3,13 @@ from collections.abc import Generator from typing import Any from typing import TYPE_CHECKING -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.utils.special_types import JSON_ro if TYPE_CHECKING: - from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.tools.message import ToolCallSummary from danswer.tools.models import ToolResponse diff --git a/backend/danswer/tools/tool_constructor.py b/backend/danswer/tools/tool_constructor.py index 7ffd5b96b..6f3717935 100644 --- a/backend/danswer/tools/tool_constructor.py +++ b/backend/danswer/tools/tool_constructor.py @@ -5,6 +5,10 @@ from pydantic import BaseModel from pydantic import Field from sqlalchemy.orm import Session +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import CitationConfig +from danswer.chat.models import DocumentPruningConfig +from danswer.chat.models import PromptConfig from danswer.configs.app_configs import AZURE_DALLE_API_BASE from danswer.configs.app_configs import AZURE_DALLE_API_KEY from danswer.configs.app_configs import AZURE_DALLE_API_VERSION @@ -19,10 +23,6 @@ from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import Persona from danswer.db.models import User from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PromptConfig from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLMConfig from danswer.natural_language_processing.utils import get_tokenizer diff --git a/backend/danswer/tools/tool_implementations/custom/custom_tool.py b/backend/danswer/tools/tool_implementations/custom/custom_tool.py index c25d61b3c..b874a2164 100644 --- a/backend/danswer/tools/tool_implementations/custom/custom_tool.py +++ b/backend/danswer/tools/tool_implementations/custom/custom_tool.py @@ -15,14 +15,14 @@ from langchain_core.messages import SystemMessage from pydantic import BaseModel from requests import JSONDecodeError +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.constants import FileOrigin from danswer.db.engine import get_session_with_default_tenant from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.tools.base_tool import BaseTool from danswer.tools.message import ToolCallSummary from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER diff --git a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py index 70763fc78..83a602528 100644 --- a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py @@ -8,10 +8,10 @@ from litellm import image_generation # type: ignore from pydantic import BaseModel from danswer.chat.chat_utils import combine_message_chain +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import message_to_string from danswer.prompts.constants import GENERAL_SEP_PAT diff --git a/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py index 85d93d833..cdd52f763 100644 --- a/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py +++ b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py @@ -7,15 +7,15 @@ from typing import cast import httpx from danswer.chat.chat_utils import combine_message_chain +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.context.search.models import SearchDoc -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE from danswer.prompts.constants import GENERAL_SEP_PAT diff --git a/backend/danswer/tools/tool_implementations/search/search_tool.py b/backend/danswer/tools/tool_implementations/search/search_tool.py index 3d4c14ea6..5bf08e564 100644 --- a/backend/danswer/tools/tool_implementations/search/search_tool.py +++ b/backend/danswer/tools/tool_implementations/search/search_tool.py @@ -7,10 +7,19 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.chat.chat_utils import llm_doc_from_inference_section +from danswer.chat.llm_response_handler import LLMCall +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import ContextualPruningConfig from danswer.chat.models import DanswerContext from danswer.chat.models import DanswerContexts +from danswer.chat.models import DocumentPruningConfig from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.chat.models import SectionRelevancePiece +from danswer.chat.prompt_builder.build import AnswerPromptBuilder +from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens +from danswer.chat.prune_and_merge import prune_and_merge_sections +from danswer.chat.prune_and_merge import prune_sections from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS @@ -25,17 +34,8 @@ from danswer.context.search.models import SearchRequest from danswer.context.search.pipeline import SearchPipeline from danswer.db.models import Persona from danswer.db.models import User -from danswer.llm.answering.llm_response_handler import LLMCall -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import ContextualPruningConfig -from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import PreviousMessage -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens -from danswer.llm.answering.prune_and_merge import prune_and_merge_sections -from danswer.llm.answering.prune_and_merge import prune_sections from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.tools.message import ToolCallSummary diff --git a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py index 55890188d..761e2f9ec 100644 --- a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py @@ -2,15 +2,15 @@ from typing import cast from langchain_core.messages import HumanMessage +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import LlmDoc -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder -from danswer.llm.answering.prompts.citations_prompt import ( +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder +from danswer.chat.prompt_builder.citations_prompt import ( build_citations_system_message, ) -from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message -from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message +from danswer.chat.prompt_builder.citations_prompt import build_citations_user_message +from danswer.chat.prompt_builder.quotes_prompt import build_quotes_user_message from danswer.tools.message import ToolCallSummary from danswer.tools.models import ToolResponse diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index fb3eb8b99..55ae7022e 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -2,8 +2,8 @@ from collections.abc import Callable from collections.abc import Generator from typing import Any -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.tools.models import ToolCallFinalResult from danswer.tools.models import ToolCallKickoff from danswer.tools.models import ToolResponse diff --git a/backend/danswer/tools/tool_selection.py b/backend/danswer/tools/tool_selection.py index dc8d697c2..f9fbaf9c0 100644 --- a/backend/danswer/tools/tool_selection.py +++ b/backend/danswer/tools/tool_selection.py @@ -3,8 +3,8 @@ from typing import Any from danswer.chat.chat_utils import combine_message_chain from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF -from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.llm.models import PreviousMessage from danswer.llm.utils import message_to_string from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.tools.tool import Tool diff --git a/backend/tests/unit/danswer/llm/answering/conftest.py b/backend/tests/unit/danswer/chat/conftest.py similarity index 93% rename from backend/tests/unit/danswer/llm/answering/conftest.py rename to backend/tests/unit/danswer/chat/conftest.py index 46dfc3523..aed94d8fc 100644 --- a/backend/tests/unit/danswer/llm/answering/conftest.py +++ b/backend/tests/unit/danswer/chat/conftest.py @@ -5,12 +5,12 @@ from unittest.mock import MagicMock import pytest from langchain_core.messages import SystemMessage +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import CitationConfig from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig +from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.constants import DocumentSource -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig -from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLMConfig from danswer.tools.models import ToolResponse from danswer.tools.tool_implementations.search.search_tool import SearchTool diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py similarity index 98% rename from backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py rename to backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py index 335c2e7b0..563e26780 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py @@ -5,11 +5,9 @@ import pytest from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.stream_processing.citation_processing import CitationProcessor +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping from danswer.configs.constants import DocumentSource -from danswer.llm.answering.stream_processing.citation_processing import ( - CitationProcessor, -) -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping """ diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/chat/stream_processing/test_quotes_processing.py similarity index 97% rename from backend/tests/unit/danswer/direct_qa/test_qa_utils.py rename to backend/tests/unit/danswer/chat/stream_processing/test_quotes_processing.py index bcbd76f4e..7cb969ab7 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/chat/stream_processing/test_quotes_processing.py @@ -2,14 +2,10 @@ import textwrap import pytest +from danswer.chat.stream_processing.quotes_processing import match_quotes_to_docs +from danswer.chat.stream_processing.quotes_processing import separate_answer_quotes from danswer.configs.constants import DocumentSource from danswer.context.search.models import InferenceChunk -from danswer.llm.answering.stream_processing.quotes_processing import ( - match_quotes_to_docs, -) -from danswer.llm.answering.stream_processing.quotes_processing import ( - separate_answer_quotes, -) def test_passed_in_quotes() -> None: diff --git a/backend/tests/unit/danswer/llm/answering/test_answer.py b/backend/tests/unit/danswer/chat/test_answer.py similarity index 97% rename from backend/tests/unit/danswer/llm/answering/test_answer.py rename to backend/tests/unit/danswer/chat/test_answer.py index 5746c753d..14bbec654 100644 --- a/backend/tests/unit/danswer/llm/answering/test_answer.py +++ b/backend/tests/unit/danswer/chat/test_answer.py @@ -11,21 +11,21 @@ from langchain_core.messages import SystemMessage from langchain_core.messages import ToolCall from langchain_core.messages import ToolCallChunk +from danswer.chat.answer import Answer +from danswer.chat.models import AnswerStyleConfig from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.models import PromptConfig from danswer.chat.models import StreamStopInfo from danswer.chat.models import StreamStopReason -from danswer.llm.answering.answer import Answer -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PromptConfig from danswer.llm.interfaces import LLM from danswer.tools.force import ForceUseTool from danswer.tools.models import ToolCallFinalResult from danswer.tools.models import ToolCallKickoff from danswer.tools.models import ToolResponse -from tests.unit.danswer.llm.answering.conftest import DEFAULT_SEARCH_ARGS -from tests.unit.danswer.llm.answering.conftest import QUERY +from tests.unit.danswer.chat.conftest import DEFAULT_SEARCH_ARGS +from tests.unit.danswer.chat.conftest import QUERY @pytest.fixture diff --git a/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py b/backend/tests/unit/danswer/chat/test_prune_and_merge.py similarity index 99% rename from backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py rename to backend/tests/unit/danswer/chat/test_prune_and_merge.py index c71d91090..2741a5652 100644 --- a/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py +++ b/backend/tests/unit/danswer/chat/test_prune_and_merge.py @@ -1,9 +1,9 @@ import pytest +from danswer.chat.prune_and_merge import _merge_sections from danswer.configs.constants import DocumentSource from danswer.context.search.models import InferenceChunk from danswer.context.search.models import InferenceSection -from danswer.llm.answering.prune_and_merge import _merge_sections # This large test accounts for all of the following: diff --git a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py b/backend/tests/unit/danswer/chat/test_skip_gen_ai.py similarity index 95% rename from backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py rename to backend/tests/unit/danswer/chat/test_skip_gen_ai.py index 74d178da3..772ec52a6 100644 --- a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py +++ b/backend/tests/unit/danswer/chat/test_skip_gen_ai.py @@ -5,10 +5,10 @@ from unittest.mock import Mock import pytest from pytest_mock import MockerFixture -from danswer.llm.answering.answer import Answer -from danswer.llm.answering.answer import AnswerStream -from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import PromptConfig +from danswer.chat.answer import Answer +from danswer.chat.answer import AnswerStream +from danswer.chat.models import AnswerStyleConfig +from danswer.chat.models import PromptConfig from danswer.tools.force import ForceUseTool from danswer.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results From 69b99056b2e0284ee1c46a3ad514d68504a7aafd Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 4 Dec 2024 16:08:52 -0800 Subject: [PATCH 2/5] Redirect to chat (#3341) * k * nit --- web/src/app/not-found.tsx | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 web/src/app/not-found.tsx diff --git a/web/src/app/not-found.tsx b/web/src/app/not-found.tsx new file mode 100644 index 000000000..6866db7cb --- /dev/null +++ b/web/src/app/not-found.tsx @@ -0,0 +1,5 @@ +import { redirect } from "next/navigation"; + +export default function NotFound() { + redirect("/chat"); +} From fd1999454a1cf96941e49ca25c255d48df6766bb Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 4 Dec 2024 17:10:37 -0800 Subject: [PATCH 3/5] ensure we can order by doc id (#3343) --- .../document_index/vespa/app_config/schemas/danswer_chunk.sd | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd index e712266fa..8789a0534 100644 --- a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -4,6 +4,8 @@ schema DANSWER_CHUNK_NAME { # Not to be confused with the UUID generated for this chunk which is called documentid by default field document_id type string { indexing: summary | attribute + attribute: fast-search + rank: filter } field chunk_id type int { indexing: summary | attribute From 91d44c83d2cbbce5f2baba6fe1c64875995f3cd2 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 4 Dec 2024 18:19:43 -0800 Subject: [PATCH 4/5] fixing chromatic tests (#3344) * wait for the page to load * fix up tests * make sure "Initializing Danswer" is gone --- web/tests/e2e/admin_assistants.spec.ts | 3 +++ web/tests/e2e/home.spec.ts | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/web/tests/e2e/admin_assistants.spec.ts b/web/tests/e2e/admin_assistants.spec.ts index 308d0975e..a5ccf92c7 100644 --- a/web/tests/e2e/admin_assistants.spec.ts +++ b/web/tests/e2e/admin_assistants.spec.ts @@ -12,5 +12,8 @@ test( await expect(page.locator("p.text-sm").nth(0)).toHaveText( /^Assistants are a way to build/ ); + + const generalTextLocator = page.locator("tr.border-b td").nth(1); + await expect(generalTextLocator.locator("p.text")).toHaveText("General"); } ); diff --git a/web/tests/e2e/home.spec.ts b/web/tests/e2e/home.spec.ts index 2b1605fa1..74a9bba96 100644 --- a/web/tests/e2e/home.spec.ts +++ b/web/tests/e2e/home.spec.ts @@ -27,5 +27,11 @@ test( await page.click('button[type="submit"]'); await page.waitForURL("http://localhost:3000/chat"); + + await page.getByPlaceholder("Send a message "); + + await expect(page.locator("body")).not.toContainText( + "Initializing Danswer" + ); } ); From 7e53af18b62a12cf1ecab38d07bbf7608ed4aae8 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 4 Dec 2024 18:24:54 -0800 Subject: [PATCH 5/5] Add b64 image support for image generation (#3342) * Add b64 image support * Fix * enhance * Fix mypy * Fix imports --- backend/danswer/chat/process_message.py | 39 ++++++++---- backend/danswer/configs/tool_configs.py | 2 + backend/danswer/file_store/utils.py | 61 ++++++++++++++++--- backend/danswer/llm/utils.py | 21 ++++++- .../images/image_generation_tool.py | 52 ++++++++++++---- .../tool_implementations/images/prompt.py | 5 +- backend/danswer/utils/b64.py | 25 ++++++++ 7 files changed, 169 insertions(+), 36 deletions(-) create mode 100644 backend/danswer/utils/b64.py diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 41de3c04d..43188748f 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -62,7 +62,7 @@ from danswer.document_index.factory import get_default_document_index from danswer.file_store.models import ChatFileType from danswer.file_store.models import FileDescriptor from danswer.file_store.utils import load_all_chat_files -from danswer.file_store.utils import save_files_from_urls +from danswer.file_store.utils import save_files from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple @@ -119,6 +119,7 @@ from danswer.utils.logger import setup_logger from danswer.utils.long_term_log import LongTermLogger from danswer.utils.timing import log_function_time from danswer.utils.timing import log_generator_function_time +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -302,6 +303,7 @@ def stream_chat_message_objects( 3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails 4. [always] Details on the final AI response message that is created """ + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() use_existing_user_message = new_msg_req.use_existing_user_message existing_assistant_message_id = new_msg_req.existing_assistant_message_id @@ -678,7 +680,8 @@ def stream_chat_message_objects( reference_db_search_docs = None qa_docs_response = None - ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images + # any files to associate with the AI message e.g. dall-e generated images + ai_message_files = [] dropped_indices = None tool_result = None @@ -733,8 +736,14 @@ def stream_chat_message_objects( list[ImageGenerationResponse], packet.response ) - file_ids = save_files_from_urls( - [img.url for img in img_generation_response] + file_ids = save_files( + urls=[img.url for img in img_generation_response if img.url], + base64_files=[ + img.image_data + for img in img_generation_response + if img.image_data + ], + tenant_id=tenant_id, ) ai_message_files = [ FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) @@ -760,15 +769,19 @@ def stream_chat_message_objects( or custom_tool_response.response_type == "csv" ): file_ids = custom_tool_response.tool_result.file_ids - ai_message_files = [ - FileDescriptor( - id=str(file_id), - type=ChatFileType.IMAGE - if custom_tool_response.response_type == "image" - else ChatFileType.CSV, - ) - for file_id in file_ids - ] + ai_message_files.extend( + [ + FileDescriptor( + id=str(file_id), + type=( + ChatFileType.IMAGE + if custom_tool_response.response_type == "image" + else ChatFileType.CSV + ), + ) + for file_id in file_ids + ] + ) yield FileChatDisplay( file_ids=[str(file_id) for file_id in file_ids] ) diff --git a/backend/danswer/configs/tool_configs.py b/backend/danswer/configs/tool_configs.py index 3170cb31f..9e1433014 100644 --- a/backend/danswer/configs/tool_configs.py +++ b/backend/danswer/configs/tool_configs.py @@ -2,6 +2,8 @@ import json import os +IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url") + # if specified, will pass through request headers to the call to API calls made by custom tools CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get( diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index e9eea2c26..978bb92e6 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -1,6 +1,6 @@ +import base64 from collections.abc import Callable from io import BytesIO -from typing import Any from typing import cast from uuid import uuid4 @@ -13,8 +13,8 @@ from danswer.db.models import ChatMessage from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import FileDescriptor from danswer.file_store.models import InMemoryChatFile +from danswer.utils.b64 import get_image_type from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR def load_chat_file( @@ -75,11 +75,58 @@ def save_file_from_url(url: str, tenant_id: str) -> str: return unique_id -def save_files_from_urls(urls: list[str]) -> list[str]: - tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() +def save_file_from_base64(base64_string: str, tenant_id: str) -> str: + with get_session_with_tenant(tenant_id) as db_session: + unique_id = str(uuid4()) + file_store = get_default_file_store(db_session) + file_store.save_file( + file_name=unique_id, + content=BytesIO(base64.b64decode(base64_string)), + display_name="GeneratedImage", + file_origin=FileOrigin.CHAT_IMAGE_GEN, + file_type=get_image_type(base64_string), + ) + return unique_id - funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [ - (save_file_from_url, (url, tenant_id)) for url in urls + +def save_file( + tenant_id: str, + url: str | None = None, + base64_data: str | None = None, +) -> str: + """Save a file from either a URL or base64 encoded string. + + Args: + tenant_id: The tenant ID to save the file under + url: URL to download file from + base64_data: Base64 encoded file data + + Returns: + The unique ID of the saved file + + Raises: + ValueError: If neither url nor base64_data is provided, or if both are provided + """ + if url is not None and base64_data is not None: + raise ValueError("Cannot specify both url and base64_data") + + if url is not None: + return save_file_from_url(url, tenant_id) + elif base64_data is not None: + return save_file_from_base64(base64_data, tenant_id) + else: + raise ValueError("Must specify either url or base64_data") + + +def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]: + # NOTE: be explicit about typing so that if we change things, we get notified + funcs: list[ + tuple[ + Callable[[str, str | None, str | None], str], + tuple[str, str | None, str | None], + ] + ] = [(save_file, (tenant_id, url, None)) for url in urls] + [ + (save_file, (tenant_id, None, base64_file)) for base64_file in base64_files ] - # Must pass in tenant_id here, since this is called by multithreading + return run_functions_tuples_in_parallel(funcs) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 5573312b9..41c6592c6 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -38,6 +38,8 @@ from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile from danswer.llm.interfaces import LLM from danswer.prompts.constants import CODE_BLOCK_PAT +from danswer.utils.b64 import get_image_type +from danswer.utils.b64 import get_image_type_from_bytes from danswer.utils.logger import setup_logger from shared_configs.configs import LOG_LEVEL @@ -151,6 +153,7 @@ def build_content_with_imgs( message: str, files: list[InMemoryChatFile] | None = None, img_urls: list[str] | None = None, + b64_imgs: list[str] | None = None, message_type: MessageType = MessageType.USER, ) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type files = files or [] @@ -163,6 +166,7 @@ def build_content_with_imgs( ) img_urls = img_urls or [] + b64_imgs = b64_imgs or [] message_main_content = _build_content(message, files) @@ -181,11 +185,22 @@ def build_content_with_imgs( { "type": "image_url", "image_url": { - "url": f"data:image/jpeg;base64,{file.to_base64()}", + "url": ( + f"data:{get_image_type_from_bytes(file.content)};" + f"base64,{file.to_base64()}" + ), }, } - for file in files - if file.file_type == "image" + for file in img_files + ] + + [ + { + "type": "image_url", + "image_url": { + "url": f"data:{get_image_type(b64_img)};base64,{b64_img}", + }, + } + for b64_img in b64_imgs ] + [ { diff --git a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py index 83a602528..d8d3d7543 100644 --- a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py @@ -4,12 +4,14 @@ from enum import Enum from typing import Any from typing import cast +import requests from litellm import image_generation # type: ignore from pydantic import BaseModel from danswer.chat.chat_utils import combine_message_chain from danswer.chat.prompt_builder.build import AnswerPromptBuilder from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT from danswer.llm.interfaces import LLM from danswer.llm.models import PreviousMessage from danswer.llm.utils import build_content_with_imgs @@ -56,9 +58,18 @@ Follow Up Input: """.strip() +class ImageFormat(str, Enum): + URL = "url" + BASE64 = "b64_json" + + +_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT) + + class ImageGenerationResponse(BaseModel): revised_prompt: str - url: str + url: str | None + image_data: str | None class ImageShape(str, Enum): @@ -80,6 +91,7 @@ class ImageGenerationTool(Tool): model: str = "dall-e-3", num_imgs: int = 2, additional_headers: dict[str, str] | None = None, + output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT, ) -> None: self.api_key = api_key self.api_base = api_base @@ -89,6 +101,7 @@ class ImageGenerationTool(Tool): self.num_imgs = num_imgs self.additional_headers = additional_headers + self.output_format = output_format @property def name(self) -> str: @@ -168,7 +181,7 @@ class ImageGenerationTool(Tool): ) return build_content_with_imgs( - json.dumps( + message=json.dumps( [ { "revised_prompt": image_generation.revised_prompt, @@ -177,13 +190,10 @@ class ImageGenerationTool(Tool): for image_generation in image_generations ] ), - # NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow - # Tool messages to contain images - # img_urls=[image_generation.url for image_generation in image_generations], ) def _generate_image( - self, prompt: str, shape: ImageShape + self, prompt: str, shape: ImageShape, format: ImageFormat ) -> ImageGenerationResponse: if shape == ImageShape.LANDSCAPE: size = "1792x1024" @@ -197,20 +207,32 @@ class ImageGenerationTool(Tool): prompt=prompt, model=self.model, api_key=self.api_key, - # need to pass in None rather than empty str api_base=self.api_base or None, api_version=self.api_version or None, size=size, n=1, + response_format=format, extra_headers=build_llm_extra_headers(self.additional_headers), ) + + if format == ImageFormat.URL: + url = response.data[0]["url"] + image_data = None + else: + url = None + image_data = response.data[0]["b64_json"] + return ImageGenerationResponse( revised_prompt=response.data[0]["revised_prompt"], - url=response.data[0]["url"], + url=url, + image_data=image_data, ) + except requests.RequestException as e: + logger.error(f"Error fetching or converting image: {e}") + raise ValueError("Failed to fetch or convert the generated image") except Exception as e: - logger.debug(f"Error occured during image generation: {e}") + logger.debug(f"Error occurred during image generation: {e}") error_message = str(e) if "OpenAIException" in str(type(e)): @@ -235,9 +257,8 @@ class ImageGenerationTool(Tool): def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: prompt = cast(str, kwargs["prompt"]) shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE)) + format = self.output_format - # dalle3 only supports 1 image at a time, which is why we have to - # parallelize this via threading results = cast( list[ImageGenerationResponse], run_functions_tuples_in_parallel( @@ -247,6 +268,7 @@ class ImageGenerationTool(Tool): ( prompt, shape, + format, ), ) for _ in range(self.num_imgs) @@ -288,11 +310,17 @@ class ImageGenerationTool(Tool): if img_generation_response is None: raise ValueError("No image generation response found") - img_urls = [img.url for img in img_generation_response] + img_urls = [img.url for img in img_generation_response if img.url is not None] + b64_imgs = [ + img.image_data + for img in img_generation_response + if img.image_data is not None + ] prompt_builder.update_user_prompt( build_image_generation_user_prompt( query=prompt_builder.get_user_message_content(), img_urls=img_urls, + b64_imgs=b64_imgs, ) ) diff --git a/backend/danswer/tools/tool_implementations/images/prompt.py b/backend/danswer/tools/tool_implementations/images/prompt.py index bb729bfcd..e5f11ba62 100644 --- a/backend/danswer/tools/tool_implementations/images/prompt.py +++ b/backend/danswer/tools/tool_implementations/images/prompt.py @@ -11,11 +11,14 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or def build_image_generation_user_prompt( - query: str, img_urls: list[str] | None = None + query: str, + img_urls: list[str] | None = None, + b64_imgs: list[str] | None = None, ) -> HumanMessage: return HumanMessage( content=build_content_with_imgs( message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(), + b64_imgs=b64_imgs, img_urls=img_urls, ) ) diff --git a/backend/danswer/utils/b64.py b/backend/danswer/utils/b64.py new file mode 100644 index 000000000..05a915814 --- /dev/null +++ b/backend/danswer/utils/b64.py @@ -0,0 +1,25 @@ +import base64 + + +def get_image_type_from_bytes(raw_b64_bytes: bytes) -> str: + magic_number = raw_b64_bytes[:4] + + if magic_number.startswith(b"\x89PNG"): + mime_type = "image/png" + elif magic_number.startswith(b"\xFF\xD8"): + mime_type = "image/jpeg" + elif magic_number.startswith(b"GIF8"): + mime_type = "image/gif" + elif magic_number.startswith(b"RIFF") and raw_b64_bytes[8:12] == b"WEBP": + mime_type = "image/webp" + else: + raise ValueError( + "Unsupported image format - only PNG, JPEG, " "GIF, and WEBP are supported." + ) + + return mime_type + + +def get_image_type(raw_b64_string: str) -> str: + binary_data = base64.b64decode(raw_b64_string) + return get_image_type_from_bytes(binary_data)