mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 11:12:02 +01:00
add sequential tool calls
This commit is contained in:
parent
285bdbbaf9
commit
49397e8a86
@ -0,0 +1,65 @@
|
||||
"""single tool call per message
|
||||
|
||||
|
||||
Revision ID: 4e8e7ae58189
|
||||
Revises: f7e58d357687
|
||||
Create Date: 2024-09-09 10:07:58.008838
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4e8e7ae58189"
|
||||
down_revision = "f7e58d357687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the new column
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_tool_call",
|
||||
"chat_message",
|
||||
"tool_call",
|
||||
["tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Migrate existing data
|
||||
op.execute(
|
||||
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
|
||||
)
|
||||
|
||||
# Drop the old relationship
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "message_id")
|
||||
|
||||
# Add a unique constraint to ensure one-to-one relationship
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old column
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
|
||||
)
|
||||
|
||||
# Migrate data back
|
||||
op.execute(
|
||||
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
|
||||
)
|
||||
|
||||
# Drop the new column
|
||||
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "tool_call_id")
|
@ -48,6 +48,8 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
|
@ -18,6 +18,8 @@ from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@ -617,6 +619,11 @@ def stream_chat_message_objects(
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@ -738,10 +745,80 @@ def stream_chat_message_objects(
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
|
||||
db_citations = None
|
||||
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=gen_ai_response_message.id,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
elif isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@ -767,12 +844,6 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
@ -786,16 +857,14 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
else [],
|
||||
else None,
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
|
@ -178,8 +178,14 @@ def delete_search_doc_message_relationship(
|
||||
|
||||
|
||||
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
|
||||
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
|
||||
)
|
||||
if chat_message and chat_message.tool_call_id:
|
||||
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message.tool_call_id = None
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@ -388,7 +394,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@ -474,7 +480,7 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
@ -482,6 +488,7 @@ def create_new_chat_message(
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
|
||||
print(f"creating with reserved id {reserved_message_id}")
|
||||
if existing_message is None:
|
||||
raise ValueError(f"No message found with id {reserved_message_id}")
|
||||
|
||||
@ -494,7 +501,7 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.tool_call = tool_call if tool_call else None
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
@ -513,7 +520,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
tool_call=tool_call if tool_call else None,
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
@ -747,14 +754,13 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
@ -854,10 +854,8 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
"ChatMessage", back_populates="tool_call"
|
||||
)
|
||||
|
||||
|
||||
@ -984,9 +982,14 @@ class ChatMessage(Base):
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
|
||||
tool_call_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("tool_call.id"), nullable=True
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
|
||||
)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
|
@ -16,6 +16,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
@ -68,6 +69,7 @@ from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MAX_TOOL_CALLS
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -161,6 +163,10 @@ class Answer:
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.final_context_docs: list = []
|
||||
self.current_streamed_output: list = []
|
||||
self.processing_stream: list = []
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
@ -196,128 +202,166 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
for i in range(MAX_TOOL_CALLS):
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
content="",
|
||||
)
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
else:
|
||||
# if tool calling is supported, first try the raw message
|
||||
# to see if we don't need to use any tools
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if not tool_call_chunk:
|
||||
return # no tool call needed
|
||||
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
|
||||
else:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get(
|
||||
"stop"
|
||||
)
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
return
|
||||
if not tool_call_chunk:
|
||||
logger.info("Skipped tool call but generated message")
|
||||
return
|
||||
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
tool_responses = list(tool_runner.tool_responses())
|
||||
yield from tool_responses
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
# Update message history with tool call and response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=str(tool_call_request),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=10, # You may want to implement a token counting method
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message="\n".join(str(response) for response in tool_responses),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=10,
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
response_content = ""
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=10,
|
||||
tool_call=None,
|
||||
files=[], # You may want to implement a token counting method
|
||||
)
|
||||
)
|
||||
|
||||
# This method processes the LLM stream and yields the content or stop information
|
||||
def _process_llm_stream(
|
||||
@ -494,6 +538,7 @@ class Answer:
|
||||
and not self.skip_explicit_tool_calling
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
)
|
||||
self.processing_stream = []
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
|
||||
@ -535,56 +580,69 @@ class Answer:
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
stream_stop_info = None
|
||||
new_kickoff = None
|
||||
|
||||
stream_stop_info = None
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
nonlocal new_kickoff
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
yield cast(str, item)
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
if isinstance(item, ToolCallKickoff):
|
||||
new_kickoff = item
|
||||
stream_stop_info = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.NEW_RESPONSE
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield cast(str, item)
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
# if new_kickoff: handle new tool call (continuation of message)
|
||||
if new_kickoff:
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
|
||||
yield new_kickoff
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
processed_stream.append(processed_packet)
|
||||
self.processing_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if not self._processed_stream and not self.current_streamed_output:
|
||||
return ""
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
for packet in self.current_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
|
@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
@ -99,6 +99,29 @@ def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
|
||||
return error_msg
|
||||
|
||||
|
||||
# def translate_danswer_msg_to_langchain(
|
||||
# msg: Union[ChatMessage, "PreviousMessage"],
|
||||
# ) -> BaseMessage:
|
||||
# files: list[InMemoryChatFile] = []
|
||||
|
||||
# # If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# # attached. Just ignore them for now. Also, OpenAI doesn't allow files to
|
||||
# # be attached to AI messages, so we must remove them
|
||||
# if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
|
||||
# files = msg.files
|
||||
# content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
# if msg.message_type == MessageType.SYSTEM:
|
||||
# raise ValueError("System messages are not currently part of history")
|
||||
# if msg.message_type == MessageType.ASSISTANT:
|
||||
# return AIMessage(content=content)
|
||||
# if msg.message_type == MessageType.USER:
|
||||
# return HumanMessage(content=content)
|
||||
|
||||
# raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
# TODO This is quite janky
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
@ -112,14 +135,36 @@ def translate_danswer_msg_to_langchain(
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
return SystemMessage(content=content)
|
||||
wrapped_content = ""
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
if (
|
||||
"name" in parsed_content
|
||||
and parsed_content["name"] == "run_image_generation"
|
||||
):
|
||||
wrapped_content += f"I, the AI, am now generating an \
|
||||
image based on the prompt: '{parsed_content['args']['prompt']}'\n"
|
||||
wrapped_content += "[/AI IMAGE GENERATION REQUEST]"
|
||||
elif (
|
||||
"id" in parsed_content
|
||||
and parsed_content["id"] == "image_generation_response"
|
||||
):
|
||||
wrapped_content += "I, the AI, have generated the following image(s) based on the previous request:\n"
|
||||
for img in parsed_content["response"]:
|
||||
wrapped_content += f"- Description: {img['revised_prompt']}\n"
|
||||
wrapped_content += f" Image URL: {img['url']}\n\n"
|
||||
wrapped_content += "[/AI IMAGE GENERATION RESPONSE]"
|
||||
else:
|
||||
wrapped_content = content
|
||||
except json.JSONDecodeError:
|
||||
wrapped_content = content
|
||||
return AIMessage(content=wrapped_content)
|
||||
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
|
@ -178,7 +178,7 @@ class ChatMessageDetail(BaseModel):
|
||||
chat_session_id: int | None = None
|
||||
citations: dict[int, int] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
|
@ -21,6 +21,8 @@ CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
# TOoc all configs
|
||||
MAX_TOOL_CALLS = 2
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
36
web/src/app/chat/AIMessageSequenceUtils.ts
Normal file
36
web/src/app/chat/AIMessageSequenceUtils.ts
Normal file
@ -0,0 +1,36 @@
|
||||
// For handling AI message `sequences` (ie. ai messages which are streamed in sequence as separate messags but are in reality one message)
|
||||
|
||||
import { Message } from "@/app/chat/interfaces";
|
||||
import { DanswerDocument } from "@/lib/search/interfaces";
|
||||
|
||||
export function getConsecutiveAIMessagesAtEnd(
|
||||
messageHistory: Message[]
|
||||
): Message[] {
|
||||
const aiMessages = [];
|
||||
for (let i = messageHistory.length - 1; i >= 0; i--) {
|
||||
if (messageHistory[i]?.type === "assistant") {
|
||||
aiMessages.unshift(messageHistory[i]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return aiMessages;
|
||||
}
|
||||
export function getUniqueDocumentsFromAIMessages(
|
||||
messages: Message[]
|
||||
): DanswerDocument[] {
|
||||
const uniqueDocumentsMap = new Map<string, DanswerDocument>();
|
||||
|
||||
messages.forEach((message) => {
|
||||
if (message.documents) {
|
||||
message.documents.forEach((doc) => {
|
||||
const uniqueKey = `${doc.document_id}-${doc.chunk_ind}`;
|
||||
if (!uniqueDocumentsMap.has(uniqueKey)) {
|
||||
uniqueDocumentsMap.set(uniqueKey, doc);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return Array.from(uniqueDocumentsMap.values());
|
||||
}
|
@ -13,6 +13,7 @@ import {
|
||||
ImageGenerationDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
PreviousAIMessage,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
ToolCallMetadata,
|
||||
@ -101,6 +102,8 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { SEARCH_TOOL_NAME } from "./tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
import { Button } from "@tremor/react";
|
||||
import dynamic from "next/dynamic";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@ -133,7 +136,6 @@ export function ChatPage({
|
||||
} = useChatContext();
|
||||
|
||||
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
|
||||
|
||||
const { user, refreshUser, isLoadingUser } = useUser();
|
||||
|
||||
// chat session
|
||||
@ -248,13 +250,13 @@ export function ChatPage({
|
||||
if (
|
||||
lastMessage &&
|
||||
lastMessage.type === "assistant" &&
|
||||
lastMessage.toolCalls[0] &&
|
||||
lastMessage.toolCalls[0].tool_result === undefined
|
||||
lastMessage.toolCall &&
|
||||
lastMessage.toolCall.tool_result === undefined
|
||||
) {
|
||||
const newCompleteMessageMap = new Map(
|
||||
currentMessageMap(completeMessageDetail)
|
||||
);
|
||||
const updatedMessage = { ...lastMessage, toolCalls: [] };
|
||||
const updatedMessage = { ...lastMessage, toolCall: null };
|
||||
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
|
||||
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
|
||||
}
|
||||
@ -483,7 +485,7 @@ export function ChatPage({
|
||||
message: "",
|
||||
type: "system",
|
||||
files: [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: null,
|
||||
childrenMessageIds: [firstMessageId],
|
||||
latestChildMessageId: firstMessageId,
|
||||
@ -510,6 +512,7 @@ export function ChatPage({
|
||||
}
|
||||
newCompleteMessageMap.set(message.messageId, message);
|
||||
});
|
||||
|
||||
// if specified, make these new message the latest of the current message chain
|
||||
if (makeLatestChildMessage) {
|
||||
const currentMessageChain = buildLatestMessageChain(
|
||||
@ -1044,8 +1047,6 @@ export function ChatPage({
|
||||
resetInputBar();
|
||||
let messageUpdates: Message[] | null = null;
|
||||
|
||||
let answer = "";
|
||||
|
||||
let stopReason: StreamStopReason | null = null;
|
||||
let query: string | null = null;
|
||||
let retrievalType: RetrievalType =
|
||||
@ -1058,12 +1059,14 @@ export function ChatPage({
|
||||
let stackTrace: string | null = null;
|
||||
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCalls: ToolCallMetadata[] = [];
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
assistant_message_id: number;
|
||||
frozenMessageMap: Map<number, Message>;
|
||||
initialDynamicParentMessage: Message;
|
||||
initialDynamicAssistantMessage: Message;
|
||||
} = null;
|
||||
|
||||
try {
|
||||
@ -1122,7 +1125,16 @@ export function ChatPage({
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
};
|
||||
|
||||
let updateFn = (messages: Message[]) => {
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
};
|
||||
|
||||
await delay(50);
|
||||
let dynamicParentMessage: Message | null = null;
|
||||
let dynamicAssistantMessage: Message | null = null;
|
||||
while (!stack.isComplete || !stack.isEmpty()) {
|
||||
await delay(0.5);
|
||||
|
||||
@ -1161,7 +1173,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
];
|
||||
@ -1176,22 +1188,122 @@ export function ChatPage({
|
||||
});
|
||||
}
|
||||
|
||||
const { messageMap: currentFrozenMessageMap } =
|
||||
let { messageMap: currentFrozenMessageMap } =
|
||||
upsertToCompleteMessageMap({
|
||||
messages: messageUpdates,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
|
||||
const frozenMessageMap = currentFrozenMessageMap;
|
||||
let frozenMessageMap = currentFrozenMessageMap;
|
||||
|
||||
let initialDynamicParentMessage: Message = {
|
||||
messageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: user_message_id!,
|
||||
message: "",
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCall: null,
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [
|
||||
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
|
||||
[]),
|
||||
-100,
|
||||
],
|
||||
latestChildMessageId: -100,
|
||||
};
|
||||
|
||||
let initialDynamicAssistantMessage: Message = {
|
||||
messageId: assistant_message_id!,
|
||||
message: "",
|
||||
type: "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
parentMessageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: user_message_id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
};
|
||||
|
||||
initialFetchDetails = {
|
||||
frozenMessageMap,
|
||||
assistant_message_id,
|
||||
user_message_id,
|
||||
initialDynamicParentMessage,
|
||||
initialDynamicAssistantMessage,
|
||||
};
|
||||
|
||||
resetRegenerationState();
|
||||
} else {
|
||||
const { user_message_id, frozenMessageMap } = initialFetchDetails;
|
||||
let {
|
||||
initialDynamicParentMessage,
|
||||
initialDynamicAssistantMessage,
|
||||
user_message_id,
|
||||
frozenMessageMap,
|
||||
} = initialFetchDetails;
|
||||
|
||||
if (
|
||||
dynamicParentMessage === null &&
|
||||
dynamicAssistantMessage === null
|
||||
) {
|
||||
console.log("INITIALizing");
|
||||
dynamicParentMessage = initialDynamicParentMessage;
|
||||
dynamicAssistantMessage = initialDynamicAssistantMessage;
|
||||
dynamicParentMessage.childrenMessageIds = [
|
||||
initialFetchDetails.assistant_message_id,
|
||||
];
|
||||
|
||||
dynamicParentMessage.latestChildMessageId =
|
||||
initialFetchDetails.assistant_message_id;
|
||||
|
||||
dynamicParentMessage.messageId =
|
||||
initialFetchDetails.user_message_id;
|
||||
dynamicParentMessage.message = currMessage;
|
||||
}
|
||||
|
||||
if (!dynamicAssistantMessage || !dynamicParentMessage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "user_message_id")) {
|
||||
let newParentMessageId = dynamicParentMessage.messageId;
|
||||
const messageResponseIDInfo = packet as MessageResponseIDInfo;
|
||||
|
||||
for (const key in dynamicAssistantMessage) {
|
||||
(dynamicParentMessage as Record<string, any>)[key] = (
|
||||
dynamicAssistantMessage as Record<string, any>
|
||||
)[key];
|
||||
}
|
||||
|
||||
dynamicParentMessage.parentMessageId = newParentMessageId;
|
||||
dynamicParentMessage.latestChildMessageId =
|
||||
messageResponseIDInfo.reserved_assistant_message_id;
|
||||
dynamicParentMessage.childrenMessageIds = [
|
||||
messageResponseIDInfo.reserved_assistant_message_id,
|
||||
];
|
||||
|
||||
dynamicParentMessage.messageId =
|
||||
messageResponseIDInfo.user_message_id!;
|
||||
dynamicAssistantMessage = {
|
||||
messageId: messageResponseIDInfo.reserved_assistant_message_id,
|
||||
type: "assistant",
|
||||
message: "",
|
||||
documents: [],
|
||||
retrievalType: undefined,
|
||||
toolCall: null,
|
||||
files: [],
|
||||
parentMessageId: dynamicParentMessage.messageId,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: null,
|
||||
};
|
||||
}
|
||||
|
||||
setChatState((prevState) => {
|
||||
if (prevState.get(chatSessionIdRef.current!) === "loading") {
|
||||
@ -1204,37 +1316,31 @@ export function ChatPage({
|
||||
});
|
||||
|
||||
if (Object.hasOwn(packet, "answer_piece")) {
|
||||
answer += (packet as AnswerPiecePacket).answer_piece;
|
||||
dynamicAssistantMessage.message += (
|
||||
packet as AnswerPiecePacket
|
||||
).answer_piece;
|
||||
} else if (Object.hasOwn(packet, "top_documents")) {
|
||||
documents = (packet as DocumentsResponse).top_documents;
|
||||
dynamicAssistantMessage.documents = (
|
||||
packet as DocumentsResponse
|
||||
).top_documents;
|
||||
dynamicAssistantMessage.retrievalType = RetrievalType.Search;
|
||||
retrievalType = RetrievalType.Search;
|
||||
if (documents && documents.length > 0) {
|
||||
// point to the latest message (we don't know the messageId yet, which is why
|
||||
// we have to use -1)
|
||||
setSelectedMessageForDocDisplay(user_message_id);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "tool_name")) {
|
||||
toolCalls = [
|
||||
{
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
},
|
||||
];
|
||||
dynamicAssistantMessage.toolCall = {
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
};
|
||||
|
||||
if (
|
||||
!toolCalls[0].tool_result ||
|
||||
toolCalls[0].tool_result == undefined
|
||||
!dynamicAssistantMessage.toolCall ||
|
||||
!dynamicAssistantMessage.toolCall.tool_result ||
|
||||
dynamicAssistantMessage.toolCall.tool_result == undefined
|
||||
) {
|
||||
updateChatState("toolBuilding", frozenSessionId);
|
||||
} else {
|
||||
updateChatState("streaming", frozenSessionId);
|
||||
}
|
||||
|
||||
// This will be consolidated in upcoming tool calls udpate,
|
||||
// but for now, we need to set query as early as possible
|
||||
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
|
||||
query = toolCalls[0].tool_args["query"];
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "file_ids")) {
|
||||
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
|
||||
(fileId) => {
|
||||
@ -1244,82 +1350,62 @@ export function ChatPage({
|
||||
};
|
||||
}
|
||||
);
|
||||
dynamicAssistantMessage.files = aiMessageImages;
|
||||
} else if (Object.hasOwn(packet, "error")) {
|
||||
error = (packet as StreamingError).error;
|
||||
stackTrace = (packet as StreamingError).stack_trace;
|
||||
dynamicAssistantMessage.stackTrace = (
|
||||
packet as StreamingError
|
||||
).stack_trace;
|
||||
} else if (Object.hasOwn(packet, "message_id")) {
|
||||
finalMessage = packet as BackendMessage;
|
||||
dynamicAssistantMessage = {
|
||||
...dynamicAssistantMessage,
|
||||
...finalMessage,
|
||||
};
|
||||
} else if (Object.hasOwn(packet, "stop_reason")) {
|
||||
const stop_reason = (packet as StreamStopInfo).stop_reason;
|
||||
|
||||
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
|
||||
updateCanContinue(true, frozenSessionId);
|
||||
}
|
||||
}
|
||||
if (!Object.hasOwn(packet, "stop_reason")) {
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
// parentMessage =
|
||||
// parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
|
||||
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
|
||||
// Should update message map with new message based on
|
||||
// Previous message, parent message, and new message being generated.
|
||||
// All should have proper IDs.
|
||||
updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = regenerationRequest
|
||||
? new Map([
|
||||
[
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
],
|
||||
[
|
||||
dynamicParentMessage?.messageId,
|
||||
dynamicAssistantMessage?.messageId,
|
||||
],
|
||||
] as [number, number][])
|
||||
: null;
|
||||
|
||||
const updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = regenerationRequest
|
||||
? new Map([
|
||||
[
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
],
|
||||
[
|
||||
regenerationRequest?.messageId,
|
||||
initialFetchDetails?.assistant_message_id,
|
||||
],
|
||||
] as [number, number][])
|
||||
: null;
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
// completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
|
||||
updateFn([
|
||||
{
|
||||
messageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: initialFetchDetails.user_message_id!,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [
|
||||
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
|
||||
[]),
|
||||
initialFetchDetails.assistant_message_id!,
|
||||
],
|
||||
latestChildMessageId: initialFetchDetails.assistant_message_id,
|
||||
},
|
||||
{
|
||||
messageId: initialFetchDetails.assistant_message_id!,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents:
|
||||
finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
parentMessageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: initialFetchDetails.user_message_id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
},
|
||||
]);
|
||||
let { messageMap } = updateFn([
|
||||
dynamicParentMessage,
|
||||
dynamicAssistantMessage,
|
||||
]);
|
||||
frozenMessageMap = messageMap;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1333,7 +1419,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
{
|
||||
@ -1343,7 +1429,7 @@ export function ChatPage({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
files: aiMessageImages || [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId:
|
||||
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
|
||||
},
|
||||
@ -2055,6 +2141,21 @@ export function ChatPage({
|
||||
) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
const hasChildMessage =
|
||||
message.latestChildMessageId !== null &&
|
||||
message.latestChildMessageId !== undefined;
|
||||
const childMessage = hasChildMessage
|
||||
? messageMap.get(
|
||||
message.latestChildMessageId!
|
||||
)
|
||||
: null;
|
||||
|
||||
const hasParentAI =
|
||||
parentMessage?.type == "assistant";
|
||||
const hasChildAI =
|
||||
childMessage?.type == "assistant";
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`message-${message.messageId}`}
|
||||
@ -2066,6 +2167,8 @@ export function ChatPage({
|
||||
}
|
||||
>
|
||||
<AIMessage
|
||||
hasChildAI={hasChildAI}
|
||||
hasParentAI={hasParentAI}
|
||||
continueGenerating={
|
||||
i == messageHistory.length - 1 &&
|
||||
currentCanContinue()
|
||||
@ -2112,7 +2215,6 @@ export function ChatPage({
|
||||
}
|
||||
messageId={message.messageId}
|
||||
content={message.message}
|
||||
// content={message.message}
|
||||
files={message.files}
|
||||
query={
|
||||
messageHistory[i]?.query || undefined
|
||||
@ -2122,8 +2224,7 @@ export function ChatPage({
|
||||
message
|
||||
)}
|
||||
toolCall={
|
||||
message.toolCalls &&
|
||||
message.toolCalls[0]
|
||||
message.toolCall && message.toolCall
|
||||
}
|
||||
isComplete={
|
||||
i !== messageHistory.length - 1 ||
|
||||
@ -2368,6 +2469,14 @@ export function ChatPage({
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<Button
|
||||
onClick={() => {
|
||||
console.log(completeMessageDetail);
|
||||
console.log(messageHistory);
|
||||
}}
|
||||
>
|
||||
CLICK EM
|
||||
</Button>
|
||||
<ChatInputBar
|
||||
showConfigureAPIKey={() =>
|
||||
setShowApiKeyModal(true)
|
||||
|
@ -76,6 +76,24 @@ export interface SearchSession {
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface PreviousAIMessage {
|
||||
messageId?: number;
|
||||
message?: string;
|
||||
type?: "assistant";
|
||||
retrievalType?: RetrievalType;
|
||||
query?: string | null;
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files?: FileDescriptor[];
|
||||
toolCall?: ToolCallMetadata | null;
|
||||
|
||||
// for rebuilding the message tree
|
||||
parentMessageId?: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
latestChildMessageId?: number | null;
|
||||
alternateAssistantID?: number | null;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
messageId: number;
|
||||
message: string;
|
||||
@ -85,7 +103,7 @@ export interface Message {
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
toolCalls: ToolCallMetadata[];
|
||||
toolCall: ToolCallMetadata | null;
|
||||
// for rebuilding the message tree
|
||||
parentMessageId: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
@ -120,7 +138,7 @@ export interface BackendMessage {
|
||||
time_sent: string;
|
||||
citations: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
tool_calls: ToolCallFinalResult[];
|
||||
tool_call: ToolCallFinalResult | null;
|
||||
alternate_assistant_id?: number | null;
|
||||
overridden_model?: string;
|
||||
}
|
||||
|
@ -435,7 +435,7 @@ export function processRawChatHistory(
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
: {}),
|
||||
toolCalls: messageInfo.tool_calls,
|
||||
toolCall: messageInfo.tool_call,
|
||||
parentMessageId: messageInfo.parent_message,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: messageInfo.latest_child_message,
|
||||
@ -479,6 +479,7 @@ export function buildLatestMessageChain(
|
||||
let currMessage: Message | null = rootMessage;
|
||||
while (currMessage) {
|
||||
finalMessageList.push(currMessage);
|
||||
|
||||
const childMessageNumber = currMessage.latestChildMessageId;
|
||||
if (childMessageNumber && messageMap.has(childMessageNumber)) {
|
||||
currMessage = messageMap.get(childMessageNumber) as Message;
|
||||
|
@ -46,12 +46,7 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
|
||||
import {
|
||||
ThumbsUpIcon,
|
||||
ThumbsDownIcon,
|
||||
LikeFeedback,
|
||||
DislikeFeedback,
|
||||
} from "@/components/icons/icons";
|
||||
import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons";
|
||||
import {
|
||||
CustomTooltip,
|
||||
TooltipGroup,
|
||||
@ -121,6 +116,8 @@ function FileDisplay({
|
||||
}
|
||||
|
||||
export const AIMessage = ({
|
||||
hasChildAI,
|
||||
hasParentAI,
|
||||
regenerate,
|
||||
overriddenModel,
|
||||
continueGenerating,
|
||||
@ -149,6 +146,8 @@ export const AIMessage = ({
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
}: {
|
||||
hasChildAI?: boolean;
|
||||
hasParentAI?: boolean;
|
||||
shared?: boolean;
|
||||
isActive?: boolean;
|
||||
continueGenerating?: () => void;
|
||||
@ -165,7 +164,7 @@ export const AIMessage = ({
|
||||
query?: string;
|
||||
personaName?: string;
|
||||
citedDocuments?: [string, DanswerDocument][] | null;
|
||||
toolCall?: ToolCallMetadata;
|
||||
toolCall?: ToolCallMetadata | null;
|
||||
isComplete?: boolean;
|
||||
hasDocs?: boolean;
|
||||
handleFeedback?: (feedbackType: FeedbackType) => void;
|
||||
@ -274,18 +273,21 @@ export const AIMessage = ({
|
||||
<div
|
||||
id="danswer-ai-message"
|
||||
ref={trackedElementRef}
|
||||
className={"py-5 ml-4 px-5 relative flex "}
|
||||
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
|
||||
>
|
||||
<div
|
||||
className={`mx-auto ${shared ? "w-full" : "w-[90%]"} max-w-message-max`}
|
||||
>
|
||||
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
|
||||
<div className="flex">
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
|
||||
{!hasParentAI ? (
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
) : (
|
||||
<div className="w-6" />
|
||||
)}
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
<div className="w-full ml-4">
|
||||
@ -503,7 +505,8 @@ export const AIMessage = ({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{handleFeedback &&
|
||||
{!hasChildAI &&
|
||||
handleFeedback &&
|
||||
(isActive ? (
|
||||
<div
|
||||
className={`
|
||||
@ -774,7 +777,7 @@ export const HumanMessage = ({
|
||||
outline-none
|
||||
placeholder-gray-400
|
||||
resize-none
|
||||
pl-4
|
||||
pl-4crea
|
||||
overflow-y-auto
|
||||
pr-12
|
||||
py-4`}
|
||||
|
83
web/src/app/chat/tools/ImagePromptCitaiton.tsx
Normal file
83
web/src/app/chat/tools/ImagePromptCitaiton.tsx
Normal file
@ -0,0 +1,83 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { CopyIcon } from "@/components/icons/icons";
|
||||
import { Divider } from "@tremor/react";
|
||||
import React, { forwardRef, useState } from "react";
|
||||
import { FiCheck } from "react-icons/fi";
|
||||
|
||||
interface PromptDisplayProps {
|
||||
prompt1: string;
|
||||
prompt2?: string;
|
||||
arg: string;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}
|
||||
|
||||
const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
|
||||
({ prompt1, prompt2, setPopup, arg }, ref) => {
|
||||
const [copied, setCopied] = useState<number | null>(null);
|
||||
|
||||
const copyToClipboard = (text: string, index: number) => {
|
||||
navigator.clipboard
|
||||
.writeText(text)
|
||||
.then(() => {
|
||||
setPopup({ message: "Copied to clipboard", type: "success" });
|
||||
setCopied(index);
|
||||
setTimeout(() => setCopied(null), 2000); // Reset copy status after 2 seconds
|
||||
})
|
||||
.catch((err) => {
|
||||
setPopup({ message: "Failed to copy", type: "error" });
|
||||
});
|
||||
};
|
||||
|
||||
const PromptSection = ({
|
||||
copied,
|
||||
prompt,
|
||||
index,
|
||||
}: {
|
||||
copied: number | null;
|
||||
prompt: string;
|
||||
index: number;
|
||||
}) => (
|
||||
<div className="w-full p-2 rounded-lg">
|
||||
<h2 className="text-lg font-semibold mb-2">
|
||||
{arg} {index + 1}
|
||||
</h2>
|
||||
|
||||
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
|
||||
|
||||
<button
|
||||
onMouseDown={() => copyToClipboard(prompt, index)}
|
||||
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
|
||||
>
|
||||
{copied != null ? (
|
||||
<>
|
||||
<FiCheck className="mr-2" size={16} />
|
||||
Copied!
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<CopyIcon className="mr-2" size={16} />
|
||||
Copy
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="w-[400px] bg-inverted mx-auto p-6 rounded-lg shadow-lg">
|
||||
<div className="flex flex-col gap-x-4">
|
||||
<PromptSection copied={copied} prompt={prompt1} index={0} />
|
||||
{prompt2 && (
|
||||
<>
|
||||
<Divider />
|
||||
<PromptSection copied={copied} prompt={prompt2} index={1} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
DualPromptDisplay.displayName = "DualPromptDisplay";
|
||||
export default DualPromptDisplay;
|
@ -145,6 +145,19 @@ export function ClientLayout({
|
||||
),
|
||||
link: "/admin/tools",
|
||||
},
|
||||
...(enableEnterprise
|
||||
? [
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<ClipboardIcon size={18} />
|
||||
<div className="ml-1">Standard Answers</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/standard-answer",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
|
@ -2811,3 +2811,40 @@ export const WindowsIcon = ({
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const ToolCallIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 19 15"
|
||||
fill="none"
|
||||
>
|
||||
<path
|
||||
d="M4.42 0.75H2.8625H2.75C1.64543 0.75 0.75 1.64543 0.75 2.75V11.65C0.75 12.7546 1.64543 13.65 2.75 13.65H2.8625C2.8625 13.65 2.8625 13.65 2.8625 13.65C2.8625 13.65 4.00751 13.65 4.42 13.65M13.98 13.65H15.5375H15.65C16.7546 13.65 17.65 12.7546 17.65 11.65V2.75C17.65 1.64543 16.7546 0.75 15.65 0.75H15.5375H13.98"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M5.55283 4.21963C5.25993 3.92674 4.78506 3.92674 4.49217 4.21963C4.19927 4.51252 4.19927 4.9874 4.49217 5.28029L6.36184 7.14996L4.49217 9.01963C4.19927 9.31252 4.19927 9.7874 4.49217 10.0803C4.78506 10.3732 5.25993 10.3732 5.55283 10.0803L7.95283 7.68029C8.24572 7.3874 8.24572 6.91252 7.95283 6.61963L5.55283 4.21963Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M9.77753 8.75003C9.3357 8.75003 8.97753 9.10821 8.97753 9.55003C8.97753 9.99186 9.3357 10.35 9.77753 10.35H13.2775C13.7194 10.35 14.0775 9.99186 14.0775 9.55003C14.0775 9.10821 13.7194 8.75003 13.2775 8.75003H9.77753Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.1"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
@ -22,6 +22,8 @@ export interface AnswerPiecePacket {
|
||||
export enum StreamStopReason {
|
||||
CONTEXT_LENGTH = "CONTEXT_LENGTH",
|
||||
CANCELLED = "CANCELLED",
|
||||
FINISHED = "FINISHED",
|
||||
NEW_RESPONSE = "NEW_RESPONSE",
|
||||
}
|
||||
|
||||
export interface StreamStopInfo {
|
||||
|
Loading…
x
Reference in New Issue
Block a user