persisting refined answer improvement

This commit is contained in:
Evan Lohn
2025-01-24 14:29:05 -08:00
parent 2bbe20edc3
commit 5e9b2e41ae
4 changed files with 43 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
"""refined answer improvement
Revision ID: 211b14ab5a91
Revises: 925b58bd75b6
Create Date: 2025-01-24 14:05:03.334309
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "211b14ab5a91"
down_revision = "925b58bd75b6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"refined_answer_improvement",
sa.Boolean(),
server_default=sa.true(),
nullable=False,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "refined_answer_improvement")

View File

@@ -30,6 +30,7 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
@@ -829,6 +830,7 @@ def stream_chat_message_objects(
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[])
)
refined_answer_improvement = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
level, level_question_nr = (
@@ -954,6 +956,9 @@ def stream_chat_message_objects(
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:
yield packet
elif isinstance(packet, RefinedAnswerImprovement):
refined_answer_improvement = packet.refined_answer_improvement
yield packet
else:
if isinstance(packet, ToolCallFinalResult):
level, level_question_nr = (
@@ -1067,6 +1072,7 @@ def stream_chat_message_objects(
citations=info.message_specific_citations.citation_map
if info.message_specific_citations
else None,
refined_answer_improvement=refined_answer_improvement,
)
next_level += 1
prev_message = next_answer_message

View File

@@ -616,6 +616,7 @@ def create_new_chat_message(
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
refined_answer_improvement: bool = True,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message
@@ -636,6 +637,7 @@ def create_new_chat_message(
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
existing_message.refined_answer_improvement = refined_answer_improvement
new_chat_message = existing_message
else:
@@ -655,6 +657,7 @@ def create_new_chat_message(
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement,
)
db_session.add(new_chat_message)

View File

@@ -1204,6 +1204,8 @@ class ChatMessage(Base):
DateTime(timezone=True), server_default=func.now()
)
refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, default=True)
chat_session: Mapped[ChatSession] = relationship("ChatSession")
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")