mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
persisting refined answer improvement
This commit is contained in:
@@ -0,0 +1,32 @@
|
|||||||
|
"""refined answer improvement
|
||||||
|
|
||||||
|
Revision ID: 211b14ab5a91
|
||||||
|
Revises: 925b58bd75b6
|
||||||
|
Create Date: 2025-01-24 14:05:03.334309
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "211b14ab5a91"
|
||||||
|
down_revision = "925b58bd75b6"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"chat_message",
|
||||||
|
sa.Column(
|
||||||
|
"refined_answer_improvement",
|
||||||
|
sa.Boolean(),
|
||||||
|
server_default=sa.true(),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("chat_message", "refined_answer_improvement")
|
@@ -30,6 +30,7 @@ from onyx.chat.models import OnyxAnswerPiece
|
|||||||
from onyx.chat.models import OnyxContexts
|
from onyx.chat.models import OnyxContexts
|
||||||
from onyx.chat.models import PromptConfig
|
from onyx.chat.models import PromptConfig
|
||||||
from onyx.chat.models import QADocsResponse
|
from onyx.chat.models import QADocsResponse
|
||||||
|
from onyx.chat.models import RefinedAnswerImprovement
|
||||||
from onyx.chat.models import StreamingError
|
from onyx.chat.models import StreamingError
|
||||||
from onyx.chat.models import StreamStopInfo
|
from onyx.chat.models import StreamStopInfo
|
||||||
from onyx.chat.models import StreamStopReason
|
from onyx.chat.models import StreamStopReason
|
||||||
@@ -829,6 +830,7 @@ def stream_chat_message_objects(
|
|||||||
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
|
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
|
||||||
lambda: AnswerPostInfo(ai_message_files=[])
|
lambda: AnswerPostInfo(ai_message_files=[])
|
||||||
)
|
)
|
||||||
|
refined_answer_improvement = True
|
||||||
for packet in answer.processed_streamed_output:
|
for packet in answer.processed_streamed_output:
|
||||||
if isinstance(packet, ToolResponse):
|
if isinstance(packet, ToolResponse):
|
||||||
level, level_question_nr = (
|
level, level_question_nr = (
|
||||||
@@ -954,6 +956,9 @@ def stream_chat_message_objects(
|
|||||||
elif isinstance(packet, StreamStopInfo):
|
elif isinstance(packet, StreamStopInfo):
|
||||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||||
yield packet
|
yield packet
|
||||||
|
elif isinstance(packet, RefinedAnswerImprovement):
|
||||||
|
refined_answer_improvement = packet.refined_answer_improvement
|
||||||
|
yield packet
|
||||||
else:
|
else:
|
||||||
if isinstance(packet, ToolCallFinalResult):
|
if isinstance(packet, ToolCallFinalResult):
|
||||||
level, level_question_nr = (
|
level, level_question_nr = (
|
||||||
@@ -1067,6 +1072,7 @@ def stream_chat_message_objects(
|
|||||||
citations=info.message_specific_citations.citation_map
|
citations=info.message_specific_citations.citation_map
|
||||||
if info.message_specific_citations
|
if info.message_specific_citations
|
||||||
else None,
|
else None,
|
||||||
|
refined_answer_improvement=refined_answer_improvement,
|
||||||
)
|
)
|
||||||
next_level += 1
|
next_level += 1
|
||||||
prev_message = next_answer_message
|
prev_message = next_answer_message
|
||||||
|
@@ -616,6 +616,7 @@ def create_new_chat_message(
|
|||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
reserved_message_id: int | None = None,
|
reserved_message_id: int | None = None,
|
||||||
overridden_model: str | None = None,
|
overridden_model: str | None = None,
|
||||||
|
refined_answer_improvement: bool = True,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
if reserved_message_id is not None:
|
if reserved_message_id is not None:
|
||||||
# Edit existing message
|
# Edit existing message
|
||||||
@@ -636,6 +637,7 @@ def create_new_chat_message(
|
|||||||
existing_message.error = error
|
existing_message.error = error
|
||||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||||
existing_message.overridden_model = overridden_model
|
existing_message.overridden_model = overridden_model
|
||||||
|
existing_message.refined_answer_improvement = refined_answer_improvement
|
||||||
|
|
||||||
new_chat_message = existing_message
|
new_chat_message = existing_message
|
||||||
else:
|
else:
|
||||||
@@ -655,6 +657,7 @@ def create_new_chat_message(
|
|||||||
error=error,
|
error=error,
|
||||||
alternate_assistant_id=alternate_assistant_id,
|
alternate_assistant_id=alternate_assistant_id,
|
||||||
overridden_model=overridden_model,
|
overridden_model=overridden_model,
|
||||||
|
refined_answer_improvement=refined_answer_improvement,
|
||||||
)
|
)
|
||||||
db_session.add(new_chat_message)
|
db_session.add(new_chat_message)
|
||||||
|
|
||||||
|
@@ -1204,6 +1204,8 @@ class ChatMessage(Base):
|
|||||||
DateTime(timezone=True), server_default=func.now()
|
DateTime(timezone=True), server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
|
||||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||||
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")
|
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user