refactor to use stricter typing (#4513)

* refactor to use stricter typing

* older version of ruff
This commit is contained in:
evan-danswer 2025-04-14 10:23:07 -07:00 committed by GitHub
parent a5edc8aa0f
commit 68c6c1f4f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 386 additions and 308 deletions

View File

@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
( (
or_( or_(
ChatMessageFeedback.is_positive.is_(False), ChatMessageFeedback.is_positive.is_(False),
ChatMessageFeedback.required_followup, ChatMessageFeedback.required_followup.is_(True),
), ),
1, 1,
), ),
@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
.all() .all()
) )
return results return [tuple(row) for row in results]
def fetch_persona_message_analytics( def fetch_persona_message_analytics(

View File

@ -1,3 +1,5 @@
from typing import cast
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -39,10 +41,10 @@ logger = setup_logger()
router = APIRouter(prefix="/custom") router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None _CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None _INTENT_TOKENIZER: PreTrainedTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None _INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
@ -50,13 +52,14 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version! _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> AutoTokenizer: def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None: if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the # The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer. # unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained( _CONNECTOR_CLASSIFIER_TOKENIZER = cast(
"distilbert-base-uncased" PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
) )
return _CONNECTOR_CLASSIFIER_TOKENIZER return _CONNECTOR_CLASSIFIER_TOKENIZER
@ -92,12 +95,15 @@ def get_local_connector_classifier(
return _CONNECTOR_CLASSIFIER_MODEL return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> AutoTokenizer: def get_intent_model_tokenizer() -> PreTrainedTokenizer:
global _INTENT_TOKENIZER global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None: if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the # The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer. # unmodified distilbert tokenizer.
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased") _INTENT_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
)
return _INTENT_TOKENIZER return _INTENT_TOKENIZER
@ -395,9 +401,9 @@ def run_content_classification_inference(
def map_keywords( def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool] input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
) -> list[str]: ) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids) tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
if not len(tokens) == len(is_keyword): if not len(tokens) == len(is_keyword):
raise ValueError("Length of tokens and keyword predictions must match") raise ValueError("Length of tokens and keyword predictions must match")

View File

