mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-30 09:40:35 +02:00
refactor to use stricter typing (#4513)
* refactor to use stricter typing * older version of ruff
This commit is contained in:
parent
a5edc8aa0f
commit
68c6c1f4f8
@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
|
||||
(
|
||||
or_(
|
||||
ChatMessageFeedback.is_positive.is_(False),
|
||||
ChatMessageFeedback.required_followup,
|
||||
ChatMessageFeedback.required_followup.is_(True),
|
||||
),
|
||||
1,
|
||||
),
|
||||
@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
|
||||
.all()
|
||||
)
|
||||
|
||||
return results
|
||||
return [tuple(row) for row in results]
|
||||
|
||||
|
||||
def fetch_persona_message_analytics(
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -39,10 +41,10 @@ logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
@ -50,13 +52,14 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
|
||||
"distilbert-base-uncased"
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
@ -92,12 +95,15 @@ def get_local_connector_classifier(
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
_INTENT_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
@ -395,9 +401,9 @@ def run_content_classification_inference(
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
||||
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
if not len(tokens) == len(is_keyword):
|
||||
raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -13,15 +14,14 @@ class HybridClassifier(nn.Module):
|
||||
super().__init__()
|
||||
config = DistilBertConfig()
|
||||
self.distilbert = DistilBertModel(config)
|
||||
config = self.distilbert.config # type: ignore
|
||||
|
||||
# Keyword tokenwise binary classification layer
|
||||
self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||
self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
# Intent Classifier layers
|
||||
self.pre_classifier = nn.Linear(
|
||||
self.distilbert.config.dim, self.distilbert.config.dim
|
||||
)
|
||||
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
self.intent_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@ -30,7 +30,7 @@ class HybridClassifier(nn.Module):
|
||||
query_ids: torch.Tensor,
|
||||
query_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
# Intent classification on the CLS token
|
||||
@ -79,8 +79,9 @@ class ConnectorClassifier(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.distilbert = DistilBertModel(config)
|
||||
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
config = self.distilbert.config # type: ignore
|
||||
self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# Token indicating end of connector name, and on which classifier is used
|
||||
@ -95,7 +96,7 @@ class ConnectorClassifier(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert(
|
||||
hidden_states = self.distilbert( # type: ignore
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
@ -114,7 +115,10 @@ class ConnectorClassifier(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
|
||||
config = cast(
|
||||
DistilBertConfig,
|
||||
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
)
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
|
@ -2,9 +2,11 @@ import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -82,6 +84,8 @@ from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@ -159,6 +163,25 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
COMMON_TOOL_RESPONSE_TYPES = {
|
||||
"image": ChatFileType.IMAGE,
|
||||
"csv": ChatFileType.CSV,
|
||||
}
|
||||
|
||||
|
||||
class PartialResponse(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
message: str,
|
||||
rephrased_query: str | None,
|
||||
reference_docs: list[DbSearchDoc] | None,
|
||||
files: list[FileDescriptor],
|
||||
token_count: int,
|
||||
citations: dict[int, int] | None,
|
||||
error: str | None,
|
||||
tool_call: ToolCall | None,
|
||||
) -> ChatMessage: ...
|
||||
|
||||
|
||||
def _translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
@ -211,25 +234,25 @@ def _handle_search_tool_response_summary(
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None:
|
||||
if user_files is not None and loaded_user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id not in doc_ids:
|
||||
associated_chat_file = None
|
||||
if loaded_user_files is not None:
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
if user_file.id in doc_ids:
|
||||
continue
|
||||
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
@ -357,6 +380,86 @@ def _get_force_search_settings(
|
||||
)
|
||||
|
||||
|
||||
def _get_user_knowledge_files(
|
||||
info: AnswerPostInfo,
|
||||
user_files: list[InMemoryChatFile],
|
||||
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||
) -> Generator[UserKnowledgeFilePacket, None, None]:
|
||||
if not info.qa_docs_response:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||
)
|
||||
|
||||
# Extract document order from search results
|
||||
doc_order = []
|
||||
for doc in info.qa_docs_response.top_documents:
|
||||
doc_id = doc.document_id
|
||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||
if file_id in file_id_to_user_file:
|
||||
doc_order.append(file_id)
|
||||
|
||||
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
|
||||
|
||||
# Add any files that weren't in search results at the end
|
||||
missing_files = [
|
||||
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
|
||||
]
|
||||
|
||||
missing_files.extend(doc_order)
|
||||
doc_order = missing_files
|
||||
|
||||
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
|
||||
|
||||
# Reorder user files based on search results
|
||||
ordered_user_files = [
|
||||
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
|
||||
]
|
||||
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id),
|
||||
type=ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
for file in ordered_user_files
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_persona_for_chat_session(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
default_persona: Persona,
|
||||
) -> Persona:
|
||||
if new_msg_req.alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
new_msg_req.alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = default_persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
return persona
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
@ -378,6 +481,149 @@ ChatPacket = (
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
def _process_tool_response(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_db_search_docs: list[DbSearchDoc] | None,
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
user_file_files: list[UserFile] | None,
|
||||
user_files: list[InMemoryChatFile] | None,
|
||||
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||
search_for_ordering_only: bool,
|
||||
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
|
||||
|
||||
# Skip LLM relevance processing entirely for ordering-only mode
|
||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
logger.info(
|
||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||
)
|
||||
# Skip this packet entirely since it would trigger LLM processing
|
||||
return info_by_subq
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||
)
|
||||
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
# Skip deduping completely for ordering-only mode to save time
|
||||
dedupe_docs=bool(
|
||||
not search_for_ordering_only
|
||||
and retrieval_options
|
||||
and retrieval_options.dedupe_docs
|
||||
),
|
||||
user_files=user_file_files if search_for_ordering_only else [],
|
||||
loaded_user_files=(user_files if search_for_ordering_only else []),
|
||||
)
|
||||
|
||||
# If we're using search just for ordering user files
|
||||
if search_for_ordering_only and user_files:
|
||||
yield from _get_user_knowledge_files(
|
||||
info=info,
|
||||
user_files=user_files,
|
||||
file_id_to_user_file=file_id_to_user_file,
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||
)
|
||||
return info_by_subq
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning("No reference docs found for relevance filtering")
|
||||
return info_by_subq
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data for img in img_generation_response if img.image_data
|
||||
],
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
response_type = custom_tool_response.response_type
|
||||
if response_type in COMMON_TOOL_RESPONSE_TYPES:
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=file_type)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
return info_by_subq
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@ -421,7 +667,6 @@ def stream_chat_message_objects(
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
file_id_to_user_file = {}
|
||||
ordered_user_files = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@ -436,35 +681,19 @@ def stream_chat_message_objects(
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
||||
new_msg_req.alternate_assistant_id
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
persona = _get_persona_for_chat_session(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
default_persona=chat_session.persona,
|
||||
)
|
||||
|
||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||
user=user,
|
||||
@ -746,31 +975,42 @@ def stream_chat_message_objects(
|
||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||
)
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
# error=,
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
def create_response(
|
||||
message: str,
|
||||
rephrased_query: str | None,
|
||||
reference_docs: list[DbSearchDoc] | None,
|
||||
files: list[FileDescriptor],
|
||||
token_count: int,
|
||||
citations: dict[int, int] | None,
|
||||
error: str | None,
|
||||
tool_call: ToolCall | None,
|
||||
) -> ChatMessage:
|
||||
return create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=(
|
||||
final_msg
|
||||
if existing_assistant_message_id is None
|
||||
else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message=message,
|
||||
rephrased_query=rephrased_query,
|
||||
token_count=token_count,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
error=error,
|
||||
reference_docs=reference_docs,
|
||||
files=files,
|
||||
citations=citations,
|
||||
tool_call=tool_call,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
partial_response = create_response
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
@ -1041,220 +1281,23 @@ def stream_chat_message_objects(
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
# reference_db_search_docs = None
|
||||
# qa_docs_response = None
|
||||
# # any files to associate with the AI message e.g. dall-e generated images
|
||||
# ai_message_files = []
|
||||
# dropped_indices = None
|
||||
# tool_result = None
|
||||
|
||||
# TODO: different channels for stored info when it's coming from the agent flow
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
)
|
||||
refined_answer_improvement = True
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
info_by_subq = yield from _process_tool_response(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_db_search_docs=selected_db_search_docs,
|
||||
info_by_subq=info_by_subq,
|
||||
retrieval_options=retrieval_options,
|
||||
user_file_files=user_file_files,
|
||||
user_files=user_files,
|
||||
file_id_to_user_file=file_id_to_user_file,
|
||||
search_for_ordering_only=search_for_ordering_only,
|
||||
)
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
|
||||
# Skip LLM relevance processing entirely for ordering-only mode
|
||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
logger.info(
|
||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||
)
|
||||
# Skip this packet entirely since it would trigger LLM processing
|
||||
continue
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||
)
|
||||
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
# Skip deduping completely for ordering-only mode to save time
|
||||
dedupe_docs=bool(
|
||||
not search_for_ordering_only
|
||||
and retrieval_options
|
||||
and retrieval_options.dedupe_docs
|
||||
),
|
||||
user_files=user_file_files if search_for_ordering_only else [],
|
||||
loaded_user_files=(
|
||||
user_files if search_for_ordering_only else []
|
||||
),
|
||||
)
|
||||
|
||||
# If we're using search just for ordering user files
|
||||
if (
|
||||
search_for_ordering_only
|
||||
and user_files
|
||||
and info.qa_docs_response
|
||||
):
|
||||
logger.info(
|
||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||
)
|
||||
|
||||
# Extract document order from search results
|
||||
doc_order = []
|
||||
for doc in info.qa_docs_response.top_documents:
|
||||
doc_id = doc.document_id
|
||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||
if file_id in file_id_to_user_file:
|
||||
doc_order.append(file_id)
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Found {len(doc_order)} files from search results"
|
||||
)
|
||||
|
||||
# Add any files that weren't in search results at the end
|
||||
missing_files = [
|
||||
f_id
|
||||
for f_id in file_id_to_user_file.keys()
|
||||
if f_id not in doc_order
|
||||
]
|
||||
|
||||
missing_files.extend(doc_order)
|
||||
doc_order = missing_files
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Added {len(missing_files)} missing files to the end"
|
||||
)
|
||||
|
||||
# Reorder user files based on search results
|
||||
ordered_user_files = [
|
||||
file_id_to_user_file[f_id]
|
||||
for f_id in doc_order
|
||||
if f_id in file_id_to_user_file
|
||||
]
|
||||
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id),
|
||||
type=ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
for file in ordered_user_files
|
||||
]
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||
)
|
||||
continue
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning(
|
||||
"No reference docs found for relevance filtering"
|
||||
)
|
||||
continue
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
@ -1291,22 +1334,46 @@ def stream_chat_message_objects(
|
||||
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(error=error_msg, stack_trace=stack_trace)
|
||||
else:
|
||||
if llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
elif llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
client_error_msg = client_error_msg.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
yield from _post_llm_answer_processing(
|
||||
answer=answer,
|
||||
info_by_subq=info_by_subq,
|
||||
tool_dict=tool_dict,
|
||||
partial_response=partial_response,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
)
|
||||
|
||||
|
||||
def _post_llm_answer_processing(
|
||||
answer: Answer,
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||
tool_dict: dict[int, list[Tool]],
|
||||
partial_response: PartialResponse,
|
||||
llm_tokenizer_encode_func: Callable[[str], list[int]],
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
refined_answer_improvement: bool | None,
|
||||
) -> Generator[ChatPacket, None, None]:
|
||||
"""
|
||||
Stores messages in the db and yields some final packets to the frontend
|
||||
"""
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
|
@ -56,7 +56,8 @@ def get_total_users_count(db_session: Session) -> int:
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_with_tenant() as session:
|
||||
stmt = select(func.count(User.id))
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
stmt = select(count_stmt)
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
|
Loading…
x
Reference in New Issue
Block a user