diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 6d12d68df..1f1f15ea7 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -1,5 +1,6 @@ from collections.abc import Iterator from datetime import datetime +from enum import Enum from typing import Any from pydantic import BaseModel @@ -44,6 +45,20 @@ class QADocsResponse(RetrievalDocs): return initial_dict +class StreamStopReason(Enum): + CONTEXT_LENGTH = "context_length" + CANCELLED = "cancelled" + + +class StreamStopInfo(BaseModel): + stop_reason: StreamStopReason + + def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore + data = super().model_dump(mode="json", *args, **kwargs) # type: ignore + data["stop_reason"] = self.stop_reason.name + return data + + class LLMRelevanceFilterResponse(BaseModel): relevant_chunk_indices: list[int] @@ -144,6 +159,7 @@ AnswerQuestionPossibleReturn = ( | ImageGenerationDisplay | CustomToolResponse | StreamingError + | StreamStopInfo ) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index a664db217..684b0aab7 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,5 +1,6 @@ from collections.abc import Callable from collections.abc import Iterator +from typing import Any from typing import cast from uuid import uuid4 @@ -12,6 +13,8 @@ from danswer.chat.models import AnswerQuestionPossibleReturn from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import AnswerStyleConfig @@ -35,7 +38,7 @@ from danswer.llm.answering.stream_processing.quotes_processing import ( from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.llm.answering.stream_processing.utils import map_document_id_order from danswer.llm.interfaces import LLM -from danswer.llm.utils import message_generator_to_string_generator +from danswer.llm.interfaces import ToolChoiceOptions from danswer.natural_language_processing.utils import get_tokenizer from danswer.tools.custom.custom_tool_prompt_builder import ( build_user_message_for_custom_tool_for_non_tool_calling_llm, @@ -190,7 +193,9 @@ class Answer: def _raw_output_for_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: + ) -> Iterator[ + str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult + ]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) tool_call_chunk: AIMessageChunk | None = None @@ -225,6 +230,7 @@ class Answer: self.tools, self.force_use_tool ) ] + for message in self.llm.stream( prompt=prompt, tools=final_tool_definitions if final_tool_definitions else None, @@ -298,21 +304,41 @@ class Answer: yield tool_runner.tool_final_result() prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - for token in message_generator_to_string_generator( - self.llm.stream( - prompt=prompt, - tools=[tool.tool_definition() for tool in self.tools], - ) - ): - if self.is_cancelled: - return - yield token + + yield from self._process_llm_stream( + prompt=prompt, + tools=[tool.tool_definition() for tool in self.tools], + ) return + # This method processes the LLM stream and yields the content or stop information + def _process_llm_stream( + self, + prompt: Any, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> Iterator[str | StreamStopInfo]: + for message in self.llm.stream( + prompt=prompt, tools=tools, tool_choice=tool_choice + ): + if isinstance(message, AIMessageChunk): + if message.content: + if self.is_cancelled: + return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + yield cast(str, message.content) + + if ( + message.additional_kwargs.get("usage_metadata", {}).get("stop") + == "length" + ): + yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH) + def _raw_output_for_non_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: + ) -> Iterator[ + str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult + ]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) chosen_tool_and_args: tuple[Tool, dict] | None = None @@ -387,13 +413,10 @@ class Answer: ) ) prompt = prompt_builder.build() - for token in message_generator_to_string_generator( - self.llm.stream(prompt=prompt) - ): - if self.is_cancelled: - return - yield token - + yield from self._process_llm_stream( + prompt=prompt, + tools=None, + ) return tool, tool_args = chosen_tool_and_args @@ -447,12 +470,8 @@ class Answer: yield final prompt = prompt_builder.build() - for token in message_generator_to_string_generator( - self.llm.stream(prompt=prompt) - ): - if self.is_cancelled: - return - yield token + + yield from self._process_llm_stream(prompt=prompt, tools=None) @property def processed_streamed_output(self) -> AnswerStream: @@ -470,7 +489,7 @@ class Answer: ) def _process_stream( - stream: Iterator[ToolCallKickoff | ToolResponse | str], + stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo], ) -> AnswerStream: message = None @@ -524,12 +543,15 @@ class Answer: answer_style_configs=self.answer_style_config, ) - def _stream() -> Iterator[str]: - if message: - yield cast(str, message) - yield from cast(Iterator[str], stream) + def _stream() -> Iterator[str | StreamStopInfo]: + yield cast(str | StreamStopInfo, message) + yield from (cast(str | StreamStopInfo, item) for item in stream) - yield from process_answer_stream_fn(_stream()) + for item in _stream(): + if isinstance(item, StreamStopInfo): + yield item + else: + yield from process_answer_stream_fn(iter([item])) processed_stream = [] for processed_packet in _process_stream(output_generator): diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index de80b6f67..a72fc70a8 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -11,7 +11,6 @@ from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -204,7 +203,9 @@ def extract_citations_from_stream( def build_citation_processor( context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping ) -> StreamProcessor: - def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + def stream_processor( + tokens: Iterator[str], + ) -> AnswerQuestionStreamReturn: yield from extract_citations_from_stream( tokens=tokens, context_docs=context_docs, diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py index 74f37b852..501a56b5a 100644 --- a/backend/danswer/llm/answering/stream_processing/quotes_processing.py +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -285,7 +285,9 @@ def process_model_tokens( def build_quotes_processor( context_docs: list[LlmDoc], is_json_prompt: bool ) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]: - def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + def stream_processor( + tokens: Iterator[str], + ) -> AnswerQuestionStreamReturn: yield from process_model_tokens( tokens=tokens, context_docs=context_docs, diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 05711044b..08131f581 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -138,7 +138,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_delta_to_message_chunk( - _dict: dict[str, Any], curr_msg: BaseMessage | None + _dict: dict[str, Any], + curr_msg: BaseMessage | None, + stop_reason: str | None = None, ) -> BaseMessageChunk: """Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk""" role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None) @@ -163,12 +165,23 @@ def _convert_delta_to_message_chunk( args=tool_call.function.arguments, index=0, # only support a single tool call atm ) + return AIMessageChunk( content=content, - additional_kwargs=additional_kwargs, tool_call_chunks=[tool_call_chunk], + additional_kwargs={ + "usage_metadata": {"stop": stop_reason}, + **additional_kwargs, + }, ) - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + + return AIMessageChunk( + content=content, + additional_kwargs={ + "usage_metadata": {"stop": stop_reason}, + **additional_kwargs, + }, + ) elif role == "system": return SystemMessageChunk(content=content) elif role == "function": @@ -206,7 +219,7 @@ class DefaultMultiLLM(LLM): self._api_version = api_version self._custom_llm_provider = custom_llm_provider - # This can be used to store the maximum output tkoens for this model. + # This can be used to store the maximum output tokens for this model. # self._max_output_tokens = ( # max_output_tokens # if max_output_tokens is not None @@ -349,10 +362,16 @@ class DefaultMultiLLM(LLM): ) try: for part in response: - if len(part["choices"]) == 0: + if not part["choices"]: continue - delta = part["choices"][0]["delta"] - message_chunk = _convert_delta_to_message_chunk(delta, output) + + choice = part["choices"][0] + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], + output, + stop_reason=choice["finish_reason"], + ) + if output is None: output = message_chunk else: diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 43503a3af..57c663af2 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -65,7 +65,12 @@ import { FiArrowDown } from "react-icons/fi"; import { ChatIntro } from "./ChatIntro"; import { AIMessage, HumanMessage } from "./message/Messages"; import { StarterMessage } from "./StarterMessage"; -import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces"; +import { + AnswerPiecePacket, + DanswerDocument, + StreamStopInfo, + StreamStopReason, +} from "@/lib/search/interfaces"; import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; @@ -94,6 +99,7 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; import { SEARCH_TOOL_NAME } from "./tools/constants"; import { useUser } from "@/components/user/UserProvider"; +import { Stop } from "@phosphor-icons/react"; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; @@ -338,6 +344,7 @@ export function ChatPage({ } return; } + clearSelectedDocuments(); setIsFetchingChatMessages(true); const response = await fetch( @@ -624,6 +631,24 @@ export function ChatPage({ const currentRegenerationState = (): RegenerationState | null => { return regenerationState.get(currentSessionId()) || null; }; + const [canContinue, setCanContinue] = useState>( + new Map([[null, false]]) + ); + + const updateCanContinue = (newState: boolean, sessionId?: number | null) => { + setCanContinue((prevState) => { + const newCanContinueState = new Map(prevState); + newCanContinueState.set( + sessionId !== undefined ? sessionId : currentSessionId(), + newState + ); + return newCanContinueState; + }); + }; + + const currentCanContinue = (): boolean => { + return canContinue.get(currentSessionId()) || false; + }; const currentSessionChatState = currentChatState(); const currentSessionRegenerationState = currentRegenerationState(); @@ -864,6 +889,13 @@ export function ChatPage({ } }; + const continueGenerating = () => { + onSubmit({ + messageOverride: + "Continue Generating (pick up exactly where you left off)", + }); + }; + const onSubmit = async ({ messageIdToResend, messageOverride, @@ -884,6 +916,7 @@ export function ChatPage({ regenerationRequest?: RegenerationRequest | null; } = {}) => { let frozenSessionId = currentSessionId(); + updateCanContinue(false, frozenSessionId); if (currentChatState() != "input") { setPopup({ @@ -978,6 +1011,8 @@ export function ChatPage({ let messageUpdates: Message[] | null = null; let answer = ""; + + let stopReason: StreamStopReason | null = null; let query: string | null = null; let retrievalType: RetrievalType = selectedDocuments.length > 0 @@ -1174,6 +1209,11 @@ export function ChatPage({ stackTrace = (packet as StreamingError).stack_trace; } else if (Object.hasOwn(packet, "message_id")) { finalMessage = packet as BackendMessage; + } else if (Object.hasOwn(packet, "stop_reason")) { + const stop_reason = (packet as StreamStopInfo).stop_reason; + if (stop_reason === StreamStopReason.CONTEXT_LENGTH) { + updateCanContinue(true, frozenSessionId); + } } // on initial message send, we insert a dummy system message @@ -1237,6 +1277,7 @@ export function ChatPage({ alternateAssistantID: alternativeAssistant?.id, stackTrace: stackTrace, overridden_model: finalMessage?.overridden_model, + stopReason: stopReason, }, ]); } @@ -1835,6 +1876,12 @@ export function ChatPage({ } > void; +}) { + const [showExplanation, setShowExplanation] = useState(false); + + useEffect(() => { + const timer = setTimeout(() => { + setShowExplanation(true); + }, 1000); + + return () => clearTimeout(timer); + }, []); + + return ( +
+
+ + <> + + Continue Generation + + + {showExplanation && ( +
+ LLM reached its token limit. Click to continue. +
+ )} +
+
+ ); +} diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 09cacd1b9..fa307ec93 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -65,6 +65,8 @@ import GeneratingImageDisplay from "../tools/GeneratingImageDisplay"; import RegenerateOption from "../RegenerateOption"; import { LlmOverride } from "@/lib/hooks"; import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; +import { EmphasizedClickable } from "@/components/BasicClickable"; +import { ContinueGenerating } from "./ContinueMessage"; const TOOLS_WITH_CUSTOM_HANDLING = [ SEARCH_TOOL_NAME, @@ -123,6 +125,7 @@ function FileDisplay({ export const AIMessage = ({ regenerate, overriddenModel, + continueGenerating, shared, isActive, toggleDocumentSelection, @@ -150,6 +153,7 @@ export const AIMessage = ({ }: { shared?: boolean; isActive?: boolean; + continueGenerating?: () => void; otherMessagesCanSwitchTo?: number[]; onMessageSelection?: (messageId: number) => void; selectedDocuments?: DanswerDocument[] | null; @@ -283,11 +287,12 @@ export const AIMessage = ({ size="small" assistant={alternativeAssistant || currentPersona} /> +
- {(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && ( + {!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? ( <> {query !== undefined && handleShowRetrieved !== undefined && @@ -315,7 +320,8 @@ export const AIMessage = ({
)} - )} + ) : null} + {toolCall && !TOOLS_WITH_CUSTOM_HANDLING.includes( toolCall.tool_name @@ -633,6 +639,11 @@ export const AIMessage = ({
+ {(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && + !query && + continueGenerating && ( + + )} ); diff --git a/web/src/app/chat/message/SkippedSearch.tsx b/web/src/app/chat/message/SkippedSearch.tsx index 62c47b7d9..b339ac784 100644 --- a/web/src/app/chat/message/SkippedSearch.tsx +++ b/web/src/app/chat/message/SkippedSearch.tsx @@ -27,7 +27,7 @@ export function SkippedSearch({ handleForceSearch: () => void; }) { return ( -
+
diff --git a/web/src/lib/search/interfaces.ts b/web/src/lib/search/interfaces.ts index b33879055..6983bd336 100644 --- a/web/src/lib/search/interfaces.ts +++ b/web/src/lib/search/interfaces.ts @@ -19,6 +19,15 @@ export interface AnswerPiecePacket { answer_piece: string; } +export enum StreamStopReason { + CONTEXT_LENGTH = "CONTEXT_LENGTH", + CANCELLED = "CANCELLED", +} + +export interface StreamStopInfo { + stop_reason: StreamStopReason; +} + export interface ErrorMessagePacket { error: string; }