new is_agentic flag for chatmessages (#4026)

* new is_agentic flag for chatmessages

* added cancelled error to db

* added cancelled error to returned message
This commit is contained in:
evan-danswer
2025-02-17 20:20:33 -08:00
committed by GitHub
parent 045a41d929
commit 2b2ba5478c
5 changed files with 54 additions and 2 deletions

View File

@@ -0,0 +1,43 @@
"""chat_message_agentic
Revision ID: 9c00a2bccb83
Revises: b7a7eee5aa15
Create Date: 2025-02-17 11:15:43.081150
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c00a2bccb83"
down_revision = "b7a7eee5aa15"
branch_labels = None
depends_on = None
def upgrade() -> None:
# First add the column as nullable
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
# Update existing rows based on presence of SubQuestions
op.execute(
"""
UPDATE chat_message
SET is_agentic = EXISTS (
SELECT 1
FROM agent__sub_question
WHERE agent__sub_question.primary_question_id = chat_message.id
)
WHERE is_agentic IS NULL
"""
)
# Make the column non-nullable with a default value of False
op.alter_column(
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
)
def downgrade() -> None:
op.drop_column("chat_message", "is_agentic")

View File

@@ -146,6 +146,7 @@ from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger() logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
def _translate_citations( def _translate_citations(
@@ -631,6 +632,7 @@ def stream_chat_message_objects(
db_session=db_session, db_session=db_session,
commit=False, commit=False,
reserved_message_id=reserved_message_id, reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
) )
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
@@ -1015,7 +1017,7 @@ def stream_chat_message_objects(
if info.message_specific_citations if info.message_specific_citations
else None else None
), ),
error=None, error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
tool_call=( tool_call=(
ToolCall( ToolCall(
tool_id=tool_name_to_tool_id[info.tool_result.tool_name], tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
@@ -1053,7 +1055,9 @@ def stream_chat_message_objects(
citations=info.message_specific_citations.citation_map citations=info.message_specific_citations.citation_map
if info.message_specific_citations if info.message_specific_citations
else None, else None,
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
refined_answer_improvement=refined_answer_improvement, refined_answer_improvement=refined_answer_improvement,
is_agentic=True,
) )
next_level += 1 next_level += 1
prev_message = next_answer_message prev_message = next_answer_message

View File

@@ -629,6 +629,7 @@ def create_new_chat_message(
reserved_message_id: int | None = None, reserved_message_id: int | None = None,
overridden_model: str | None = None, overridden_model: str | None = None,
refined_answer_improvement: bool | None = None, refined_answer_improvement: bool | None = None,
is_agentic: bool = False,
) -> ChatMessage: ) -> ChatMessage:
if reserved_message_id is not None: if reserved_message_id is not None:
# Edit existing message # Edit existing message
@@ -650,7 +651,7 @@ def create_new_chat_message(
existing_message.alternate_assistant_id = alternate_assistant_id existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model existing_message.overridden_model = overridden_model
existing_message.refined_answer_improvement = refined_answer_improvement existing_message.refined_answer_improvement = refined_answer_improvement
existing_message.is_agentic = is_agentic
new_chat_message = existing_message new_chat_message = existing_message
else: else:
# Create new message # Create new message
@@ -670,6 +671,7 @@ def create_new_chat_message(
alternate_assistant_id=alternate_assistant_id, alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model, overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement, refined_answer_improvement=refined_answer_improvement,
is_agentic=is_agentic,
) )
db_session.add(new_chat_message) db_session.add(new_chat_message)
@@ -960,6 +962,7 @@ def translate_db_message_to_chat_message_detail(
chat_message.sub_questions chat_message.sub_questions
), ),
refined_answer_improvement=chat_message.refined_answer_improvement, refined_answer_improvement=chat_message.refined_answer_improvement,
error=chat_message.error,
) )
return chat_msg_detail return chat_msg_detail

View File

@@ -1221,6 +1221,7 @@ class ChatMessage(Base):
DateTime(timezone=True), server_default=func.now() DateTime(timezone=True), server_default=func.now()
) )
is_agentic: Mapped[bool] = mapped_column(Boolean, default=False)
refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, nullable=True) refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, nullable=True)
chat_session: Mapped[ChatSession] = relationship("ChatSession") chat_session: Mapped[ChatSession] = relationship("ChatSession")

View File

@@ -240,6 +240,7 @@ class ChatMessageDetail(BaseModel):
files: list[FileDescriptor] files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None tool_call: ToolCallFinalResult | None
refined_answer_improvement: bool | None = None refined_answer_improvement: bool | None = None
error: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore