new is_agentic flag for chatmessages (#4026)

* new is_agentic flag for chatmessages

* added cancelled error to db

* added cancelled error to returned message
This commit is contained in:
evan-danswer 2025-02-17 20:20:33 -08:00 committed by GitHub
parent 045a41d929
commit 2b2ba5478c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 54 additions and 2 deletions

View File

@ -0,0 +1,43 @@
"""chat_message_agentic
Revision ID: 9c00a2bccb83
Revises: b7a7eee5aa15
Create Date: 2025-02-17 11:15:43.081150
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c00a2bccb83"
down_revision = "b7a7eee5aa15"
branch_labels = None
depends_on = None
def upgrade() -> None:
# First add the column as nullable
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
# Update existing rows based on presence of SubQuestions
op.execute(
"""
UPDATE chat_message
SET is_agentic = EXISTS (
SELECT 1
FROM agent__sub_question
WHERE agent__sub_question.primary_question_id = chat_message.id
)
WHERE is_agentic IS NULL
"""
)
# Make the column non-nullable with a default value of False
op.alter_column(
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
)
def downgrade() -> None:
op.drop_column("chat_message", "is_agentic")

View File

@ -146,6 +146,7 @@ from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
def _translate_citations(
@ -631,6 +632,7 @@ def stream_chat_message_objects(
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
@ -1015,7 +1017,7 @@ def stream_chat_message_objects(
if info.message_specific_citations
else None
),
error=None,
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
@ -1053,7 +1055,9 @@ def stream_chat_message_objects(
citations=info.message_specific_citations.citation_map
if info.message_specific_citations
else None,
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
refined_answer_improvement=refined_answer_improvement,
is_agentic=True,
)
next_level += 1
prev_message = next_answer_message

View File

@ -629,6 +629,7 @@ def create_new_chat_message(
reserved_message_id: int | None = None,
overridden_model: str | None = None,
refined_answer_improvement: bool | None = None,
is_agentic: bool = False,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message
@ -650,7 +651,7 @@ def create_new_chat_message(
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
existing_message.refined_answer_improvement = refined_answer_improvement
existing_message.is_agentic = is_agentic
new_chat_message = existing_message
else:
# Create new message
@ -670,6 +671,7 @@ def create_new_chat_message(
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement,
is_agentic=is_agentic,
)
db_session.add(new_chat_message)
@ -960,6 +962,7 @@ def translate_db_message_to_chat_message_detail(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
error=chat_message.error,
)
return chat_msg_detail

View File

@ -1221,6 +1221,7 @@ class ChatMessage(Base):
DateTime(timezone=True), server_default=func.now()
)
is_agentic: Mapped[bool] = mapped_column(Boolean, default=False)
refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, nullable=True)
chat_session: Mapped[ChatSession] = relationship("ChatSession")

View File

@ -240,6 +240,7 @@ class ChatMessageDetail(BaseModel):
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
refined_answer_improvement: bool | None = None
error: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore