diff --git a/backend/alembic/versions/211b14ab5a91_refined_answer_improvement.py b/backend/alembic/versions/211b14ab5a91_refined_answer_improvement.py new file mode 100644 index 00000000000..03aaddacc02 --- /dev/null +++ b/backend/alembic/versions/211b14ab5a91_refined_answer_improvement.py @@ -0,0 +1,32 @@ +"""refined answer improvement + +Revision ID: 211b14ab5a91 +Revises: 925b58bd75b6 +Create Date: 2025-01-24 14:05:03.334309 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "211b14ab5a91" +down_revision = "925b58bd75b6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_message", + sa.Column( + "refined_answer_improvement", + sa.Boolean(), + server_default=sa.true(), + nullable=False, + ), + ) + + +def downgrade() -> None: + op.drop_column("chat_message", "refined_answer_improvement") diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 2d64236978a..7f55275fcb5 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -30,6 +30,7 @@ from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import QADocsResponse +from onyx.chat.models import RefinedAnswerImprovement from onyx.chat.models import StreamingError from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason @@ -829,6 +830,7 @@ def stream_chat_message_objects( info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict( lambda: AnswerPostInfo(ai_message_files=[]) ) + refined_answer_improvement = True for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): level, level_question_nr = ( @@ -954,6 +956,9 @@ def stream_chat_message_objects( elif isinstance(packet, StreamStopInfo): if packet.stop_reason == StreamStopReason.FINISHED: yield packet + elif isinstance(packet, RefinedAnswerImprovement): + refined_answer_improvement = packet.refined_answer_improvement + yield packet else: if isinstance(packet, ToolCallFinalResult): level, level_question_nr = ( @@ -1067,6 +1072,7 @@ def stream_chat_message_objects( citations=info.message_specific_citations.citation_map if info.message_specific_citations else None, + refined_answer_improvement=refined_answer_improvement, ) next_level += 1 prev_message = next_answer_message diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index abb00171641..ebd7de5303e 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -616,6 +616,7 @@ def create_new_chat_message( commit: bool = True, reserved_message_id: int | None = None, overridden_model: str | None = None, + refined_answer_improvement: bool = True, ) -> ChatMessage: if reserved_message_id is not None: # Edit existing message @@ -636,6 +637,7 @@ def create_new_chat_message( existing_message.error = error existing_message.alternate_assistant_id = alternate_assistant_id existing_message.overridden_model = overridden_model + existing_message.refined_answer_improvement = refined_answer_improvement new_chat_message = existing_message else: @@ -655,6 +657,7 @@ def create_new_chat_message( error=error, alternate_assistant_id=alternate_assistant_id, overridden_model=overridden_model, + refined_answer_improvement=refined_answer_improvement, ) db_session.add(new_chat_message) diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index f6a94bfc4ea..51694d06b61 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -1204,6 +1204,8 @@ class ChatMessage(Base): DateTime(timezone=True), server_default=func.now() ) + refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, default=True) + chat_session: Mapped[ChatSession] = relationship("ChatSession") prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")