From 68c6c1f4f8d53f0fdd96b6803117017cd92f52f2 Mon Sep 17 00:00:00 2001 From: evan-danswer Date: Mon, 14 Apr 2025 10:23:07 -0700 Subject: [PATCH] refactor to use stricter typing (#4513) * refactor to use stricter typing * older version of ruff --- backend/ee/onyx/db/analytics.py | 4 +- backend/model_server/custom_models.py | 24 +- backend/model_server/onyx_torch_model.py | 24 +- backend/onyx/chat/process_message.py | 639 +++++++++++++---------- backend/onyx/db/auth.py | 3 +- 5 files changed, 386 insertions(+), 308 deletions(-) diff --git a/backend/ee/onyx/db/analytics.py b/backend/ee/onyx/db/analytics.py index b9ae0005d..cb54f232d 100644 --- a/backend/ee/onyx/db/analytics.py +++ b/backend/ee/onyx/db/analytics.py @@ -140,7 +140,7 @@ def fetch_onyxbot_analytics( ( or_( ChatMessageFeedback.is_positive.is_(False), - ChatMessageFeedback.required_followup, + ChatMessageFeedback.required_followup.is_(True), ), 1, ), @@ -173,7 +173,7 @@ def fetch_onyxbot_analytics( .all() ) - return results + return [tuple(row) for row in results] def fetch_persona_message_analytics( diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py index ea15721ae..86c3e9380 100644 --- a/backend/model_server/custom_models.py +++ b/backend/model_server/custom_models.py @@ -1,3 +1,5 @@ +from typing import cast + import numpy as np import torch import torch.nn.functional as F @@ -39,10 +41,10 @@ logger = setup_logger() router = APIRouter(prefix="/custom") -_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None +_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None -_INTENT_TOKENIZER: AutoTokenizer | None = None +_INTENT_TOKENIZER: PreTrainedTokenizer | None = None _INTENT_MODEL: HybridClassifier | None = None _INFORMATION_CONTENT_MODEL: SetFitModel | None = None @@ -50,13 +52,14 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version! -def get_connector_classifier_tokenizer() -> AutoTokenizer: +def get_connector_classifier_tokenizer() -> PreTrainedTokenizer: global _CONNECTOR_CLASSIFIER_TOKENIZER if _CONNECTOR_CLASSIFIER_TOKENIZER is None: # The tokenizer details are not uploaded to the HF hub since it's just the # unmodified distilbert tokenizer. - _CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained( - "distilbert-base-uncased" + _CONNECTOR_CLASSIFIER_TOKENIZER = cast( + PreTrainedTokenizer, + AutoTokenizer.from_pretrained("distilbert-base-uncased"), ) return _CONNECTOR_CLASSIFIER_TOKENIZER @@ -92,12 +95,15 @@ def get_local_connector_classifier( return _CONNECTOR_CLASSIFIER_MODEL -def get_intent_model_tokenizer() -> AutoTokenizer: +def get_intent_model_tokenizer() -> PreTrainedTokenizer: global _INTENT_TOKENIZER if _INTENT_TOKENIZER is None: # The tokenizer details are not uploaded to the HF hub since it's just the # unmodified distilbert tokenizer. - _INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased") + _INTENT_TOKENIZER = cast( + PreTrainedTokenizer, + AutoTokenizer.from_pretrained("distilbert-base-uncased"), + ) return _INTENT_TOKENIZER @@ -395,9 +401,9 @@ def run_content_classification_inference( def map_keywords( - input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool] + input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool] ) -> list[str]: - tokens = tokenizer.convert_ids_to_tokens(input_ids) + tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore if not len(tokens) == len(is_keyword): raise ValueError("Length of tokens and keyword predictions must match") diff --git a/backend/model_server/onyx_torch_model.py b/backend/model_server/onyx_torch_model.py index 1bb5544ac..6429537af 100644 --- a/backend/model_server/onyx_torch_model.py +++ b/backend/model_server/onyx_torch_model.py @@ -1,5 +1,6 @@ import json import os +from typing import cast import torch import torch.nn as nn @@ -13,15 +14,14 @@ class HybridClassifier(nn.Module): super().__init__() config = DistilBertConfig() self.distilbert = DistilBertModel(config) + config = self.distilbert.config # type: ignore # Keyword tokenwise binary classification layer - self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2) + self.keyword_classifier = nn.Linear(config.dim, 2) # Intent Classifier layers - self.pre_classifier = nn.Linear( - self.distilbert.config.dim, self.distilbert.config.dim - ) - self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2) + self.pre_classifier = nn.Linear(config.dim, config.dim) + self.intent_classifier = nn.Linear(config.dim, 2) self.device = torch.device("cpu") @@ -30,7 +30,7 @@ class HybridClassifier(nn.Module): query_ids: torch.Tensor, query_mask: torch.Tensor, ) -> dict[str, torch.Tensor]: - outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) + outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore sequence_output = outputs.last_hidden_state # Intent classification on the CLS token @@ -79,8 +79,9 @@ class ConnectorClassifier(nn.Module): self.config = config self.distilbert = DistilBertModel(config) - self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1) - self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1) + config = self.distilbert.config # type: ignore + self.connector_global_classifier = nn.Linear(config.dim, 1) + self.connector_match_classifier = nn.Linear(config.dim, 1) self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") # Token indicating end of connector name, and on which classifier is used @@ -95,7 +96,7 @@ class ConnectorClassifier(nn.Module): input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - hidden_states = self.distilbert( + hidden_states = self.distilbert( # type: ignore input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state @@ -114,7 +115,10 @@ class ConnectorClassifier(nn.Module): @classmethod def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier": - config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")) + config = cast( + DistilBertConfig, + DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")), + ) device = ( torch.device("cuda") if torch.cuda.is_available() diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index bfce443b8..19eb3e2b0 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -2,9 +2,11 @@ import time import traceback from collections import defaultdict from collections.abc import Callable +from collections.abc import Generator from collections.abc import Iterator -from functools import partial from typing import cast +from typing import Protocol +from uuid import UUID from sqlalchemy.orm import Session @@ -82,6 +84,8 @@ from onyx.db.engine import get_session_context_manager from onyx.db.milestone import check_multi_assistant_milestone from onyx.db.milestone import create_milestone_if_not_exists from onyx.db.milestone import update_user_assistant_milestone +from onyx.db.models import ChatMessage +from onyx.db.models import Persona from onyx.db.models import SearchDoc as DbSearchDoc from onyx.db.models import ToolCall from onyx.db.models import User @@ -159,6 +163,25 @@ from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() ERROR_TYPE_CANCELLED = "cancelled" +COMMON_TOOL_RESPONSE_TYPES = { + "image": ChatFileType.IMAGE, + "csv": ChatFileType.CSV, +} + + +class PartialResponse(Protocol): + def __call__( + self, + message: str, + rephrased_query: str | None, + reference_docs: list[DbSearchDoc] | None, + files: list[FileDescriptor], + token_count: int, + citations: dict[int, int] | None, + error: str | None, + tool_call: ToolCall | None, + ) -> ChatMessage: ... + def _translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] @@ -211,25 +234,25 @@ def _handle_search_tool_response_summary( reference_db_search_docs = selected_search_docs doc_ids = {doc.id for doc in reference_db_search_docs} - if user_files is not None: + if user_files is not None and loaded_user_files is not None: for user_file in user_files: - if user_file.id not in doc_ids: - associated_chat_file = None - if loaded_user_files is not None: - associated_chat_file = next( - ( - file - for file in loaded_user_files - if file.file_id == str(user_file.file_id) - ), - None, - ) - # Use create_search_doc_from_user_file to properly add the document to the database - if associated_chat_file is not None: - db_doc = create_search_doc_from_user_file( - user_file, associated_chat_file, db_session - ) - reference_db_search_docs.append(db_doc) + if user_file.id in doc_ids: + continue + + associated_chat_file = next( + ( + file + for file in loaded_user_files + if file.file_id == str(user_file.file_id) + ), + None, + ) + # Use create_search_doc_from_user_file to properly add the document to the database + if associated_chat_file is not None: + db_doc = create_search_doc_from_user_file( + user_file, associated_chat_file, db_session + ) + reference_db_search_docs.append(db_doc) response_docs = [ translate_db_search_doc_to_server_search_doc(db_search_doc) @@ -357,6 +380,86 @@ def _get_force_search_settings( ) +def _get_user_knowledge_files( + info: AnswerPostInfo, + user_files: list[InMemoryChatFile], + file_id_to_user_file: dict[str, InMemoryChatFile], +) -> Generator[UserKnowledgeFilePacket, None, None]: + if not info.qa_docs_response: + return + + logger.info( + f"ORDERING: Processing search results for ordering {len(user_files)} user files" + ) + + # Extract document order from search results + doc_order = [] + for doc in info.qa_docs_response.top_documents: + doc_id = doc.document_id + if str(doc_id).startswith("USER_FILE_CONNECTOR__"): + file_id = doc_id.replace("USER_FILE_CONNECTOR__", "") + if file_id in file_id_to_user_file: + doc_order.append(file_id) + + logger.info(f"ORDERING: Found {len(doc_order)} files from search results") + + # Add any files that weren't in search results at the end + missing_files = [ + f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order + ] + + missing_files.extend(doc_order) + doc_order = missing_files + + logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end") + + # Reorder user files based on search results + ordered_user_files = [ + file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file + ] + + yield UserKnowledgeFilePacket( + user_files=[ + FileDescriptor( + id=str(file.file_id), + type=ChatFileType.USER_KNOWLEDGE, + ) + for file in ordered_user_files + ] + ) + + +def _get_persona_for_chat_session( + new_msg_req: CreateChatMessageRequest, + user: User | None, + db_session: Session, + default_persona: Persona, +) -> Persona: + if new_msg_req.alternate_assistant_id is not None: + # Allows users to specify a temporary persona (assistant) in the chat session + # this takes highest priority since it's user specified + persona = get_persona_by_id( + new_msg_req.alternate_assistant_id, + user=user, + db_session=db_session, + is_for_edit=False, + ) + elif new_msg_req.persona_override_config: + # Certain endpoints allow users to specify arbitrary persona settings + # this should never conflict with the alternate_assistant_id + persona = create_temporary_persona( + db_session=db_session, + persona_config=new_msg_req.persona_override_config, + user=user, + ) + else: + persona = default_persona + + if not persona: + raise RuntimeError("No persona specified or found for chat session") + return persona + + ChatPacket = ( StreamingError | QADocsResponse @@ -378,6 +481,149 @@ ChatPacket = ( ChatPacketStream = Iterator[ChatPacket] +def _process_tool_response( + packet: ToolResponse, + db_session: Session, + selected_db_search_docs: list[DbSearchDoc] | None, + info_by_subq: dict[SubQuestionKey, AnswerPostInfo], + retrieval_options: RetrievalDetails | None, + user_file_files: list[UserFile] | None, + user_files: list[InMemoryChatFile] | None, + file_id_to_user_file: dict[str, InMemoryChatFile], + search_for_ordering_only: bool, +) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]: + level, level_question_num = ( + (packet.level, packet.level_question_num) + if isinstance(packet, ExtendedToolResponse) + else BASIC_KEY + ) + + assert level is not None + assert level_question_num is not None + info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)] + + # Skip LLM relevance processing entirely for ordering-only mode + if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID: + logger.info( + "Fast path: Completely bypassing section relevance processing for ordering-only mode" + ) + # Skip this packet entirely since it would trigger LLM processing + return info_by_subq + + # TODO: don't need to dedupe here when we do it in agent flow + if packet.id == SEARCH_RESPONSE_SUMMARY_ID: + if search_for_ordering_only: + logger.info( + "Fast path: Skipping document deduplication for ordering-only mode" + ) + + ( + info.qa_docs_response, + info.reference_db_search_docs, + info.dropped_indices, + ) = _handle_search_tool_response_summary( + packet=packet, + db_session=db_session, + selected_search_docs=selected_db_search_docs, + # Deduping happens at the last step to avoid harming quality by dropping content early on + # Skip deduping completely for ordering-only mode to save time + dedupe_docs=bool( + not search_for_ordering_only + and retrieval_options + and retrieval_options.dedupe_docs + ), + user_files=user_file_files if search_for_ordering_only else [], + loaded_user_files=(user_files if search_for_ordering_only else []), + ) + + # If we're using search just for ordering user files + if search_for_ordering_only and user_files: + yield from _get_user_knowledge_files( + info=info, + user_files=user_files, + file_id_to_user_file=file_id_to_user_file, + ) + + yield info.qa_docs_response + elif packet.id == SECTION_RELEVANCE_LIST_ID: + relevance_sections = packet.response + + if search_for_ordering_only: + logger.info( + "Performance: Skipping relevance filtering for ordering-only mode" + ) + return info_by_subq + + if info.reference_db_search_docs is None: + logger.warning("No reference docs found for relevance filtering") + return info_by_subq + + llm_indices = relevant_sections_to_indices( + relevance_sections=relevance_sections, + items=[ + translate_db_search_doc_to_server_search_doc(doc) + for doc in info.reference_db_search_docs + ], + ) + + if info.dropped_indices: + llm_indices = drop_llm_indices( + llm_indices=llm_indices, + search_docs=info.reference_db_search_docs, + dropped_indices=info.dropped_indices, + ) + + yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices) + elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: + yield FinalUsedContextDocsResponse(final_context_docs=packet.response) + + elif packet.id == IMAGE_GENERATION_RESPONSE_ID: + img_generation_response = cast(list[ImageGenerationResponse], packet.response) + + file_ids = save_files( + urls=[img.url for img in img_generation_response if img.url], + base64_files=[ + img.image_data for img in img_generation_response if img.image_data + ], + ) + info.ai_message_files.extend( + [ + FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) + for file_id in file_ids + ] + ) + yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids]) + elif packet.id == INTERNET_SEARCH_RESPONSE_ID: + ( + info.qa_docs_response, + info.reference_db_search_docs, + ) = _handle_internet_search_tool_response_summary( + packet=packet, + db_session=db_session, + ) + yield info.qa_docs_response + elif packet.id == CUSTOM_TOOL_RESPONSE_ID: + custom_tool_response = cast(CustomToolCallSummary, packet.response) + response_type = custom_tool_response.response_type + if response_type in COMMON_TOOL_RESPONSE_TYPES: + file_ids = custom_tool_response.tool_result.file_ids + file_type = COMMON_TOOL_RESPONSE_TYPES[response_type] + info.ai_message_files.extend( + [ + FileDescriptor(id=str(file_id), type=file_type) + for file_id in file_ids + ] + ) + yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids]) + else: + yield CustomToolResponse( + response=custom_tool_response.tool_result, + tool_name=custom_tool_response.tool_name, + ) + + return info_by_subq + + def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, @@ -421,7 +667,6 @@ def stream_chat_message_objects( try: # Move these variables inside the try block file_id_to_user_file = {} - ordered_user_files = None user_id = user.id if user is not None else None @@ -436,35 +681,19 @@ def stream_chat_message_objects( parent_id = new_msg_req.parent_message_id reference_doc_ids = new_msg_req.search_doc_ids retrieval_options = new_msg_req.retrieval_options - alternate_assistant_id = new_msg_req.alternate_assistant_id + new_msg_req.alternate_assistant_id # permanent "log" store, used primarily for debugging long_term_logger = LongTermLogger( metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)} ) - if alternate_assistant_id is not None: - # Allows users to specify a temporary persona (assistant) in the chat session - # this takes highest priority since it's user specified - persona = get_persona_by_id( - alternate_assistant_id, - user=user, - db_session=db_session, - is_for_edit=False, - ) - elif new_msg_req.persona_override_config: - # Certain endpoints allow users to specify arbitrary persona settings - # this should never conflict with the alternate_assistant_id - persona = persona = create_temporary_persona( - db_session=db_session, - persona_config=new_msg_req.persona_override_config, - user=user, - ) - else: - persona = chat_session.persona - - if not persona: - raise RuntimeError("No persona specified or found for chat session") + persona = _get_persona_for_chat_session( + new_msg_req=new_msg_req, + user=user, + db_session=db_session, + default_persona=chat_session.persona, + ) multi_assistant_milestone, _is_new = create_milestone_if_not_exists( user=user, @@ -746,31 +975,42 @@ def stream_chat_message_objects( new_msg_req.llm_override.model_version if new_msg_req.llm_override else None ) - # Cannot determine these without the LLM step or breaking out early - partial_response = partial( - create_new_chat_message, - chat_session_id=chat_session_id, - # if we're using an existing assistant message, then this will just be an - # update operation, in which case the parent should be the parent of - # the latest. If we're creating a new assistant message, then the parent - # should be the latest message (latest user message) - parent_message=( - final_msg if existing_assistant_message_id is None else parent_message - ), - prompt_id=prompt_id, - overridden_model=overridden_model, - # message=, - # rephrased_query=, - # token_count=, - message_type=MessageType.ASSISTANT, - alternate_assistant_id=new_msg_req.alternate_assistant_id, - # error=, - # reference_docs=, - db_session=db_session, - commit=False, - reserved_message_id=reserved_message_id, - is_agentic=new_msg_req.use_agentic_search, - ) + def create_response( + message: str, + rephrased_query: str | None, + reference_docs: list[DbSearchDoc] | None, + files: list[FileDescriptor], + token_count: int, + citations: dict[int, int] | None, + error: str | None, + tool_call: ToolCall | None, + ) -> ChatMessage: + return create_new_chat_message( + chat_session_id=chat_session_id, + parent_message=( + final_msg + if existing_assistant_message_id is None + else parent_message + ), + prompt_id=prompt_id, + overridden_model=overridden_model, + message=message, + rephrased_query=rephrased_query, + token_count=token_count, + message_type=MessageType.ASSISTANT, + alternate_assistant_id=new_msg_req.alternate_assistant_id, + error=error, + reference_docs=reference_docs, + files=files, + citations=citations, + tool_call=tool_call, + db_session=db_session, + commit=False, + reserved_message_id=reserved_message_id, + is_agentic=new_msg_req.use_agentic_search, + ) + + partial_response = create_response prompt_override = new_msg_req.prompt_override or chat_session.prompt_override if new_msg_req.persona_override_config: @@ -1041,220 +1281,23 @@ def stream_chat_message_objects( use_agentic_search=new_msg_req.use_agentic_search, ) - # reference_db_search_docs = None - # qa_docs_response = None - # # any files to associate with the AI message e.g. dall-e generated images - # ai_message_files = [] - # dropped_indices = None - # tool_result = None - - # TODO: different channels for stored info when it's coming from the agent flow info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict( lambda: AnswerPostInfo(ai_message_files=[]) ) refined_answer_improvement = True for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): - level, level_question_num = ( - (packet.level, packet.level_question_num) - if isinstance(packet, ExtendedToolResponse) - else BASIC_KEY + info_by_subq = yield from _process_tool_response( + packet=packet, + db_session=db_session, + selected_db_search_docs=selected_db_search_docs, + info_by_subq=info_by_subq, + retrieval_options=retrieval_options, + user_file_files=user_file_files, + user_files=user_files, + file_id_to_user_file=file_id_to_user_file, + search_for_ordering_only=search_for_ordering_only, ) - assert level is not None - assert level_question_num is not None - info = info_by_subq[ - SubQuestionKey(level=level, question_num=level_question_num) - ] - - # Skip LLM relevance processing entirely for ordering-only mode - if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID: - logger.info( - "Fast path: Completely bypassing section relevance processing for ordering-only mode" - ) - # Skip this packet entirely since it would trigger LLM processing - continue - - # TODO: don't need to dedupe here when we do it in agent flow - if packet.id == SEARCH_RESPONSE_SUMMARY_ID: - if search_for_ordering_only: - logger.info( - "Fast path: Skipping document deduplication for ordering-only mode" - ) - - ( - info.qa_docs_response, - info.reference_db_search_docs, - info.dropped_indices, - ) = _handle_search_tool_response_summary( - packet=packet, - db_session=db_session, - selected_search_docs=selected_db_search_docs, - # Deduping happens at the last step to avoid harming quality by dropping content early on - # Skip deduping completely for ordering-only mode to save time - dedupe_docs=bool( - not search_for_ordering_only - and retrieval_options - and retrieval_options.dedupe_docs - ), - user_files=user_file_files if search_for_ordering_only else [], - loaded_user_files=( - user_files if search_for_ordering_only else [] - ), - ) - - # If we're using search just for ordering user files - if ( - search_for_ordering_only - and user_files - and info.qa_docs_response - ): - logger.info( - f"ORDERING: Processing search results for ordering {len(user_files)} user files" - ) - - # Extract document order from search results - doc_order = [] - for doc in info.qa_docs_response.top_documents: - doc_id = doc.document_id - if str(doc_id).startswith("USER_FILE_CONNECTOR__"): - file_id = doc_id.replace("USER_FILE_CONNECTOR__", "") - if file_id in file_id_to_user_file: - doc_order.append(file_id) - - logger.info( - f"ORDERING: Found {len(doc_order)} files from search results" - ) - - # Add any files that weren't in search results at the end - missing_files = [ - f_id - for f_id in file_id_to_user_file.keys() - if f_id not in doc_order - ] - - missing_files.extend(doc_order) - doc_order = missing_files - - logger.info( - f"ORDERING: Added {len(missing_files)} missing files to the end" - ) - - # Reorder user files based on search results - ordered_user_files = [ - file_id_to_user_file[f_id] - for f_id in doc_order - if f_id in file_id_to_user_file - ] - - yield UserKnowledgeFilePacket( - user_files=[ - FileDescriptor( - id=str(file.file_id), - type=ChatFileType.USER_KNOWLEDGE, - ) - for file in ordered_user_files - ] - ) - - yield info.qa_docs_response - elif packet.id == SECTION_RELEVANCE_LIST_ID: - relevance_sections = packet.response - - if search_for_ordering_only: - logger.info( - "Performance: Skipping relevance filtering for ordering-only mode" - ) - continue - - if info.reference_db_search_docs is None: - logger.warning( - "No reference docs found for relevance filtering" - ) - continue - - llm_indices = relevant_sections_to_indices( - relevance_sections=relevance_sections, - items=[ - translate_db_search_doc_to_server_search_doc(doc) - for doc in info.reference_db_search_docs - ], - ) - - if info.dropped_indices: - llm_indices = drop_llm_indices( - llm_indices=llm_indices, - search_docs=info.reference_db_search_docs, - dropped_indices=info.dropped_indices, - ) - - yield LLMRelevanceFilterResponse( - llm_selected_doc_indices=llm_indices - ) - elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: - yield FinalUsedContextDocsResponse( - final_context_docs=packet.response - ) - - elif packet.id == IMAGE_GENERATION_RESPONSE_ID: - img_generation_response = cast( - list[ImageGenerationResponse], packet.response - ) - - file_ids = save_files( - urls=[img.url for img in img_generation_response if img.url], - base64_files=[ - img.image_data - for img in img_generation_response - if img.image_data - ], - ) - info.ai_message_files.extend( - [ - FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) - for file_id in file_ids - ] - ) - yield FileChatDisplay( - file_ids=[str(file_id) for file_id in file_ids] - ) - elif packet.id == INTERNET_SEARCH_RESPONSE_ID: - ( - info.qa_docs_response, - info.reference_db_search_docs, - ) = _handle_internet_search_tool_response_summary( - packet=packet, - db_session=db_session, - ) - yield info.qa_docs_response - elif packet.id == CUSTOM_TOOL_RESPONSE_ID: - custom_tool_response = cast(CustomToolCallSummary, packet.response) - - if ( - custom_tool_response.response_type == "image" - or custom_tool_response.response_type == "csv" - ): - file_ids = custom_tool_response.tool_result.file_ids - info.ai_message_files.extend( - [ - FileDescriptor( - id=str(file_id), - type=( - ChatFileType.IMAGE - if custom_tool_response.response_type == "image" - else ChatFileType.CSV - ), - ) - for file_id in file_ids - ] - ) - yield FileChatDisplay( - file_ids=[str(file_id) for file_id in file_ids] - ) - else: - yield CustomToolResponse( - response=custom_tool_response.tool_result, - tool_name=custom_tool_response.tool_name, - ) elif isinstance(packet, StreamStopInfo): if packet.stop_reason == StreamStopReason.FINISHED: @@ -1291,22 +1334,46 @@ def stream_chat_message_objects( if isinstance(e, ToolCallException): yield StreamingError(error=error_msg, stack_trace=stack_trace) - else: - if llm: - client_error_msg = litellm_exception_to_error_msg(e, llm) - if llm.config.api_key and len(llm.config.api_key) > 2: - error_msg = error_msg.replace( - llm.config.api_key, "[REDACTED_API_KEY]" - ) - stack_trace = stack_trace.replace( - llm.config.api_key, "[REDACTED_API_KEY]" - ) + elif llm: + client_error_msg = litellm_exception_to_error_msg(e, llm) + if llm.config.api_key and len(llm.config.api_key) > 2: + client_error_msg = client_error_msg.replace( + llm.config.api_key, "[REDACTED_API_KEY]" + ) + stack_trace = stack_trace.replace( + llm.config.api_key, "[REDACTED_API_KEY]" + ) - yield StreamingError(error=client_error_msg, stack_trace=stack_trace) + yield StreamingError(error=client_error_msg, stack_trace=stack_trace) db_session.rollback() return + yield from _post_llm_answer_processing( + answer=answer, + info_by_subq=info_by_subq, + tool_dict=tool_dict, + partial_response=partial_response, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, + db_session=db_session, + chat_session_id=chat_session_id, + refined_answer_improvement=refined_answer_improvement, + ) + + +def _post_llm_answer_processing( + answer: Answer, + info_by_subq: dict[SubQuestionKey, AnswerPostInfo], + tool_dict: dict[int, list[Tool]], + partial_response: PartialResponse, + llm_tokenizer_encode_func: Callable[[str], list[int]], + db_session: Session, + chat_session_id: UUID, + refined_answer_improvement: bool | None, +) -> Generator[ChatPacket, None, None]: + """ + Stores messages in the db and yields some final packets to the frontend + """ # Post-LLM answer processing try: tool_name_to_tool_id: dict[str, int] = {} diff --git a/backend/onyx/db/auth.py b/backend/onyx/db/auth.py index cbc1dcc0f..e890c2ec7 100644 --- a/backend/onyx/db/auth.py +++ b/backend/onyx/db/auth.py @@ -56,7 +56,8 @@ def get_total_users_count(db_session: Session) -> int: async def get_user_count(only_admin_users: bool = False) -> int: async with get_async_session_with_tenant() as session: - stmt = select(func.count(User.id)) + count_stmt = func.count(User.id) # type: ignore + stmt = select(count_stmt) if only_admin_users: stmt = stmt.where(User.role == UserRole.ADMIN) result = await session.execute(stmt)