mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-04 09:58:32 +02:00
Add stop generating
functionality (#2100)
* functional types + sidebar * remove commits * remove logs * functional rework of temporary user/assistant ID * robustify switching * remove logs * typing * robustify frontend handling * cleaner loop + data persistence * migrate to streaming response * formatting * add new loading state to prevent collisions * add `ChatState` for more robust handling * remove logs * robustify typing * unnecessary list removed * robustify * remove log * remove false comment * slightly more robust chat state * update utility + copy * improve clarity + new SSE handling utility function * remove comments * clearer * add back stack trace detail * cleaner messages * clean final message handling * tiny formatting (remove newline) * add synchronous wrapper to avoid hampering main event loop * update typing * include logs * slightly more specific logs * add `critical` error just in case
This commit is contained in:
parent
8a7bc4e411
commit
12fccfeffd
@ -39,6 +39,7 @@ def create_chat_chain(
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
|
||||
all_chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None,
|
||||
|
@ -76,6 +76,11 @@ class CitationInfo(BaseModel):
|
||||
document_id: str
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
|
@ -11,6 +11,7 @@ from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
@ -27,6 +28,7 @@ from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_db_search_doc_by_id
|
||||
from danswer.db.chat import get_doc_query_identifiers_from_model
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
@ -241,6 +243,7 @@ ChatPacket = (
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageResponseIDInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@ -256,9 +259,9 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
# user message (e.g. this can only be used for the chat-seeding flow).
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@ -449,7 +452,18 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
@ -582,6 +596,7 @@ def stream_chat_message_objects(
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
@ -615,6 +630,7 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
@ -690,6 +706,7 @@ def stream_chat_message_objects(
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
@ -717,6 +734,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@ -737,6 +755,8 @@ def stream_chat_message_objects(
|
||||
if tool_result
|
||||
else [],
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
@ -745,7 +765,8 @@ def stream_chat_message_objects(
|
||||
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
yield StreamingError(error="Failed to parse LLM output")
|
||||
@ -757,6 +778,7 @@ def stream_chat_message(
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
@ -765,6 +787,7 @@ def stream_chat_message(
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
|
@ -393,6 +393,34 @@ def get_or_create_root_message(
|
||||
return new_root_message
|
||||
|
||||
|
||||
def reserve_message_id(
|
||||
db_session: Session,
|
||||
chat_session_id: int,
|
||||
parent_message: int,
|
||||
message_type: MessageType,
|
||||
) -> int:
|
||||
# Create an empty chat message
|
||||
empty_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
latest_child_message=None,
|
||||
message="",
|
||||
token_count=0,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
# Add the empty message to the session
|
||||
db_session.add(empty_message)
|
||||
|
||||
# Flush the session to get an ID for the new chat message
|
||||
db_session.flush()
|
||||
|
||||
# Get the ID of the newly created message
|
||||
new_id = empty_message.id
|
||||
|
||||
return new_id
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: int,
|
||||
parent_message: ChatMessage,
|
||||
@ -410,29 +438,51 @@ def create_new_chat_message(
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
) -> ChatMessage:
|
||||
new_chat_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message.id,
|
||||
latest_child_message=None,
|
||||
message=message,
|
||||
rephrased_query=rephrased_query,
|
||||
prompt_id=prompt_id,
|
||||
token_count=token_count,
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
)
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
|
||||
if existing_message is None:
|
||||
raise ValueError(f"No message found with id {reserved_message_id}")
|
||||
|
||||
existing_message.chat_session_id = chat_session_id
|
||||
existing_message.parent_message = parent_message.id
|
||||
existing_message.message = message
|
||||
existing_message.rephrased_query = rephrased_query
|
||||
existing_message.prompt_id = prompt_id
|
||||
existing_message.token_count = token_count
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
|
||||
new_chat_message = existing_message
|
||||
else:
|
||||
# Create new message
|
||||
new_chat_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message.id,
|
||||
latest_child_message=None,
|
||||
message=message,
|
||||
rephrased_query=rephrased_query,
|
||||
prompt_id=prompt_id,
|
||||
token_count=token_count,
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
)
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
|
||||
if reference_docs:
|
||||
new_chat_message.search_docs = reference_docs
|
||||
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
# Flush the session to get an ID for the new chat message
|
||||
db_session.flush()
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
@ -115,6 +116,7 @@ class Answer:
|
||||
# Returns the full document sections text from the search tool
|
||||
return_contexts: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
@ -122,6 +124,7 @@ class Answer:
|
||||
)
|
||||
|
||||
self.question = question
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
self.file_id_to_file = {file.file_id: file for file in (files or [])}
|
||||
@ -153,6 +156,7 @@ class Answer:
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
@ -235,6 +239,8 @@ class Answer:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
|
||||
if not tool_call_chunk:
|
||||
@ -292,12 +298,15 @@ class Answer:
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
yield from message_generator_to_string_generator(
|
||||
for token in message_generator_to_string_generator(
|
||||
self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
)
|
||||
):
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield token
|
||||
|
||||
return
|
||||
|
||||
@ -378,9 +387,13 @@ class Answer:
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from message_generator_to_string_generator(
|
||||
for token in message_generator_to_string_generator(
|
||||
self.llm.stream(prompt=prompt)
|
||||
)
|
||||
):
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield token
|
||||
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
@ -434,7 +447,12 @@ class Answer:
|
||||
yield final
|
||||
|
||||
prompt = prompt_builder.build()
|
||||
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
|
||||
for token in message_generator_to_string_generator(
|
||||
self.llm.stream(prompt=prompt)
|
||||
):
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield token
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@ -537,3 +555,15 @@ class Answer:
|
||||
citations.append(packet)
|
||||
|
||||
return citations
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
if self.is_connected is not None:
|
||||
if not self.is_connected():
|
||||
logger.debug("Answer stream has been cancelled")
|
||||
self._is_cancelled = not self.is_connected()
|
||||
|
||||
return self._is_cancelled
|
||||
|
@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import io
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@ -207,8 +210,6 @@ def rename_chat_session(
|
||||
chat_session_id = rename_req.chat_session_id
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
logger.info(f"Received rename request for chat session: {chat_session_id}")
|
||||
|
||||
if name:
|
||||
update_chat_session(
|
||||
db_session=db_session,
|
||||
@ -271,19 +272,39 @@ def delete_chat_session_by_id(
|
||||
delete_chat_session(user_id, session_id, db_session)
|
||||
|
||||
|
||||
async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
|
||||
def is_disconnected_sync() -> bool:
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
return not future.result(timeout=0.01)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.critical(
|
||||
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
|
||||
)
|
||||
return True
|
||||
|
||||
return is_disconnected_sync
|
||||
|
||||
|
||||
@router.post("/send-message")
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
|
||||
) -> StreamingResponse:
|
||||
"""This endpoint is both used for all the following purposes:
|
||||
- Sending a new message in the session
|
||||
- Regenerating a message in the session (just send the same one again)
|
||||
- Editing a message (similar to regenerating but sending a different message)
|
||||
- Kicking off a seeded chat session (set `use_existing_user_message`)
|
||||
|
||||
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
|
||||
have already been set as latest"""
|
||||
logger.debug(f"Received new chat message: {chat_message_req.message}")
|
||||
@ -295,15 +316,26 @@ def handle_new_chat_message(
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
packets = stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=get_litellm_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
)
|
||||
return StreamingResponse(packets, media_type="application/json")
|
||||
import json
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=get_litellm_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
is_connected=is_disconnected_func,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in chat message streaming: {e}")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.put("/set-message-as-latest")
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
FileDescriptor,
|
||||
ImageGenerationDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
ToolCallMetadata,
|
||||
@ -50,7 +51,7 @@ import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
|
||||
import { useDocumentSelection } from "./useDocumentSelection";
|
||||
import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks";
|
||||
import { computeAvailableFilters } from "@/lib/filters";
|
||||
import { FeedbackType } from "./types";
|
||||
import { ChatState, FeedbackType } from "./types";
|
||||
import { DocumentSidebar } from "./documentSidebar/DocumentSidebar";
|
||||
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
|
||||
import { FeedbackModal } from "./modal/FeedbackModal";
|
||||
@ -211,6 +212,27 @@ export function ChatPage({
|
||||
}
|
||||
}, [liveAssistant]);
|
||||
|
||||
const stopGeneration = () => {
|
||||
if (abortController) {
|
||||
abortController.abort();
|
||||
}
|
||||
const lastMessage = messageHistory[messageHistory.length - 1];
|
||||
if (
|
||||
lastMessage &&
|
||||
lastMessage.type === "assistant" &&
|
||||
lastMessage.toolCalls[0] &&
|
||||
lastMessage.toolCalls[0].tool_result === undefined
|
||||
) {
|
||||
const newCompleteMessageMap = new Map(completeMessageDetail.messageMap);
|
||||
const updatedMessage = { ...lastMessage, toolCalls: [] };
|
||||
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
|
||||
setCompleteMessageDetail({
|
||||
sessionId: completeMessageDetail.sessionId,
|
||||
messageMap: newCompleteMessageMap,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// this is for "@"ing assistants
|
||||
|
||||
// this is used to track which assistant is being used to generate the current message
|
||||
@ -413,6 +435,7 @@ export function ChatPage({
|
||||
);
|
||||
messages[0].parentMessageId = systemMessageId;
|
||||
}
|
||||
|
||||
messages.forEach((message) => {
|
||||
const idToReplace = replacementsMap?.get(message.messageId);
|
||||
if (idToReplace) {
|
||||
@ -428,7 +451,6 @@ export function ChatPage({
|
||||
}
|
||||
newCompleteMessageMap.set(message.messageId, message);
|
||||
});
|
||||
|
||||
// if specified, make these new message the latest of the current message chain
|
||||
if (makeLatestChildMessage) {
|
||||
const currentMessageChain = buildLatestMessageChain(
|
||||
@ -452,7 +474,8 @@ export function ChatPage({
|
||||
const messageHistory = buildLatestMessageChain(
|
||||
completeMessageDetail.messageMap
|
||||
);
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [submittedMessage, setSubmittedMessage] = useState("");
|
||||
const [chatState, setChatState] = useState<ChatState>("input");
|
||||
const [abortController, setAbortController] =
|
||||
useState<AbortController | null>(null);
|
||||
|
||||
@ -663,13 +686,11 @@ export function ChatPage({
|
||||
params: any
|
||||
) {
|
||||
try {
|
||||
for await (const packetBunch of sendMessage(params)) {
|
||||
for await (const packet of sendMessage(params)) {
|
||||
if (params.signal?.aborted) {
|
||||
throw new Error("AbortError");
|
||||
}
|
||||
for (const packet of packetBunch) {
|
||||
stack.push(packet);
|
||||
}
|
||||
stack.push(packet);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof Error) {
|
||||
@ -709,7 +730,7 @@ export function ChatPage({
|
||||
isSeededChat?: boolean;
|
||||
alternativeAssistantOverride?: Persona | null;
|
||||
} = {}) => {
|
||||
if (isStreaming) {
|
||||
if (chatState != "input") {
|
||||
setPopup({
|
||||
message: "Please wait for the response to complete",
|
||||
type: "error",
|
||||
@ -718,6 +739,7 @@ export function ChatPage({
|
||||
return;
|
||||
}
|
||||
|
||||
setChatState("loading");
|
||||
const controller = new AbortController();
|
||||
setAbortController(controller);
|
||||
|
||||
@ -757,13 +779,15 @@ export function ChatPage({
|
||||
"Failed to re-send message - please refresh the page and try again.",
|
||||
type: "error",
|
||||
});
|
||||
setChatState("input");
|
||||
return;
|
||||
}
|
||||
|
||||
let currMessage = messageToResend ? messageToResend.message : message;
|
||||
if (messageOverride) {
|
||||
currMessage = messageOverride;
|
||||
}
|
||||
|
||||
setSubmittedMessage(currMessage);
|
||||
const currMessageHistory =
|
||||
messageToResendIndex !== null
|
||||
? messageHistory.slice(0, messageToResendIndex)
|
||||
@ -775,39 +799,6 @@ export function ChatPage({
|
||||
: null) ||
|
||||
(messageMap.size === 1 ? Array.from(messageMap.values())[0] : null);
|
||||
|
||||
// if we're resending, set the parent's child to null
|
||||
// we will use tempMessages until the regenerated message is complete
|
||||
const messageUpdates: Message[] = [
|
||||
{
|
||||
messageId: TEMP_USER_MESSAGE_ID,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: parentMessage?.messageId || null,
|
||||
},
|
||||
];
|
||||
if (parentMessage) {
|
||||
messageUpdates.push({
|
||||
...parentMessage,
|
||||
childrenMessageIds: (parentMessage.childrenMessageIds || []).concat([
|
||||
TEMP_USER_MESSAGE_ID,
|
||||
]),
|
||||
latestChildMessageId: TEMP_USER_MESSAGE_ID,
|
||||
});
|
||||
}
|
||||
const { messageMap: frozenMessageMap, sessionId: frozenSessionId } =
|
||||
upsertToCompleteMessageMap({
|
||||
messages: messageUpdates,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
if (!parentMessage && frozenMessageMap.size === 2) {
|
||||
parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null;
|
||||
}
|
||||
|
||||
const currentAssistantId = alternativeAssistantOverride
|
||||
? alternativeAssistantOverride.id
|
||||
: alternativeAssistant
|
||||
@ -815,8 +806,8 @@ export function ChatPage({
|
||||
: liveAssistant.id;
|
||||
|
||||
resetInputBar();
|
||||
let messageUpdates: Message[] | null = null;
|
||||
|
||||
setIsStreaming(true);
|
||||
let answer = "";
|
||||
let query: string | null = null;
|
||||
let retrievalType: RetrievalType =
|
||||
@ -831,6 +822,13 @@ export function ChatPage({
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCalls: ToolCallMetadata[] = [];
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
assistant_message_id: number;
|
||||
frozenMessageMap: Map<number, Message>;
|
||||
frozenSessionId: number | null;
|
||||
} = null;
|
||||
|
||||
try {
|
||||
const lastSuccessfulMessageId =
|
||||
getLastSuccessfulMessageId(currMessageHistory);
|
||||
@ -838,7 +836,6 @@ export function ChatPage({
|
||||
const stack = new CurrentMessageFIFO();
|
||||
updateCurrentMessageFIFO(stack, {
|
||||
signal: controller.signal, // Add this line
|
||||
|
||||
message: currMessage,
|
||||
alternateAssistantId: currentAssistantId,
|
||||
fileDescriptors: currentMessageFiles,
|
||||
@ -875,20 +872,6 @@ export function ChatPage({
|
||||
useExistingUserMessage: isSeededChat,
|
||||
});
|
||||
|
||||
const updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = finalMessage
|
||||
? new Map([
|
||||
[messages[0].messageId, TEMP_USER_MESSAGE_ID],
|
||||
[messages[1].messageId, TEMP_ASSISTANT_MESSAGE_ID],
|
||||
] as [number, number][])
|
||||
: null;
|
||||
upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
const delay = (ms: number) => {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
};
|
||||
@ -899,8 +882,71 @@ export function ChatPage({
|
||||
|
||||
if (!stack.isEmpty()) {
|
||||
const packet = stack.nextPacket();
|
||||
console.log(packet);
|
||||
if (packet) {
|
||||
if (!packet) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!initialFetchDetails) {
|
||||
if (!Object.hasOwn(packet, "user_message_id")) {
|
||||
console.error(
|
||||
"First packet should contain message response info "
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const messageResponseIDInfo = packet as MessageResponseIDInfo;
|
||||
|
||||
const user_message_id = messageResponseIDInfo.user_message_id!;
|
||||
const assistant_message_id =
|
||||
messageResponseIDInfo.reserved_assistant_message_id;
|
||||
|
||||
// we will use tempMessages until the regenerated message is complete
|
||||
messageUpdates = [
|
||||
{
|
||||
messageId: user_message_id,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: parentMessage?.messageId || null,
|
||||
},
|
||||
];
|
||||
if (parentMessage) {
|
||||
messageUpdates.push({
|
||||
...parentMessage,
|
||||
childrenMessageIds: (
|
||||
parentMessage.childrenMessageIds || []
|
||||
).concat([user_message_id]),
|
||||
latestChildMessageId: user_message_id,
|
||||
});
|
||||
}
|
||||
|
||||
const {
|
||||
messageMap: currentFrozenMessageMap,
|
||||
sessionId: currentFrozenSessionId,
|
||||
} = upsertToCompleteMessageMap({
|
||||
messages: messageUpdates,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
|
||||
const frozenMessageMap = currentFrozenMessageMap;
|
||||
const frozenSessionId = currentFrozenSessionId;
|
||||
initialFetchDetails = {
|
||||
frozenMessageMap,
|
||||
frozenSessionId,
|
||||
assistant_message_id,
|
||||
user_message_id,
|
||||
};
|
||||
} else {
|
||||
const { user_message_id, frozenMessageMap, frozenSessionId } =
|
||||
initialFetchDetails;
|
||||
setChatState((chatState) => {
|
||||
if (chatState == "loading") {
|
||||
return "streaming";
|
||||
}
|
||||
return chatState;
|
||||
});
|
||||
|
||||
if (Object.hasOwn(packet, "answer_piece")) {
|
||||
answer += (packet as AnswerPiecePacket).answer_piece;
|
||||
} else if (Object.hasOwn(packet, "top_documents")) {
|
||||
@ -910,7 +956,7 @@ export function ChatPage({
|
||||
if (documents && documents.length > 0) {
|
||||
// point to the latest message (we don't know the messageId yet, which is why
|
||||
// we have to use -1)
|
||||
setSelectedMessageForDocDisplay(TEMP_USER_MESSAGE_ID);
|
||||
setSelectedMessageForDocDisplay(user_message_id);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "tool_name")) {
|
||||
toolCalls = [
|
||||
@ -920,6 +966,14 @@ export function ChatPage({
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
},
|
||||
];
|
||||
if (
|
||||
!toolCalls[0].tool_result ||
|
||||
toolCalls[0].tool_result == undefined
|
||||
) {
|
||||
setChatState("toolBuilding");
|
||||
} else {
|
||||
setChatState("streaming");
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "file_ids")) {
|
||||
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
|
||||
(fileId) => {
|
||||
@ -936,23 +990,34 @@ export function ChatPage({
|
||||
finalMessage = packet as BackendMessage;
|
||||
}
|
||||
|
||||
const newUserMessageId =
|
||||
finalMessage?.parent_message || TEMP_USER_MESSAGE_ID;
|
||||
const newAssistantMessageId =
|
||||
finalMessage?.message_id || TEMP_ASSISTANT_MESSAGE_ID;
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
|
||||
|
||||
const updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = null;
|
||||
upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
|
||||
updateFn([
|
||||
{
|
||||
messageId: newUserMessageId,
|
||||
messageId: initialFetchDetails.user_message_id!,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: parentMessage?.messageId || null,
|
||||
childrenMessageIds: [newAssistantMessageId],
|
||||
latestChildMessageId: newAssistantMessageId,
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [initialFetchDetails.assistant_message_id!],
|
||||
latestChildMessageId: initialFetchDetails.assistant_message_id,
|
||||
},
|
||||
{
|
||||
messageId: newAssistantMessageId,
|
||||
messageId: initialFetchDetails.assistant_message_id!,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
@ -962,7 +1027,7 @@ export function ChatPage({
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
parentMessageId: newUserMessageId,
|
||||
parentMessageId: initialFetchDetails.user_message_id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
stackTrace: stackTrace,
|
||||
},
|
||||
@ -975,7 +1040,8 @@ export function ChatPage({
|
||||
upsertToCompleteMessageMap({
|
||||
messages: [
|
||||
{
|
||||
messageId: TEMP_USER_MESSAGE_ID,
|
||||
messageId:
|
||||
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
@ -983,24 +1049,28 @@ export function ChatPage({
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
{
|
||||
messageId: TEMP_ASSISTANT_MESSAGE_ID,
|
||||
messageId:
|
||||
initialFetchDetails?.assistant_message_id ||
|
||||
TEMP_ASSISTANT_MESSAGE_ID,
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
files: aiMessageImages || [],
|
||||
toolCalls: [],
|
||||
parentMessageId: TEMP_USER_MESSAGE_ID,
|
||||
parentMessageId:
|
||||
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
|
||||
},
|
||||
],
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
completeMessageMapOverride: completeMessageDetail.messageMap,
|
||||
});
|
||||
}
|
||||
|
||||
setIsStreaming(false);
|
||||
setChatState("input");
|
||||
if (isNewSession) {
|
||||
if (finalMessage) {
|
||||
setSelectedMessageForDocDisplay(finalMessage.message_id);
|
||||
}
|
||||
|
||||
if (!searchParamBasedChatSessionName) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
await nameChatSession(currChatSessionId, currMessage);
|
||||
}
|
||||
|
||||
@ -1060,8 +1130,8 @@ export function ChatPage({
|
||||
const onAssistantChange = (assistant: Persona | null) => {
|
||||
if (assistant && assistant.id !== liveAssistant.id) {
|
||||
// Abort the ongoing stream if it exists
|
||||
if (abortController && isStreaming) {
|
||||
abortController.abort();
|
||||
if (chatState != "input") {
|
||||
stopGeneration();
|
||||
resetInputBar();
|
||||
}
|
||||
|
||||
@ -1163,7 +1233,7 @@ export function ChatPage({
|
||||
});
|
||||
|
||||
useScrollonStream({
|
||||
isStreaming,
|
||||
chatState,
|
||||
scrollableDivRef,
|
||||
scrollDist,
|
||||
endDivRef,
|
||||
@ -1334,6 +1404,7 @@ export function ChatPage({
|
||||
>
|
||||
<div className="w-full relative">
|
||||
<HistorySidebar
|
||||
stopGenerating={stopGeneration}
|
||||
reset={() => setMessage("")}
|
||||
page="chat"
|
||||
ref={innerSidebarElementRef}
|
||||
@ -1407,7 +1478,7 @@ export function ChatPage({
|
||||
|
||||
{messageHistory.length === 0 &&
|
||||
!isFetchingChatMessages &&
|
||||
!isStreaming && (
|
||||
chatState == "input" && (
|
||||
<ChatIntro
|
||||
availableSources={finalAvailableSources}
|
||||
selectedPersona={liveAssistant}
|
||||
@ -1431,6 +1502,7 @@ export function ChatPage({
|
||||
return (
|
||||
<div key={messageReactComponentKey}>
|
||||
<HumanMessage
|
||||
stopGenerating={stopGeneration}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
messageId={message.messageId}
|
||||
@ -1483,9 +1555,7 @@ export function ChatPage({
|
||||
(selectedMessageForDocDisplay !== null &&
|
||||
selectedMessageForDocDisplay ===
|
||||
message.messageId) ||
|
||||
(selectedMessageForDocDisplay ===
|
||||
TEMP_USER_MESSAGE_ID &&
|
||||
i === messageHistory.length - 1);
|
||||
i === messageHistory.length - 1;
|
||||
const previousMessage =
|
||||
i !== 0 ? messageHistory[i - 1] : null;
|
||||
|
||||
@ -1534,7 +1604,8 @@ export function ChatPage({
|
||||
}
|
||||
isComplete={
|
||||
i !== messageHistory.length - 1 ||
|
||||
!isStreaming
|
||||
(chatState != "streaming" &&
|
||||
chatState != "toolBuilding")
|
||||
}
|
||||
hasDocs={
|
||||
(message.documents &&
|
||||
@ -1542,7 +1613,7 @@ export function ChatPage({
|
||||
}
|
||||
handleFeedback={
|
||||
i === messageHistory.length - 1 &&
|
||||
isStreaming
|
||||
chatState != "input"
|
||||
? undefined
|
||||
: (feedbackType) =>
|
||||
setCurrentFeedback([
|
||||
@ -1552,7 +1623,7 @@ export function ChatPage({
|
||||
}
|
||||
handleSearchQueryEdit={
|
||||
i === messageHistory.length - 1 &&
|
||||
!isStreaming
|
||||
chatState == "input"
|
||||
? (newQuery) => {
|
||||
if (!previousMessage) {
|
||||
setPopup({
|
||||
@ -1659,34 +1730,39 @@ export function ChatPage({
|
||||
);
|
||||
}
|
||||
})}
|
||||
{isStreaming &&
|
||||
messageHistory.length > 0 &&
|
||||
messageHistory[messageHistory.length - 1].type ===
|
||||
{chatState == "loading" &&
|
||||
messageHistory[messageHistory.length - 1]?.type !=
|
||||
"user" && (
|
||||
<div
|
||||
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
|
||||
>
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
alternativeAssistant={
|
||||
alternativeGeneratingAssistant ??
|
||||
alternativeAssistant
|
||||
}
|
||||
messageId={null}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<div
|
||||
key={"Generating"}
|
||||
className="mr-auto relative inline-block"
|
||||
>
|
||||
<span className="text-sm loading-text">
|
||||
Thinking...
|
||||
</span>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<HumanMessage
|
||||
messageId={-1}
|
||||
content={submittedMessage}
|
||||
/>
|
||||
)}
|
||||
{chatState == "loading" && (
|
||||
<div
|
||||
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
|
||||
>
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
alternativeAssistant={
|
||||
alternativeGeneratingAssistant ??
|
||||
alternativeAssistant
|
||||
}
|
||||
messageId={null}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<div
|
||||
key={"Generating"}
|
||||
className="mr-auto relative inline-block"
|
||||
>
|
||||
<span className="text-sm loading-text">
|
||||
Thinking...
|
||||
</span>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{currentPersona &&
|
||||
currentPersona.starter_messages &&
|
||||
@ -1748,6 +1824,8 @@ export function ChatPage({
|
||||
)}
|
||||
|
||||
<ChatInputBar
|
||||
chatState={chatState}
|
||||
stopGenerating={stopGeneration}
|
||||
openModelSettings={() => setSettingsToggled(true)}
|
||||
inputPrompts={userInputPrompts}
|
||||
showDocs={() => setDocumentSelection(true)}
|
||||
@ -1762,7 +1840,6 @@ export function ChatPage({
|
||||
message={message}
|
||||
setMessage={setMessage}
|
||||
onSubmit={onSubmit}
|
||||
isStreaming={isStreaming}
|
||||
filterManager={filterManager}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
files={currentMessageFiles}
|
||||
|
@ -21,6 +21,7 @@ import {
|
||||
CpuIconSkeleton,
|
||||
FileIcon,
|
||||
SendIcon,
|
||||
StopGeneratingIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { IconType } from "react-icons";
|
||||
import Popup from "../../../components/popup/Popup";
|
||||
@ -31,6 +32,9 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { StopCircle } from "@phosphor-icons/react/dist/ssr";
|
||||
import { Square } from "@phosphor-icons/react";
|
||||
import { ChatState } from "../types";
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
export function ChatInputBar({
|
||||
@ -39,10 +43,11 @@ export function ChatInputBar({
|
||||
selectedDocuments,
|
||||
message,
|
||||
setMessage,
|
||||
stopGenerating,
|
||||
onSubmit,
|
||||
isStreaming,
|
||||
filterManager,
|
||||
llmOverrideManager,
|
||||
chatState,
|
||||
|
||||
// assistants
|
||||
selectedAssistant,
|
||||
@ -59,6 +64,8 @@ export function ChatInputBar({
|
||||
inputPrompts,
|
||||
}: {
|
||||
openModelSettings: () => void;
|
||||
chatState: ChatState;
|
||||
stopGenerating: () => void;
|
||||
showDocs: () => void;
|
||||
selectedDocuments: DanswerDocument[];
|
||||
assistantOptions: Persona[];
|
||||
@ -68,7 +75,6 @@ export function ChatInputBar({
|
||||
message: string;
|
||||
setMessage: (message: string) => void;
|
||||
onSubmit: () => void;
|
||||
isStreaming: boolean;
|
||||
filterManager: FilterManager;
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
selectedAssistant: Persona;
|
||||
@ -597,24 +603,38 @@ export function ChatInputBar({
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="absolute bottom-2.5 mobile:right-4 desktop:right-10">
|
||||
<div
|
||||
className="cursor-pointer"
|
||||
onClick={() => {
|
||||
if (message) {
|
||||
onSubmit();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<SendIcon
|
||||
size={28}
|
||||
className={`text-emphasis text-white p-1 rounded-full ${
|
||||
message && !isStreaming
|
||||
? "bg-background-800"
|
||||
: "bg-[#D7D7D7]"
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
{chatState == "streaming" ||
|
||||
chatState == "toolBuilding" ||
|
||||
chatState == "loading" ? (
|
||||
<button
|
||||
className={`cursor-pointer ${chatState != "streaming" ? "bg-background-400" : "bg-background-800"} h-[28px] w-[28px] rounded-full`}
|
||||
onClick={stopGenerating}
|
||||
disabled={chatState != "streaming"}
|
||||
>
|
||||
<StopGeneratingIcon
|
||||
size={10}
|
||||
className={`text-emphasis m-auto text-white flex-none
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="cursor-pointer"
|
||||
onClick={() => {
|
||||
if (message) {
|
||||
onSubmit();
|
||||
}
|
||||
}}
|
||||
disabled={chatState != "input"}
|
||||
>
|
||||
<SendIcon
|
||||
size={28}
|
||||
className={`text-emphasis text-white p-1 rounded-full ${chatState == "input" && message ? "bg-background-800" : "bg-background-400"} `}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -118,6 +118,11 @@ export interface BackendMessage {
|
||||
alternate_assistant_id?: number | null;
|
||||
}
|
||||
|
||||
export interface MessageResponseIDInfo {
|
||||
user_message_id: number | null;
|
||||
reserved_assistant_message_id: number;
|
||||
}
|
||||
|
||||
export interface DocumentsResponse {
|
||||
top_documents: DanswerDocument[];
|
||||
rephrased_query: string | null;
|
||||
|
@ -3,8 +3,8 @@ import {
|
||||
DanswerDocument,
|
||||
Filters,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { handleStream } from "@/lib/search/streamingUtils";
|
||||
import { FeedbackType } from "./types";
|
||||
import { handleSSEStream, handleStream } from "@/lib/search/streamingUtils";
|
||||
import { ChatState, FeedbackType } from "./types";
|
||||
import {
|
||||
Dispatch,
|
||||
MutableRefObject,
|
||||
@ -20,6 +20,7 @@ import {
|
||||
FileDescriptor,
|
||||
ImageGenerationDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
ToolCallMetadata,
|
||||
@ -109,7 +110,8 @@ export type PacketType =
|
||||
| AnswerPiecePacket
|
||||
| DocumentsResponse
|
||||
| ImageGenerationDisplay
|
||||
| StreamingError;
|
||||
| StreamingError
|
||||
| MessageResponseIDInfo;
|
||||
|
||||
export async function* sendMessage({
|
||||
message,
|
||||
@ -127,6 +129,7 @@ export async function* sendMessage({
|
||||
systemPromptOverride,
|
||||
useExistingUserMessage,
|
||||
alternateAssistantId,
|
||||
signal,
|
||||
}: {
|
||||
message: string;
|
||||
fileDescriptors: FileDescriptor[];
|
||||
@ -137,70 +140,69 @@ export async function* sendMessage({
|
||||
selectedDocumentIds: number[] | null;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
// LLM overrides
|
||||
modelProvider?: string;
|
||||
modelVersion?: string;
|
||||
temperature?: number;
|
||||
// prompt overrides
|
||||
systemPromptOverride?: string;
|
||||
// if specified, will use the existing latest user message
|
||||
// and will ignore the specified `message`
|
||||
useExistingUserMessage?: boolean;
|
||||
alternateAssistantId?: number;
|
||||
}) {
|
||||
signal?: AbortSignal;
|
||||
}): AsyncGenerator<PacketType, void, unknown> {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
|
||||
const sendMessageResponse = await fetch("/api/chat/send-message", {
|
||||
const body = JSON.stringify({
|
||||
alternate_assistant_id: alternateAssistantId,
|
||||
chat_session_id: chatSessionId,
|
||||
parent_message_id: parentMessageId,
|
||||
message: message,
|
||||
prompt_id: promptId,
|
||||
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
|
||||
file_descriptors: fileDescriptors,
|
||||
retrieval_options: !documentsAreSelected
|
||||
? {
|
||||
run_search:
|
||||
promptId === null ||
|
||||
promptId === undefined ||
|
||||
queryOverride ||
|
||||
forceSearch
|
||||
? "always"
|
||||
: "auto",
|
||||
real_time: true,
|
||||
filters: filters,
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: systemPromptOverride
|
||||
? {
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
? {
|
||||
temperature,
|
||||
model_provider: modelProvider,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
});
|
||||
|
||||
const response = await fetch(`/api/chat/send-message`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
alternate_assistant_id: alternateAssistantId,
|
||||
chat_session_id: chatSessionId,
|
||||
parent_message_id: parentMessageId,
|
||||
message: message,
|
||||
prompt_id: promptId,
|
||||
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
|
||||
file_descriptors: fileDescriptors,
|
||||
retrieval_options: !documentsAreSelected
|
||||
? {
|
||||
run_search:
|
||||
promptId === null ||
|
||||
promptId === undefined ||
|
||||
queryOverride ||
|
||||
forceSearch
|
||||
? "always"
|
||||
: "auto",
|
||||
real_time: true,
|
||||
filters: filters,
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: systemPromptOverride
|
||||
? {
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
? {
|
||||
temperature,
|
||||
model_provider: modelProvider,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
}),
|
||||
body,
|
||||
signal,
|
||||
});
|
||||
if (!sendMessageResponse.ok) {
|
||||
const errorJson = await sendMessageResponse.json();
|
||||
const errorMsg = errorJson.message || errorJson.detail || "";
|
||||
throw Error(`Failed to send message - ${errorMsg}`);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
yield* handleStream<PacketType>(sendMessageResponse);
|
||||
yield* handleSSEStream<PacketType>(response);
|
||||
}
|
||||
|
||||
export async function nameChatSession(chatSessionId: number, message: string) {
|
||||
@ -635,14 +637,14 @@ export async function uploadFilesForChat(
|
||||
}
|
||||
|
||||
export async function useScrollonStream({
|
||||
isStreaming,
|
||||
chatState,
|
||||
scrollableDivRef,
|
||||
scrollDist,
|
||||
endDivRef,
|
||||
distance,
|
||||
debounce,
|
||||
}: {
|
||||
isStreaming: boolean;
|
||||
chatState: ChatState;
|
||||
scrollableDivRef: RefObject<HTMLDivElement>;
|
||||
scrollDist: MutableRefObject<number>;
|
||||
endDivRef: RefObject<HTMLDivElement>;
|
||||
@ -656,7 +658,7 @@ export async function useScrollonStream({
|
||||
const previousScroll = useRef<number>(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (isStreaming && scrollableDivRef && scrollableDivRef.current) {
|
||||
if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) {
|
||||
let newHeight: number = scrollableDivRef.current?.scrollTop!;
|
||||
const heightDifference = newHeight - previousScroll.current;
|
||||
previousScroll.current = newHeight;
|
||||
@ -712,7 +714,7 @@ export async function useScrollonStream({
|
||||
|
||||
// scroll on end of stream if within distance
|
||||
useEffect(() => {
|
||||
if (scrollableDivRef?.current && !isStreaming) {
|
||||
if (scrollableDivRef?.current && chatState == "input") {
|
||||
if (scrollDist.current < distance - 50) {
|
||||
scrollableDivRef?.current?.scrollBy({
|
||||
left: 0,
|
||||
@ -721,5 +723,5 @@ export async function useScrollonStream({
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [isStreaming]);
|
||||
}, [chatState]);
|
||||
}
|
||||
|
@ -255,7 +255,6 @@ export const AIMessage = ({
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) &&
|
||||
@ -623,6 +622,7 @@ export const HumanMessage = ({
|
||||
onEdit,
|
||||
onMessageSelection,
|
||||
shared,
|
||||
stopGenerating = () => null,
|
||||
}: {
|
||||
shared?: boolean;
|
||||
content: string;
|
||||
@ -631,6 +631,7 @@ export const HumanMessage = ({
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
onEdit?: (editedContent: string) => void;
|
||||
onMessageSelection?: (messageId: number) => void;
|
||||
stopGenerating?: () => void;
|
||||
}) => {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
@ -677,7 +678,6 @@ export const HumanMessage = ({
|
||||
<div className="xl:ml-8">
|
||||
<div className="flex flex-col mr-4">
|
||||
<FileDisplay alignBubble files={files || []} />
|
||||
|
||||
<div className="flex justify-end">
|
||||
<div className="w-full ml-8 flex w-full max-w-message-max break-words">
|
||||
{isEditing ? (
|
||||
@ -857,16 +857,18 @@ export const HumanMessage = ({
|
||||
<MessageSwitcher
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() =>
|
||||
handlePrevious={() => {
|
||||
stopGenerating();
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[currentMessageInd - 1]
|
||||
)
|
||||
}
|
||||
handleNext={() =>
|
||||
);
|
||||
}}
|
||||
handleNext={() => {
|
||||
stopGenerating();
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[currentMessageInd + 1]
|
||||
)
|
||||
}
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
@ -33,6 +33,7 @@ export function ChatSessionDisplay({
|
||||
isSelected,
|
||||
skipGradient,
|
||||
closeSidebar,
|
||||
stopGenerating = () => null,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
}: {
|
||||
@ -43,6 +44,7 @@ export function ChatSessionDisplay({
|
||||
// if not set, the gradient will still be applied and cause weirdness
|
||||
skipGradient?: boolean;
|
||||
closeSidebar?: () => void;
|
||||
stopGenerating?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
}) {
|
||||
@ -99,6 +101,7 @@ export function ChatSessionDisplay({
|
||||
className="flex my-1 group relative"
|
||||
key={chatSession.id}
|
||||
onClick={() => {
|
||||
stopGenerating();
|
||||
if (settings?.isMobile && closeSidebar) {
|
||||
closeSidebar();
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ interface HistorySidebarProps {
|
||||
reset?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
stopGenerating?: () => void;
|
||||
}
|
||||
|
||||
export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
@ -54,6 +55,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
openedFolders,
|
||||
toggleSidebar,
|
||||
removeToggle,
|
||||
stopGenerating = () => null,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
},
|
||||
@ -179,6 +181,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
)}
|
||||
<div className="border-b border-border pb-4 mx-3" />
|
||||
<PagesTab
|
||||
stopGenerating={stopGenerating}
|
||||
newFolderId={newFolderId}
|
||||
showDeleteModal={showDeleteModal}
|
||||
showShareModal={showShareModal}
|
||||
|
@ -17,10 +17,12 @@ export function PagesTab({
|
||||
folders,
|
||||
openedFolders,
|
||||
closeSidebar,
|
||||
stopGenerating,
|
||||
newFolderId,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
}: {
|
||||
stopGenerating: () => void;
|
||||
page: pageType;
|
||||
existingChats?: ChatSession[];
|
||||
currentChatId?: number;
|
||||
@ -124,6 +126,7 @@ export function PagesTab({
|
||||
return (
|
||||
<div key={`${chat.id}-${chat.name}`}>
|
||||
<ChatSessionDisplay
|
||||
stopGenerating={stopGenerating}
|
||||
showDeleteModal={showDeleteModal}
|
||||
showShareModal={showShareModal}
|
||||
closeSidebar={closeSidebar}
|
||||
|
@ -1 +1,2 @@
|
||||
export type FeedbackType = "like" | "dislike";
|
||||
export type ChatState = "input" | "loading" | "streaming" | "toolBuilding";
|
||||
|
@ -1763,6 +1763,29 @@ export const FilledLikeIcon = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const StopGeneratingIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 14 14"
|
||||
>
|
||||
<path
|
||||
fill="currentColor"
|
||||
fill-rule="evenodd"
|
||||
d="M1.5 0A1.5 1.5 0 0 0 0 1.5v11A1.5 1.5 0 0 0 1.5 14h11a1.5 1.5 0 0 0 1.5-1.5v-11A1.5 1.5 0 0 0 12.5 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const LikeFeedbackIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
|
@ -1,3 +1,5 @@
|
||||
import { PacketType } from "@/app/chat/lib";
|
||||
|
||||
type NonEmptyObject = { [k: string]: any };
|
||||
|
||||
const processSingleChunk = <T extends NonEmptyObject>(
|
||||
@ -75,3 +77,33 @@ export async function* handleStream<T extends NonEmptyObject>(
|
||||
yield await Promise.resolve(completedChunks);
|
||||
}
|
||||
}
|
||||
|
||||
export async function* handleSSEStream<T extends PacketType>(
|
||||
streamingResponse: Response
|
||||
): AsyncGenerator<T, void, unknown> {
|
||||
const reader = streamingResponse.body?.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const rawChunk = await reader?.read();
|
||||
if (!rawChunk) {
|
||||
throw new Error("Unable to process chunk");
|
||||
}
|
||||
const { done, value } = rawChunk;
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split("\n").filter((line) => line.trim() !== "");
|
||||
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const data = JSON.parse(line) as T;
|
||||
yield data;
|
||||
} catch (error) {
|
||||
console.error("Error parsing SSE data:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -90,6 +90,7 @@ module.exports = {
|
||||
"background-200": "#e5e5e5", // neutral-200
|
||||
"background-300": "#d4d4d4", // neutral-300
|
||||
"background-400": "#a3a3a3", // neutral-400
|
||||
"background-600": "#525252", // neutral-800
|
||||
"background-500": "#737373", // neutral-400
|
||||
"background-600": "#525252", // neutral-400
|
||||
"background-700": "#404040", // neutral-400
|
||||
|
Loading…
x
Reference in New Issue
Block a user