persisting refined answer improvement

This commit is contained in:
Evan Lohn
2025-01-24 14:29:05 -08:00
parent 2bbe20edc3
commit 5e9b2e41ae
4 changed files with 43 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
"""refined answer improvement
Revision ID: 211b14ab5a91
Revises: 925b58bd75b6
Create Date: 2025-01-24 14:05:03.334309
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "211b14ab5a91"
down_revision = "925b58bd75b6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"refined_answer_improvement",
sa.Boolean(),
server_default=sa.true(),
nullable=False,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "refined_answer_improvement")

View File

@@ -30,6 +30,7 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamStopReason
@@ -829,6 +830,7 @@ def stream_chat_message_objects(
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict( info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[]) lambda: AnswerPostInfo(ai_message_files=[])
) )
refined_answer_improvement = True
for packet in answer.processed_streamed_output: for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse): if isinstance(packet, ToolResponse):
level, level_question_nr = ( level, level_question_nr = (
@@ -954,6 +956,9 @@ def stream_chat_message_objects(
elif isinstance(packet, StreamStopInfo): elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED: if packet.stop_reason == StreamStopReason.FINISHED:
yield packet yield packet
elif isinstance(packet, RefinedAnswerImprovement):
refined_answer_improvement = packet.refined_answer_improvement
yield packet
else: else:
if isinstance(packet, ToolCallFinalResult): if isinstance(packet, ToolCallFinalResult):
level, level_question_nr = ( level, level_question_nr = (
@@ -1067,6 +1072,7 @@ def stream_chat_message_objects(
citations=info.message_specific_citations.citation_map citations=info.message_specific_citations.citation_map
if info.message_specific_citations if info.message_specific_citations
else None, else None,
refined_answer_improvement=refined_answer_improvement,
) )
next_level += 1 next_level += 1
prev_message = next_answer_message prev_message = next_answer_message

View File

@@ -616,6 +616,7 @@ def create_new_chat_message(
commit: bool = True, commit: bool = True,
reserved_message_id: int | None = None, reserved_message_id: int | None = None,
overridden_model: str | None = None, overridden_model: str | None = None,
refined_answer_improvement: bool = True,
) -> ChatMessage: ) -> ChatMessage:
if reserved_message_id is not None: if reserved_message_id is not None:
# Edit existing message # Edit existing message
@@ -636,6 +637,7 @@ def create_new_chat_message(
existing_message.error = error existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model existing_message.overridden_model = overridden_model
existing_message.refined_answer_improvement = refined_answer_improvement
new_chat_message = existing_message new_chat_message = existing_message
else: else:
@@ -655,6 +657,7 @@ def create_new_chat_message(
error=error, error=error,
alternate_assistant_id=alternate_assistant_id, alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model, overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement,
) )
db_session.add(new_chat_message) db_session.add(new_chat_message)

View File

@@ -1204,6 +1204,8 @@ class ChatMessage(Base):
DateTime(timezone=True), server_default=func.now() DateTime(timezone=True), server_default=func.now()
) )
refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, default=True)
chat_session: Mapped[ChatSession] = relationship("ChatSession") chat_session: Mapped[ChatSession] = relationship("ChatSession")
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt") prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")