Continue Generating (#2286)

* add stop reason

* add initial propagation

* add continue generating full functionality

* proper continue across chat session

* add new look

* propagate proper types

* fix typing

* cleaner continue generating functionality

* update types

* remove unused imports

* proper infodump

* temp

* add standardized stream handling

* validateing chosen tool args

* properly handle tools

* proper ports

* remove logs + build

* minor typing fix

* fix more minor typing issues

* add stashed reversion for tool call chunks

* ignore model dump types

* remove stop stream

* fix typing
This commit is contained in:
pablodanswer 2024-09-02 15:49:56 -07:00 committed by GitHub
parent 812ca69949
commit 6afcaafe54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 214 additions and 46 deletions

View File

@ -1,5 +1,6 @@
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
@ -44,6 +45,20 @@ class QADocsResponse(RetrievalDocs):
return initial_dict
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
class StreamStopInfo(BaseModel):
stop_reason: StreamStopReason
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
@ -144,6 +159,7 @@ AnswerQuestionPossibleReturn = (
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
)

View File

@ -1,5 +1,6 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from uuid import uuid4
@ -12,6 +13,8 @@ from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
@ -35,7 +38,7 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_generator_to_string_generator
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
@ -190,7 +193,9 @@ class Answer:
def _raw_output_for_explicit_tool_calling_llms(
self,
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
@ -225,6 +230,7 @@ class Answer:
self.tools, self.force_use_tool
)
]
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
@ -298,21 +304,41 @@ class Answer:
yield tool_runner.tool_final_result()
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
for token in message_generator_to_string_generator(
self.llm.stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
):
if self.is_cancelled:
return
yield token
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
return
# This method processes the LLM stream and yields the content or stop information
def _process_llm_stream(
self,
prompt: Any,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[str | StreamStopInfo]:
for message in self.llm.stream(
prompt=prompt, tools=tools, tool_choice=tool_choice
):
if isinstance(message, AIMessageChunk):
if message.content:
if self.is_cancelled:
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)
def _raw_output_for_non_explicit_tool_calling_llms(
self,
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
@ -387,13 +413,10 @@ class Answer:
)
)
prompt = prompt_builder.build()
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
):
if self.is_cancelled:
return
yield token
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool, tool_args = chosen_tool_and_args
@ -447,12 +470,8 @@ class Answer:
yield final
prompt = prompt_builder.build()
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
):
if self.is_cancelled:
return
yield token
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:
@ -470,7 +489,7 @@ class Answer:
)
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str],
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
) -> AnswerStream:
message = None
@ -524,12 +543,15 @@ class Answer:
answer_style_configs=self.answer_style_config,
)
def _stream() -> Iterator[str]:
if message:
yield cast(str, message)
yield from cast(Iterator[str], stream)
def _stream() -> Iterator[str | StreamStopInfo]:
yield cast(str | StreamStopInfo, message)
yield from (cast(str | StreamStopInfo, item) for item in stream)
yield from process_answer_stream_fn(_stream())
for item in _stream():
if isinstance(item, StreamStopInfo):
yield item
else:
yield from process_answer_stream_fn(iter([item]))
processed_stream = []
for processed_packet in _process_stream(output_generator):

View File

