refactor to use stricter typing (#4513)

* refactor to use stricter typing

* older version of ruff
This commit is contained in:
evan-danswer 2025-04-14 10:23:07 -07:00 committed by GitHub
parent a5edc8aa0f
commit 68c6c1f4f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 386 additions and 308 deletions

View File

@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
(
or_(
ChatMessageFeedback.is_positive.is_(False),
ChatMessageFeedback.required_followup,
ChatMessageFeedback.required_followup.is_(True),
),
1,
),
@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
.all()
)
return results
return [tuple(row) for row in results]
def fetch_persona_message_analytics(

View File

@ -1,3 +1,5 @@
from typing import cast
import numpy as np
import torch
import torch.nn.functional as F
@ -39,10 +41,10 @@ logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
@ -50,13 +52,14 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> AutoTokenizer:
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
"distilbert-base-uncased"
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
)
return _CONNECTOR_CLASSIFIER_TOKENIZER
@ -92,12 +95,15 @@ def get_local_connector_classifier(
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> AutoTokenizer:
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
_INTENT_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
)
return _INTENT_TOKENIZER
@ -395,9 +401,9 @@ def run_content_classification_inference(
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids)
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
if not len(tokens) == len(is_keyword):
raise ValueError("Length of tokens and keyword predictions must match")

View File