@ -1,5 +1,6 @@
import json import json
import os import os
from typing import cast
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -13,15 +14,14 @@ class HybridClassifier(nn.Module):
super().__init__() super().__init__()
config = DistilBertConfig() config = DistilBertConfig()
self.distilbert = DistilBertModel(config) self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
# Keyword tokenwise binary classification layer # Keyword tokenwise binary classification layer
self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2) self.keyword_classifier = nn.Linear(config.dim, 2)
# Intent Classifier layers # Intent Classifier layers
self.pre_classifier = nn.Linear( self.pre_classifier = nn.Linear(config.dim, config.dim)
self.distilbert.config.dim, self.distilbert.config.dim self.intent_classifier = nn.Linear(config.dim, 2)
)
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.device = torch.device("cpu") self.device = torch.device("cpu")
@ -30,7 +30,7 @@ class HybridClassifier(nn.Module):
query_ids: torch.Tensor, query_ids: torch.Tensor,
query_mask: torch.Tensor, query_mask: torch.Tensor,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
sequence_output = outputs.last_hidden_state sequence_output = outputs.last_hidden_state
# Intent classification on the CLS token # Intent classification on the CLS token
@ -79,8 +79,9 @@ class ConnectorClassifier(nn.Module):
self.config = config self.config = config
self.distilbert = DistilBertModel(config) self.distilbert = DistilBertModel(config)
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1) config = self.distilbert.config # type: ignore
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1) self.connector_global_classifier = nn.Linear(config.dim, 1)
self.connector_match_classifier = nn.Linear(config.dim, 1)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Token indicating end of connector name, and on which classifier is used # Token indicating end of connector name, and on which classifier is used
@ -95,7 +96,7 @@ class ConnectorClassifier(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert( hidden_states = self.distilbert( # type: ignore
input_ids=input_ids, attention_mask=attention_mask input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state ).last_hidden_state
@ -114,7 +115,10 @@ class ConnectorClassifier(nn.Module):
@classmethod @classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier": def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")) config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
)
device = ( device = (
torch.device("cuda") torch.device("cuda")
if torch.cuda.is_available() if torch.cuda.is_available()

View File

@ -2,9 +2,11 @@ import time
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator from collections.abc import Iterator
from functools import partial
from typing import cast from typing import cast
from typing import Protocol
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -82,6 +84,8 @@ from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import ChatMessage
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall from onyx.db.models import ToolCall
from onyx.db.models import User from onyx.db.models import User
@ -159,6 +163,25 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger() logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled" ERROR_TYPE_CANCELLED = "cancelled"
COMMON_TOOL_RESPONSE_TYPES = {
"image": ChatFileType.IMAGE,
"csv": ChatFileType.CSV,
}
class PartialResponse(Protocol):
def __call__(
self,
message: str,
rephrased_query: str | None,
reference_docs: list[DbSearchDoc] | None,
files: list[FileDescriptor],
token_count: int,
citations: dict[int, int] | None,
error: str | None,
tool_call: ToolCall | None,
) -> ChatMessage: ...
def _translate_citations( def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
@ -211,25 +234,25 @@ def _handle_search_tool_response_summary(
reference_db_search_docs = selected_search_docs reference_db_search_docs = selected_search_docs
doc_ids = {doc.id for doc in reference_db_search_docs} doc_ids = {doc.id for doc in reference_db_search_docs}
if user_files is not None: if user_files is not None and loaded_user_files is not None:
for user_file in user_files: for user_file in user_files:
if user_file.id not in doc_ids: if user_file.id in doc_ids:
associated_chat_file = None continue
if loaded_user_files is not None:
associated_chat_file = next( associated_chat_file = next(
( (
file file
for file in loaded_user_files for file in loaded_user_files
if file.file_id == str(user_file.file_id) if file.file_id == str(user_file.file_id)
), ),
None, None,
) )
# Use create_search_doc_from_user_file to properly add the document to the database # Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None: if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file( db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session user_file, associated_chat_file, db_session
) )
reference_db_search_docs.append(db_doc) reference_db_search_docs.append(db_doc)
response_docs = [ response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc) translate_db_search_doc_to_server_search_doc(db_search_doc)
@ -357,6 +380,86 @@ def _get_force_search_settings(
) )
def _get_user_knowledge_files(
info: AnswerPostInfo,
user_files: list[InMemoryChatFile],
file_id_to_user_file: dict[str, InMemoryChatFile],
) -> Generator[UserKnowledgeFilePacket, None, None]:
if not info.qa_docs_response:
return
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
# Add any files that weren't in search results at the end
missing_files = [
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
]
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
def _get_persona_for_chat_session(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
default_persona: Persona,
) -> Persona:
if new_msg_req.alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
new_msg_req.alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = default_persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
return persona
ChatPacket = ( ChatPacket = (
StreamingError StreamingError
| QADocsResponse | QADocsResponse
@ -378,6 +481,149 @@ ChatPacket = (
ChatPacketStream = Iterator[ChatPacket] ChatPacketStream = Iterator[ChatPacket]
def _process_tool_response(
packet: ToolResponse,
db_session: Session,
selected_db_search_docs: list[DbSearchDoc] | None,
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
retrieval_options: RetrievalDetails | None,
user_file_files: list[UserFile] | None,
user_files: list[InMemoryChatFile] | None,
file_id_to_user_file: dict[str, InMemoryChatFile],
search_for_ordering_only: bool,
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
assert level is not None
assert level_question_num is not None
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
return info_by_subq
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=bool(
not search_for_ordering_only
and retrieval_options
and retrieval_options.dedupe_docs
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=(user_files if search_for_ordering_only else []),
)
# If we're using search just for ordering user files
if search_for_ordering_only and user_files:
yield from _get_user_knowledge_files(
info=info,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
return info_by_subq
if info.reference_db_search_docs is None:
logger.warning("No reference docs found for relevance filtering")
return info_by_subq
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data for img in img_generation_response if img.image_data
],
)
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
response_type = custom_tool_response.response_type
if response_type in COMMON_TOOL_RESPONSE_TYPES:
file_ids = custom_tool_response.tool_result.file_ids
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=file_type)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
return info_by_subq
def stream_chat_message_objects( def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest, new_msg_req: CreateChatMessageRequest,
user: User | None, user: User | None,
@ -421,7 +667,6 @@ def stream_chat_message_objects(
try: try:
# Move these variables inside the try block # Move these variables inside the try block
file_id_to_user_file = {} file_id_to_user_file = {}
ordered_user_files = None
user_id = user.id if user is not None else None user_id = user.id if user is not None else None
@ -436,35 +681,19 @@ def stream_chat_message_objects(
parent_id = new_msg_req.parent_message_id parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options retrieval_options = new_msg_req.retrieval_options
alternate_assistant_id = new_msg_req.alternate_assistant_id new_msg_req.alternate_assistant_id
# permanent "log" store, used primarily for debugging # permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger( long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)} metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
) )
if alternate_assistant_id is not None: persona = _get_persona_for_chat_session(
# Allows users to specify a temporary persona (assistant) in the chat session new_msg_req=new_msg_req,
# this takes highest priority since it's user specified user=user,
persona = get_persona_by_id( db_session=db_session,
alternate_assistant_id, default_persona=chat_session.persona,
user=user, )
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = chat_session.persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
multi_assistant_milestone, _is_new = create_milestone_if_not_exists( multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
user=user, user=user,
@ -746,31 +975,42 @@ def stream_chat_message_objects(
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
) )
# Cannot determine these without the LLM step or breaking out early def create_response(
partial_response = partial( message: str,
create_new_chat_message, rephrased_query: str | None,
chat_session_id=chat_session_id, reference_docs: list[DbSearchDoc] | None,
# if we're using an existing assistant message, then this will just be an files: list[FileDescriptor],
# update operation, in which case the parent should be the parent of token_count: int,
# the latest. If we're creating a new assistant message, then the parent citations: dict[int, int] | None,
# should be the latest message (latest user message) error: str | None,
parent_message=( tool_call: ToolCall | None,
final_msg if existing_assistant_message_id is None else parent_message ) -> ChatMessage:
), return create_new_chat_message(
prompt_id=prompt_id, chat_session_id=chat_session_id,
overridden_model=overridden_model, parent_message=(
# message=, final_msg
# rephrased_query=, if existing_assistant_message_id is None
# token_count=, else parent_message
message_type=MessageType.ASSISTANT, ),
alternate_assistant_id=new_msg_req.alternate_assistant_id, prompt_id=prompt_id,
# error=, overridden_model=overridden_model,
# reference_docs=, message=message,
db_session=db_session, rephrased_query=rephrased_query,
commit=False, token_count=token_count,
reserved_message_id=reserved_message_id, message_type=MessageType.ASSISTANT,
is_agentic=new_msg_req.use_agentic_search, alternate_assistant_id=new_msg_req.alternate_assistant_id,
) error=error,
reference_docs=reference_docs,
files=files,
citations=citations,
tool_call=tool_call,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
partial_response = create_response
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config: if new_msg_req.persona_override_config:
@ -1041,220 +1281,23 @@ def stream_chat_message_objects(
use_agentic_search=new_msg_req.use_agentic_search, use_agentic_search=new_msg_req.use_agentic_search,
) )
# reference_db_search_docs = None
# qa_docs_response = None
# # any files to associate with the AI message e.g. dall-e generated images
# ai_message_files = []
# dropped_indices = None
# tool_result = None
# TODO: different channels for stored info when it's coming from the agent flow
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict( info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[]) lambda: AnswerPostInfo(ai_message_files=[])
) )
refined_answer_improvement = True refined_answer_improvement = True
for packet in answer.processed_streamed_output: for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse): if isinstance(packet, ToolResponse):
level, level_question_num = ( info_by_subq = yield from _process_tool_response(
(packet.level, packet.level_question_num) packet=packet,
if isinstance(packet, ExtendedToolResponse) db_session=db_session,
else BASIC_KEY selected_db_search_docs=selected_db_search_docs,
info_by_subq=info_by_subq,
retrieval_options=retrieval_options,
user_file_files=user_file_files,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
search_for_ordering_only=search_for_ordering_only,
) )
assert level is not None
assert level_question_num is not None
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
continue
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=bool(
not search_for_ordering_only
and retrieval_options
and retrieval_options.dedupe_docs
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=(
user_files if search_for_ordering_only else []
),
)
# If we're using search just for ordering user files
if (
search_for_ordering_only
and user_files
and info.qa_docs_response
):
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(
f"ORDERING: Found {len(doc_order)} files from search results"
)
# Add any files that weren't in search results at the end
missing_files = [
f_id
for f_id in file_id_to_user_file.keys()
if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(
f"ORDERING: Added {len(missing_files)} missing files to the end"
)
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id]
for f_id in doc_order
if f_id in file_id_to_user_file
]
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
continue
if info.reference_db_search_docs is None:
logger.warning(
"No reference docs found for relevance filtering"
)
continue
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data
for img in img_generation_response
if img.image_data
],
)
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
info.ai_message_files.extend(
[
FileDescriptor(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV
),
)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo): elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED: if packet.stop_reason == StreamStopReason.FINISHED:
@ -1291,22 +1334,46 @@ def stream_chat_message_objects(
if isinstance(e, ToolCallException): if isinstance(e, ToolCallException):
yield StreamingError(error=error_msg, stack_trace=stack_trace) yield StreamingError(error=error_msg, stack_trace=stack_trace)
else: elif llm:
if llm: client_error_msg = litellm_exception_to_error_msg(e, llm)
client_error_msg = litellm_exception_to_error_msg(e, llm) if llm.config.api_key and len(llm.config.api_key) > 2:
if llm.config.api_key and len(llm.config.api_key) > 2: client_error_msg = client_error_msg.replace(
error_msg = error_msg.replace( llm.config.api_key, "[REDACTED_API_KEY]"
llm.config.api_key, "[REDACTED_API_KEY]" )
) stack_trace = stack_trace.replace(
stack_trace = stack_trace.replace( llm.config.api_key, "[REDACTED_API_KEY]"
llm.config.api_key, "[REDACTED_API_KEY]" )
)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace) yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
db_session.rollback() db_session.rollback()
return return
yield from _post_llm_answer_processing(
answer=answer,
info_by_subq=info_by_subq,
tool_dict=tool_dict,
partial_response=partial_response,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
db_session=db_session,
chat_session_id=chat_session_id,
refined_answer_improvement=refined_answer_improvement,
)
def _post_llm_answer_processing(
answer: Answer,
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
tool_dict: dict[int, list[Tool]],
partial_response: PartialResponse,
llm_tokenizer_encode_func: Callable[[str], list[int]],
db_session: Session,
chat_session_id: UUID,
refined_answer_improvement: bool | None,
) -> Generator[ChatPacket, None, None]:
"""
Stores messages in the db and yields some final packets to the frontend
"""
# Post-LLM answer processing # Post-LLM answer processing
try: try:
tool_name_to_tool_id: dict[str, int] = {} tool_name_to_tool_id: dict[str, int] = {}

View File

@ -56,7 +56,8 @@ def get_total_users_count(db_session: Session) -> int:
async def get_user_count(only_admin_users: bool = False) -> int: async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session: async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id)) count_stmt = func.count(User.id) # type: ignore
stmt = select(count_stmt)
if only_admin_users: if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN) stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt) result = await session.execute(stmt)