@ -11,7 +11,6 @@ from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -204,7 +203,9 @@ def extract_citations_from_stream(
def build_citation_processor(
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
) -> StreamProcessor:
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from extract_citations_from_stream(
tokens=tokens,
context_docs=context_docs,

View File

@ -285,7 +285,9 @@ def process_model_tokens(
def build_quotes_processor(
context_docs: list[LlmDoc], is_json_prompt: bool
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
def stream_processor(
tokens: Iterator[str],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,

View File

@ -138,7 +138,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_delta_to_message_chunk(
_dict: dict[str, Any], curr_msg: BaseMessage | None
_dict: dict[str, Any],
curr_msg: BaseMessage | None,
stop_reason: str | None = None,
) -> BaseMessageChunk:
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None)
@ -163,12 +165,23 @@ def _convert_delta_to_message_chunk(
args=tool_call.function.arguments,
index=0, # only support a single tool call atm
)
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=[tool_call_chunk],
additional_kwargs={
"usage_metadata": {"stop": stop_reason},
**additional_kwargs,
},
)
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
return AIMessageChunk(
content=content,
additional_kwargs={
"usage_metadata": {"stop": stop_reason},
**additional_kwargs,
},
)
elif role == "system":
return SystemMessageChunk(content=content)
elif role == "function":
@ -206,7 +219,7 @@ class DefaultMultiLLM(LLM):
self._api_version = api_version
self._custom_llm_provider = custom_llm_provider
# This can be used to store the maximum output tkoens for this model.
# This can be used to store the maximum output tokens for this model.
# self._max_output_tokens = (
# max_output_tokens
# if max_output_tokens is not None
@ -349,10 +362,16 @@ class DefaultMultiLLM(LLM):
)
try:
for part in response:
if len(part["choices"]) == 0:
if not part["choices"]:
continue
delta = part["choices"][0]["delta"]
message_chunk = _convert_delta_to_message_chunk(delta, output)
choice = part["choices"][0]
message_chunk = _convert_delta_to_message_chunk(
choice["delta"],
output,
stop_reason=choice["finish_reason"],
)
if output is None:
output = message_chunk
else:

View File

@ -65,7 +65,12 @@ import { FiArrowDown } from "react-icons/fi";
import { ChatIntro } from "./ChatIntro";
import { AIMessage, HumanMessage } from "./message/Messages";
import { StarterMessage } from "./StarterMessage";
import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces";
import {
AnswerPiecePacket,
DanswerDocument,
StreamStopInfo,
StreamStopReason,
} from "@/lib/search/interfaces";
import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
@ -94,6 +99,7 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { SEARCH_TOOL_NAME } from "./tools/constants";
import { useUser } from "@/components/user/UserProvider";
import { Stop } from "@phosphor-icons/react";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@ -338,6 +344,7 @@ export function ChatPage({
}
return;
}
clearSelectedDocuments();
setIsFetchingChatMessages(true);
const response = await fetch(
@ -624,6 +631,24 @@ export function ChatPage({
const currentRegenerationState = (): RegenerationState | null => {
return regenerationState.get(currentSessionId()) || null;
};
const [canContinue, setCanContinue] = useState<Map<number | null, boolean>>(
new Map([[null, false]])
);
const updateCanContinue = (newState: boolean, sessionId?: number | null) => {
setCanContinue((prevState) => {
const newCanContinueState = new Map(prevState);
newCanContinueState.set(
sessionId !== undefined ? sessionId : currentSessionId(),
newState
);
return newCanContinueState;
});
};
const currentCanContinue = (): boolean => {
return canContinue.get(currentSessionId()) || false;
};
const currentSessionChatState = currentChatState();
const currentSessionRegenerationState = currentRegenerationState();
@ -864,6 +889,13 @@ export function ChatPage({
}
};
const continueGenerating = () => {
onSubmit({
messageOverride:
"Continue Generating (pick up exactly where you left off)",
});
};
const onSubmit = async ({
messageIdToResend,
messageOverride,
@ -884,6 +916,7 @@ export function ChatPage({
regenerationRequest?: RegenerationRequest | null;
} = {}) => {
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
if (currentChatState() != "input") {
setPopup({
@ -978,6 +1011,8 @@ export function ChatPage({
let messageUpdates: Message[] | null = null;
let answer = "";
let stopReason: StreamStopReason | null = null;
let query: string | null = null;
let retrievalType: RetrievalType =
selectedDocuments.length > 0
@ -1174,6 +1209,11 @@ export function ChatPage({
stackTrace = (packet as StreamingError).stack_trace;
} else if (Object.hasOwn(packet, "message_id")) {
finalMessage = packet as BackendMessage;
} else if (Object.hasOwn(packet, "stop_reason")) {
const stop_reason = (packet as StreamStopInfo).stop_reason;
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
updateCanContinue(true, frozenSessionId);
}
}
// on initial message send, we insert a dummy system message
@ -1237,6 +1277,7 @@ export function ChatPage({
alternateAssistantID: alternativeAssistant?.id,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
},
]);
}
@ -1835,6 +1876,12 @@ export function ChatPage({
}
>
<AIMessage
continueGenerating={
i == messageHistory.length - 1 &&
currentCanContinue()
? continueGenerating
: undefined
}
overriddenModel={message.overridden_model}
regenerate={createRegenerator({
messageId: message.messageId,

View File

@ -2,6 +2,7 @@ import {
DanswerDocument,
Filters,
SearchDanswerDocument,
StreamStopReason,
} from "@/lib/search/interfaces";
export enum RetrievalType {
@ -89,6 +90,7 @@ export interface Message {
alternateAssistantID?: number | null;
stackTrace?: string | null;
overridden_model?: string;
stopReason?: StreamStopReason | null;
}
export interface BackendChatSession {

View File

@ -2,6 +2,7 @@ import {
AnswerPiecePacket,
DanswerDocument,
Filters,
StreamStopInfo,
} from "@/lib/search/interfaces";
import { handleSSEStream, handleStream } from "@/lib/search/streamingUtils";
import { ChatState, FeedbackType } from "./types";
@ -111,7 +112,8 @@ export type PacketType =
| DocumentsResponse
| ImageGenerationDisplay
| StreamingError
| MessageResponseIDInfo;
| MessageResponseIDInfo
| StreamStopInfo;
export async function* sendMessage({
regenerate,

View File

@ -0,0 +1,37 @@
import { EmphasizedClickable } from "@/components/BasicClickable";
import { useEffect, useState } from "react";
import { FiBook, FiPlayCircle } from "react-icons/fi";
export function ContinueGenerating({
handleContinueGenerating,
}: {
handleContinueGenerating: () => void;
}) {
const [showExplanation, setShowExplanation] = useState(false);
useEffect(() => {
const timer = setTimeout(() => {
setShowExplanation(true);
}, 1000);
return () => clearTimeout(timer);
}, []);
return (
<div className="flex justify-center w-full">
<div className="relative group">
<EmphasizedClickable onClick={handleContinueGenerating}>
<>
<FiPlayCircle className="mr-2" />
Continue Generation
</>
</EmphasizedClickable>
{showExplanation && (
<div className="absolute bottom-full left-1/2 transform -translate-x-1/2 mb-2 px-3 py-1 bg-gray-800 text-white text-xs rounded-lg opacity-0 group-hover:opacity-100 transition-opacity duration-300 whitespace-nowrap">
LLM reached its token limit. Click to continue.
</div>
)}
</div>
</div>
);
}

View File

@ -65,6 +65,8 @@ import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { EmphasizedClickable } from "@/components/BasicClickable";
import { ContinueGenerating } from "./ContinueMessage";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@ -123,6 +125,7 @@ function FileDisplay({
export const AIMessage = ({
regenerate,
overriddenModel,
continueGenerating,
shared,
isActive,
toggleDocumentSelection,
@ -150,6 +153,7 @@ export const AIMessage = ({
}: {
shared?: boolean;
isActive?: boolean;
continueGenerating?: () => void;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (messageId: number) => void;
selectedDocuments?: DanswerDocument[] | null;
@ -283,11 +287,12 @@ export const AIMessage = ({
size="small"
assistant={alternativeAssistant || currentPersona}
/>
<div className="w-full">
<div className="max-w-message-max break-words">
<div className="w-full ml-4">
<div className="max-w-message-max break-words">
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
{!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? (
<>
{query !== undefined &&
handleShowRetrieved !== undefined &&
@ -315,7 +320,8 @@ export const AIMessage = ({
</div>
)}
</>
)}
) : null}
{toolCall &&
!TOOLS_WITH_CUSTOM_HANDLING.includes(
toolCall.tool_name
@ -633,6 +639,11 @@ export const AIMessage = ({
</div>
</div>
</div>
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) &&
!query &&
continueGenerating && (
<ContinueGenerating handleContinueGenerating={continueGenerating} />
)}
</div>
</div>
);

View File

@ -27,7 +27,7 @@ export function SkippedSearch({
handleForceSearch: () => void;
}) {
return (
<div className="flex text-sm !pt-0 p-1">
<div className="flex text-sm !pt-0 p-1">
<div className="flex mb-auto">
<FiBook className="my-auto flex-none mr-2" size={14} />
<div className="my-auto cursor-default">

View File

@ -19,6 +19,15 @@ export interface AnswerPiecePacket {
answer_piece: string;
}
export enum StreamStopReason {
CONTEXT_LENGTH = "CONTEXT_LENGTH",
CANCELLED = "CANCELLED",
}
export interface StreamStopInfo {
stop_reason: StreamStopReason;
}
export interface ErrorMessagePacket {
error: string;
}