@ -1,5 +1,6 @@
import json
import os
from typing import cast
import torch
import torch.nn as nn
@ -13,15 +14,14 @@ class HybridClassifier(nn.Module):
super().__init__()
config = DistilBertConfig()
self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
# Keyword tokenwise binary classification layer
self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.keyword_classifier = nn.Linear(config.dim, 2)
# Intent Classifier layers
self.pre_classifier = nn.Linear(
self.distilbert.config.dim, self.distilbert.config.dim
)
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.intent_classifier = nn.Linear(config.dim, 2)
self.device = torch.device("cpu")
@ -30,7 +30,7 @@ class HybridClassifier(nn.Module):
query_ids: torch.Tensor,
query_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
sequence_output = outputs.last_hidden_state
# Intent classification on the CLS token
@ -79,8 +79,9 @@ class ConnectorClassifier(nn.Module):
self.config = config
self.distilbert = DistilBertModel(config)
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
config = self.distilbert.config # type: ignore
self.connector_global_classifier = nn.Linear(config.dim, 1)
self.connector_match_classifier = nn.Linear(config.dim, 1)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Token indicating end of connector name, and on which classifier is used
@ -95,7 +96,7 @@ class ConnectorClassifier(nn.Module):
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert(
hidden_states = self.distilbert( # type: ignore
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
@ -114,7 +115,10 @@ class ConnectorClassifier(nn.Module):
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
)
device = (
torch.device("cuda")
if torch.cuda.is_available()

View File

@ -2,9 +2,11 @@ import time
import traceback
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from functools import partial
from typing import cast
from typing import Protocol
from uuid import UUID
from sqlalchemy.orm import Session
@ -82,6 +84,8 @@ from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import ChatMessage
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
@ -159,6 +163,25 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
COMMON_TOOL_RESPONSE_TYPES = {
"image": ChatFileType.IMAGE,
"csv": ChatFileType.CSV,
}
class PartialResponse(Protocol):
def __call__(
self,
message: str,
rephrased_query: str | None,
reference_docs: list[DbSearchDoc] | None,
files: list[FileDescriptor],
token_count: int,
citations: dict[int, int] | None,
error: str | None,
tool_call: ToolCall | None,
) -> ChatMessage: ...
def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
@ -211,25 +234,25 @@ def _handle_search_tool_response_summary(
reference_db_search_docs = selected_search_docs
doc_ids = {doc.id for doc in reference_db_search_docs}
if user_files is not None:
if user_files is not None and loaded_user_files is not None:
for user_file in user_files:
if user_file.id not in doc_ids:
associated_chat_file = None
if loaded_user_files is not None:
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
if user_file.id in doc_ids:
continue
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
@ -357,6 +380,86 @@ def _get_force_search_settings(
)
def _get_user_knowledge_files(
info: AnswerPostInfo,
user_files: list[InMemoryChatFile],
file_id_to_user_file: dict[str, InMemoryChatFile],
) -> Generator[UserKnowledgeFilePacket, None, None]:
if not info.qa_docs_response:
return
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
# Add any files that weren't in search results at the end
missing_files = [
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
]
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
def _get_persona_for_chat_session(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
default_persona: Persona,
) -> Persona:
if new_msg_req.alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
new_msg_req.alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = default_persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
return persona
ChatPacket = (
StreamingError
| QADocsResponse
@ -378,6 +481,149 @@ ChatPacket = (
ChatPacketStream = Iterator[ChatPacket]
def _process_tool_response(
packet: ToolResponse,
db_session: Session,
selected_db_search_docs: list[DbSearchDoc] | None,
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
retrieval_options: RetrievalDetails | None,
user_file_files: list[UserFile] | None,
user_files: list[InMemoryChatFile] | None,
file_id_to_user_file: dict[str, InMemoryChatFile],
search_for_ordering_only: bool,
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
assert level is not None
assert level_question_num is not None
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
return info_by_subq
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=bool(
not search_for_ordering_only
and retrieval_options
and retrieval_options.dedupe_docs
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=(user_files if search_for_ordering_only else []),
)
# If we're using search just for ordering user files
if search_for_ordering_only and user_files:
yield from _get_user_knowledge_files(
info=info,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
return info_by_subq
if info.reference_db_search_docs is None:
logger.warning("No reference docs found for relevance filtering")
return info_by_subq
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data for img in img_generation_response if img.image_data
],
)
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
response_type = custom_tool_response.response_type
if response_type in COMMON_TOOL_RESPONSE_TYPES:
file_ids = custom_tool_response.tool_result.file_ids
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=file_type)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
return info_by_subq
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
@ -421,7 +667,6 @@ def stream_chat_message_objects(
try:
# Move these variables inside the try block
file_id_to_user_file = {}
ordered_user_files = None
user_id = user.id if user is not None else None
@ -436,35 +681,19 @@ def stream_chat_message_objects(
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
alternate_assistant_id = new_msg_req.alternate_assistant_id
new_msg_req.alternate_assistant_id
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
if alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = chat_session.persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
persona = _get_persona_for_chat_session(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
default_persona=chat_session.persona,
)
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
user=user,
@ -746,31 +975,42 @@ def stream_chat_message_objects(
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
def create_response(
message: str,
rephrased_query: str | None,
reference_docs: list[DbSearchDoc] | None,
files: list[FileDescriptor],
token_count: int,
citations: dict[int, int] | None,
error: str | None,
tool_call: ToolCall | None,
) -> ChatMessage:
return create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=(
final_msg
if existing_assistant_message_id is None
else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
message=message,
rephrased_query=rephrased_query,
token_count=token_count,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
error=error,
reference_docs=reference_docs,
files=files,
citations=citations,
tool_call=tool_call,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
partial_response = create_response
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config:
@ -1041,220 +1281,23 @@ def stream_chat_message_objects(
use_agentic_search=new_msg_req.use_agentic_search,
)
# reference_db_search_docs = None
# qa_docs_response = None
# # any files to associate with the AI message e.g. dall-e generated images
# ai_message_files = []
# dropped_indices = None
# tool_result = None
# TODO: different channels for stored info when it's coming from the agent flow
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[])
)
refined_answer_improvement = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
info_by_subq = yield from _process_tool_response(
packet=packet,
db_session=db_session,
selected_db_search_docs=selected_db_search_docs,
info_by_subq=info_by_subq,
retrieval_options=retrieval_options,
user_file_files=user_file_files,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
search_for_ordering_only=search_for_ordering_only,
)
assert level is not None
assert level_question_num is not None
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
continue
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=bool(
not search_for_ordering_only
and retrieval_options
and retrieval_options.dedupe_docs
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=(
user_files if search_for_ordering_only else []
),
)
# If we're using search just for ordering user files
if (
search_for_ordering_only
and user_files
and info.qa_docs_response
):
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(
f"ORDERING: Found {len(doc_order)} files from search results"
)
# Add any files that weren't in search results at the end
missing_files = [
f_id
for f_id in file_id_to_user_file.keys()
if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(
f"ORDERING: Added {len(missing_files)} missing files to the end"
)
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id]
for f_id in doc_order
if f_id in file_id_to_user_file
]
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
continue
if info.reference_db_search_docs is None:
logger.warning(
"No reference docs found for relevance filtering"
)
continue
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data
for img in img_generation_response
if img.image_data
],
)
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
info.ai_message_files.extend(
[
FileDescriptor(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV
),
)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:
@ -1291,22 +1334,46 @@ def stream_chat_message_objects(
if isinstance(e, ToolCallException):
yield StreamingError(error=error_msg, stack_trace=stack_trace)
else:
if llm:
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
elif llm:
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
client_error_msg = client_error_msg.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
db_session.rollback()
return
yield from _post_llm_answer_processing(
answer=answer,
info_by_subq=info_by_subq,
tool_dict=tool_dict,
partial_response=partial_response,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
db_session=db_session,
chat_session_id=chat_session_id,
refined_answer_improvement=refined_answer_improvement,
)
def _post_llm_answer_processing(
answer: Answer,
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
tool_dict: dict[int, list[Tool]],
partial_response: PartialResponse,
llm_tokenizer_encode_func: Callable[[str], list[int]],
db_session: Session,
chat_session_id: UUID,
refined_answer_improvement: bool | None,
) -> Generator[ChatPacket, None, None]:
"""
Stores messages in the db and yields some final packets to the frontend
"""
# Post-LLM answer processing
try:
tool_name_to_tool_id: dict[str, int] = {}

View File

@ -56,7 +56,8 @@ def get_total_users_count(db_session: Session) -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
count_stmt = func.count(User.id) # type: ignore
stmt = select(count_stmt)
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)