mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-05 12:39:33 +02:00
refactor to use stricter typing (#4513)
* refactor to use stricter typing * older version of ruff
This commit is contained in:
parent
a5edc8aa0f
commit
68c6c1f4f8
@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
|
|||||||
(
|
(
|
||||||
or_(
|
or_(
|
||||||
ChatMessageFeedback.is_positive.is_(False),
|
ChatMessageFeedback.is_positive.is_(False),
|
||||||
ChatMessageFeedback.required_followup,
|
ChatMessageFeedback.required_followup.is_(True),
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return [tuple(row) for row in results]
|
||||||
|
|
||||||
|
|
||||||
def fetch_persona_message_analytics(
|
def fetch_persona_message_analytics(
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -39,10 +41,10 @@ logger = setup_logger()
|
|||||||
|
|
||||||
router = APIRouter(prefix="/custom")
|
router = APIRouter(prefix="/custom")
|
||||||
|
|
||||||
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
|
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
|
||||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||||
|
|
||||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
|
||||||
_INTENT_MODEL: HybridClassifier | None = None
|
_INTENT_MODEL: HybridClassifier | None = None
|
||||||
|
|
||||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||||
@ -50,13 +52,14 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
|||||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||||
|
|
||||||
|
|
||||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
|
||||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||||
# unmodified distilbert tokenizer.
|
# unmodified distilbert tokenizer.
|
||||||
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
|
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||||
"distilbert-base-uncased"
|
PreTrainedTokenizer,
|
||||||
|
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||||
)
|
)
|
||||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||||
|
|
||||||
@ -92,12 +95,15 @@ def get_local_connector_classifier(
|
|||||||
return _CONNECTOR_CLASSIFIER_MODEL
|
return _CONNECTOR_CLASSIFIER_MODEL
|
||||||
|
|
||||||
|
|
||||||
def get_intent_model_tokenizer() -> AutoTokenizer:
|
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
|
||||||
global _INTENT_TOKENIZER
|
global _INTENT_TOKENIZER
|
||||||
if _INTENT_TOKENIZER is None:
|
if _INTENT_TOKENIZER is None:
|
||||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||||
# unmodified distilbert tokenizer.
|
# unmodified distilbert tokenizer.
|
||||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
_INTENT_TOKENIZER = cast(
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||||
|
)
|
||||||
return _INTENT_TOKENIZER
|
return _INTENT_TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
@ -395,9 +401,9 @@ def run_content_classification_inference(
|
|||||||
|
|
||||||
|
|
||||||
def map_keywords(
|
def map_keywords(
|
||||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||||
|
|
||||||
if not len(tokens) == len(is_keyword):
|
if not len(tokens) == len(is_keyword):
|
||||||
raise ValueError("Length of tokens and keyword predictions must match")
|
raise ValueError("Length of tokens and keyword predictions must match")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -13,15 +14,14 @@ class HybridClassifier(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config = DistilBertConfig()
|
config = DistilBertConfig()
|
||||||
self.distilbert = DistilBertModel(config)
|
self.distilbert = DistilBertModel(config)
|
||||||
|
config = self.distilbert.config # type: ignore
|
||||||
|
|
||||||
# Keyword tokenwise binary classification layer
|
# Keyword tokenwise binary classification layer
|
||||||
self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||||
|
|
||||||
# Intent Classifier layers
|
# Intent Classifier layers
|
||||||
self.pre_classifier = nn.Linear(
|
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||||
self.distilbert.config.dim, self.distilbert.config.dim
|
self.intent_classifier = nn.Linear(config.dim, 2)
|
||||||
)
|
|
||||||
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
|
||||||
|
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class HybridClassifier(nn.Module):
|
|||||||
query_ids: torch.Tensor,
|
query_ids: torch.Tensor,
|
||||||
query_mask: torch.Tensor,
|
query_mask: torch.Tensor,
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
|
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
|
||||||
sequence_output = outputs.last_hidden_state
|
sequence_output = outputs.last_hidden_state
|
||||||
|
|
||||||
# Intent classification on the CLS token
|
# Intent classification on the CLS token
|
||||||
@ -79,8 +79,9 @@ class ConnectorClassifier(nn.Module):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.distilbert = DistilBertModel(config)
|
self.distilbert = DistilBertModel(config)
|
||||||
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
config = self.distilbert.config # type: ignore
|
||||||
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||||
|
self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||||
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||||
|
|
||||||
# Token indicating end of connector name, and on which classifier is used
|
# Token indicating end of connector name, and on which classifier is used
|
||||||
@ -95,7 +96,7 @@ class ConnectorClassifier(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
hidden_states = self.distilbert(
|
hidden_states = self.distilbert( # type: ignore
|
||||||
input_ids=input_ids, attention_mask=attention_mask
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
).last_hidden_state
|
).last_hidden_state
|
||||||
|
|
||||||
@ -114,7 +115,10 @@ class ConnectorClassifier(nn.Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||||
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
|
config = cast(
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||||
|
)
|
||||||
device = (
|
device = (
|
||||||
torch.device("cuda")
|
torch.device("cuda")
|
||||||
if torch.cuda.is_available()
|
if torch.cuda.is_available()
|
||||||
|
@ -2,9 +2,11 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Generator
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from functools import partial
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
from typing import Protocol
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -82,6 +84,8 @@ from onyx.db.engine import get_session_context_manager
|
|||||||
from onyx.db.milestone import check_multi_assistant_milestone
|
from onyx.db.milestone import check_multi_assistant_milestone
|
||||||
from onyx.db.milestone import create_milestone_if_not_exists
|
from onyx.db.milestone import create_milestone_if_not_exists
|
||||||
from onyx.db.milestone import update_user_assistant_milestone
|
from onyx.db.milestone import update_user_assistant_milestone
|
||||||
|
from onyx.db.models import ChatMessage
|
||||||
|
from onyx.db.models import Persona
|
||||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||||
from onyx.db.models import ToolCall
|
from onyx.db.models import ToolCall
|
||||||
from onyx.db.models import User
|
from onyx.db.models import User
|
||||||
@ -159,6 +163,25 @@ from shared_configs.contextvars import get_current_tenant_id
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
ERROR_TYPE_CANCELLED = "cancelled"
|
ERROR_TYPE_CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
COMMON_TOOL_RESPONSE_TYPES = {
|
||||||
|
"image": ChatFileType.IMAGE,
|
||||||
|
"csv": ChatFileType.CSV,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PartialResponse(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
rephrased_query: str | None,
|
||||||
|
reference_docs: list[DbSearchDoc] | None,
|
||||||
|
files: list[FileDescriptor],
|
||||||
|
token_count: int,
|
||||||
|
citations: dict[int, int] | None,
|
||||||
|
error: str | None,
|
||||||
|
tool_call: ToolCall | None,
|
||||||
|
) -> ChatMessage: ...
|
||||||
|
|
||||||
|
|
||||||
def _translate_citations(
|
def _translate_citations(
|
||||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||||
@ -211,25 +234,25 @@ def _handle_search_tool_response_summary(
|
|||||||
reference_db_search_docs = selected_search_docs
|
reference_db_search_docs = selected_search_docs
|
||||||
|
|
||||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||||
if user_files is not None:
|
if user_files is not None and loaded_user_files is not None:
|
||||||
for user_file in user_files:
|
for user_file in user_files:
|
||||||
if user_file.id not in doc_ids:
|
if user_file.id in doc_ids:
|
||||||
associated_chat_file = None
|
continue
|
||||||
if loaded_user_files is not None:
|
|
||||||
associated_chat_file = next(
|
associated_chat_file = next(
|
||||||
(
|
(
|
||||||
file
|
file
|
||||||
for file in loaded_user_files
|
for file in loaded_user_files
|
||||||
if file.file_id == str(user_file.file_id)
|
if file.file_id == str(user_file.file_id)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||||
if associated_chat_file is not None:
|
if associated_chat_file is not None:
|
||||||
db_doc = create_search_doc_from_user_file(
|
db_doc = create_search_doc_from_user_file(
|
||||||
user_file, associated_chat_file, db_session
|
user_file, associated_chat_file, db_session
|
||||||
)
|
)
|
||||||
reference_db_search_docs.append(db_doc)
|
reference_db_search_docs.append(db_doc)
|
||||||
|
|
||||||
response_docs = [
|
response_docs = [
|
||||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||||
@ -357,6 +380,86 @@ def _get_force_search_settings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_knowledge_files(
|
||||||
|
info: AnswerPostInfo,
|
||||||
|
user_files: list[InMemoryChatFile],
|
||||||
|
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||||
|
) -> Generator[UserKnowledgeFilePacket, None, None]:
|
||||||
|
if not info.qa_docs_response:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract document order from search results
|
||||||
|
doc_order = []
|
||||||
|
for doc in info.qa_docs_response.top_documents:
|
||||||
|
doc_id = doc.document_id
|
||||||
|
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||||
|
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||||
|
if file_id in file_id_to_user_file:
|
||||||
|
doc_order.append(file_id)
|
||||||
|
|
||||||
|
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
|
||||||
|
|
||||||
|
# Add any files that weren't in search results at the end
|
||||||
|
missing_files = [
|
||||||
|
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
|
||||||
|
]
|
||||||
|
|
||||||
|
missing_files.extend(doc_order)
|
||||||
|
doc_order = missing_files
|
||||||
|
|
||||||
|
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
|
||||||
|
|
||||||
|
# Reorder user files based on search results
|
||||||
|
ordered_user_files = [
|
||||||
|
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
|
||||||
|
]
|
||||||
|
|
||||||
|
yield UserKnowledgeFilePacket(
|
||||||
|
user_files=[
|
||||||
|
FileDescriptor(
|
||||||
|
id=str(file.file_id),
|
||||||
|
type=ChatFileType.USER_KNOWLEDGE,
|
||||||
|
)
|
||||||
|
for file in ordered_user_files
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_persona_for_chat_session(
|
||||||
|
new_msg_req: CreateChatMessageRequest,
|
||||||
|
user: User | None,
|
||||||
|
db_session: Session,
|
||||||
|
default_persona: Persona,
|
||||||
|
) -> Persona:
|
||||||
|
if new_msg_req.alternate_assistant_id is not None:
|
||||||
|
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||||
|
# this takes highest priority since it's user specified
|
||||||
|
persona = get_persona_by_id(
|
||||||
|
new_msg_req.alternate_assistant_id,
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
is_for_edit=False,
|
||||||
|
)
|
||||||
|
elif new_msg_req.persona_override_config:
|
||||||
|
# Certain endpoints allow users to specify arbitrary persona settings
|
||||||
|
# this should never conflict with the alternate_assistant_id
|
||||||
|
persona = create_temporary_persona(
|
||||||
|
db_session=db_session,
|
||||||
|
persona_config=new_msg_req.persona_override_config,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
persona = default_persona
|
||||||
|
|
||||||
|
if not persona:
|
||||||
|
raise RuntimeError("No persona specified or found for chat session")
|
||||||
|
return persona
|
||||||
|
|
||||||
|
|
||||||
ChatPacket = (
|
ChatPacket = (
|
||||||
StreamingError
|
StreamingError
|
||||||
| QADocsResponse
|
| QADocsResponse
|
||||||
@ -378,6 +481,149 @@ ChatPacket = (
|
|||||||
ChatPacketStream = Iterator[ChatPacket]
|
ChatPacketStream = Iterator[ChatPacket]
|
||||||
|
|
||||||
|
|
||||||
|
def _process_tool_response(
|
||||||
|
packet: ToolResponse,
|
||||||
|
db_session: Session,
|
||||||
|
selected_db_search_docs: list[DbSearchDoc] | None,
|
||||||
|
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||||
|
retrieval_options: RetrievalDetails | None,
|
||||||
|
user_file_files: list[UserFile] | None,
|
||||||
|
user_files: list[InMemoryChatFile] | None,
|
||||||
|
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||||
|
search_for_ordering_only: bool,
|
||||||
|
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
|
||||||
|
level, level_question_num = (
|
||||||
|
(packet.level, packet.level_question_num)
|
||||||
|
if isinstance(packet, ExtendedToolResponse)
|
||||||
|
else BASIC_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
assert level is not None
|
||||||
|
assert level_question_num is not None
|
||||||
|
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
|
||||||
|
|
||||||
|
# Skip LLM relevance processing entirely for ordering-only mode
|
||||||
|
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||||
|
logger.info(
|
||||||
|
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||||
|
)
|
||||||
|
# Skip this packet entirely since it would trigger LLM processing
|
||||||
|
return info_by_subq
|
||||||
|
|
||||||
|
# TODO: don't need to dedupe here when we do it in agent flow
|
||||||
|
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||||
|
if search_for_ordering_only:
|
||||||
|
logger.info(
|
||||||
|
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
info.qa_docs_response,
|
||||||
|
info.reference_db_search_docs,
|
||||||
|
info.dropped_indices,
|
||||||
|
) = _handle_search_tool_response_summary(
|
||||||
|
packet=packet,
|
||||||
|
db_session=db_session,
|
||||||
|
selected_search_docs=selected_db_search_docs,
|
||||||
|
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||||
|
# Skip deduping completely for ordering-only mode to save time
|
||||||
|
dedupe_docs=bool(
|
||||||
|
not search_for_ordering_only
|
||||||
|
and retrieval_options
|
||||||
|
and retrieval_options.dedupe_docs
|
||||||
|
),
|
||||||
|
user_files=user_file_files if search_for_ordering_only else [],
|
||||||
|
loaded_user_files=(user_files if search_for_ordering_only else []),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we're using search just for ordering user files
|
||||||
|
if search_for_ordering_only and user_files:
|
||||||
|
yield from _get_user_knowledge_files(
|
||||||
|
info=info,
|
||||||
|
user_files=user_files,
|
||||||
|
file_id_to_user_file=file_id_to_user_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield info.qa_docs_response
|
||||||
|
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||||
|
relevance_sections = packet.response
|
||||||
|
|
||||||
|
if search_for_ordering_only:
|
||||||
|
logger.info(
|
||||||
|
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||||
|
)
|
||||||
|
return info_by_subq
|
||||||
|
|
||||||
|
if info.reference_db_search_docs is None:
|
||||||
|
logger.warning("No reference docs found for relevance filtering")
|
||||||
|
return info_by_subq
|
||||||
|
|
||||||
|
llm_indices = relevant_sections_to_indices(
|
||||||
|
relevance_sections=relevance_sections,
|
||||||
|
items=[
|
||||||
|
translate_db_search_doc_to_server_search_doc(doc)
|
||||||
|
for doc in info.reference_db_search_docs
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if info.dropped_indices:
|
||||||
|
llm_indices = drop_llm_indices(
|
||||||
|
llm_indices=llm_indices,
|
||||||
|
search_docs=info.reference_db_search_docs,
|
||||||
|
dropped_indices=info.dropped_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
|
||||||
|
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||||
|
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
|
||||||
|
|
||||||
|
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||||
|
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
|
||||||
|
|
||||||
|
file_ids = save_files(
|
||||||
|
urls=[img.url for img in img_generation_response if img.url],
|
||||||
|
base64_files=[
|
||||||
|
img.image_data for img in img_generation_response if img.image_data
|
||||||
|
],
|
||||||
|
)
|
||||||
|
info.ai_message_files.extend(
|
||||||
|
[
|
||||||
|
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||||
|
for file_id in file_ids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||||
|
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||||
|
(
|
||||||
|
info.qa_docs_response,
|
||||||
|
info.reference_db_search_docs,
|
||||||
|
) = _handle_internet_search_tool_response_summary(
|
||||||
|
packet=packet,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
yield info.qa_docs_response
|
||||||
|
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||||
|
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||||
|
response_type = custom_tool_response.response_type
|
||||||
|
if response_type in COMMON_TOOL_RESPONSE_TYPES:
|
||||||
|
file_ids = custom_tool_response.tool_result.file_ids
|
||||||
|
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
|
||||||
|
info.ai_message_files.extend(
|
||||||
|
[
|
||||||
|
FileDescriptor(id=str(file_id), type=file_type)
|
||||||
|
for file_id in file_ids
|
||||||
|
]
|
||||||
|
)
|
||||||
|
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||||
|
else:
|
||||||
|
yield CustomToolResponse(
|
||||||
|
response=custom_tool_response.tool_result,
|
||||||
|
tool_name=custom_tool_response.tool_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return info_by_subq
|
||||||
|
|
||||||
|
|
||||||
def stream_chat_message_objects(
|
def stream_chat_message_objects(
|
||||||
new_msg_req: CreateChatMessageRequest,
|
new_msg_req: CreateChatMessageRequest,
|
||||||
user: User | None,
|
user: User | None,
|
||||||
@ -421,7 +667,6 @@ def stream_chat_message_objects(
|
|||||||
try:
|
try:
|
||||||
# Move these variables inside the try block
|
# Move these variables inside the try block
|
||||||
file_id_to_user_file = {}
|
file_id_to_user_file = {}
|
||||||
ordered_user_files = None
|
|
||||||
|
|
||||||
user_id = user.id if user is not None else None
|
user_id = user.id if user is not None else None
|
||||||
|
|
||||||
@ -436,35 +681,19 @@ def stream_chat_message_objects(
|
|||||||
parent_id = new_msg_req.parent_message_id
|
parent_id = new_msg_req.parent_message_id
|
||||||
reference_doc_ids = new_msg_req.search_doc_ids
|
reference_doc_ids = new_msg_req.search_doc_ids
|
||||||
retrieval_options = new_msg_req.retrieval_options
|
retrieval_options = new_msg_req.retrieval_options
|
||||||
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
new_msg_req.alternate_assistant_id
|
||||||
|
|
||||||
# permanent "log" store, used primarily for debugging
|
# permanent "log" store, used primarily for debugging
|
||||||
long_term_logger = LongTermLogger(
|
long_term_logger = LongTermLogger(
|
||||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||||
)
|
)
|
||||||
|
|
||||||
if alternate_assistant_id is not None:
|
persona = _get_persona_for_chat_session(
|
||||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
new_msg_req=new_msg_req,
|
||||||
# this takes highest priority since it's user specified
|
user=user,
|
||||||
persona = get_persona_by_id(
|
db_session=db_session,
|
||||||
alternate_assistant_id,
|
default_persona=chat_session.persona,
|
||||||
user=user,
|
)
|
||||||
db_session=db_session,
|
|
||||||
is_for_edit=False,
|
|
||||||
)
|
|
||||||
elif new_msg_req.persona_override_config:
|
|
||||||
# Certain endpoints allow users to specify arbitrary persona settings
|
|
||||||
# this should never conflict with the alternate_assistant_id
|
|
||||||
persona = persona = create_temporary_persona(
|
|
||||||
db_session=db_session,
|
|
||||||
persona_config=new_msg_req.persona_override_config,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
persona = chat_session.persona
|
|
||||||
|
|
||||||
if not persona:
|
|
||||||
raise RuntimeError("No persona specified or found for chat session")
|
|
||||||
|
|
||||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||||
user=user,
|
user=user,
|
||||||
@ -746,31 +975,42 @@ def stream_chat_message_objects(
|
|||||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cannot determine these without the LLM step or breaking out early
|
def create_response(
|
||||||
partial_response = partial(
|
message: str,
|
||||||
create_new_chat_message,
|
rephrased_query: str | None,
|
||||||
chat_session_id=chat_session_id,
|
reference_docs: list[DbSearchDoc] | None,
|
||||||
# if we're using an existing assistant message, then this will just be an
|
files: list[FileDescriptor],
|
||||||
# update operation, in which case the parent should be the parent of
|
token_count: int,
|
||||||
# the latest. If we're creating a new assistant message, then the parent
|
citations: dict[int, int] | None,
|
||||||
# should be the latest message (latest user message)
|
error: str | None,
|
||||||
parent_message=(
|
tool_call: ToolCall | None,
|
||||||
final_msg if existing_assistant_message_id is None else parent_message
|
) -> ChatMessage:
|
||||||
),
|
return create_new_chat_message(
|
||||||
prompt_id=prompt_id,
|
chat_session_id=chat_session_id,
|
||||||
overridden_model=overridden_model,
|
parent_message=(
|
||||||
# message=,
|
final_msg
|
||||||
# rephrased_query=,
|
if existing_assistant_message_id is None
|
||||||
# token_count=,
|
else parent_message
|
||||||
message_type=MessageType.ASSISTANT,
|
),
|
||||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
prompt_id=prompt_id,
|
||||||
# error=,
|
overridden_model=overridden_model,
|
||||||
# reference_docs=,
|
message=message,
|
||||||
db_session=db_session,
|
rephrased_query=rephrased_query,
|
||||||
commit=False,
|
token_count=token_count,
|
||||||
reserved_message_id=reserved_message_id,
|
message_type=MessageType.ASSISTANT,
|
||||||
is_agentic=new_msg_req.use_agentic_search,
|
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||||
)
|
error=error,
|
||||||
|
reference_docs=reference_docs,
|
||||||
|
files=files,
|
||||||
|
citations=citations,
|
||||||
|
tool_call=tool_call,
|
||||||
|
db_session=db_session,
|
||||||
|
commit=False,
|
||||||
|
reserved_message_id=reserved_message_id,
|
||||||
|
is_agentic=new_msg_req.use_agentic_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
partial_response = create_response
|
||||||
|
|
||||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||||
if new_msg_req.persona_override_config:
|
if new_msg_req.persona_override_config:
|
||||||
@ -1041,220 +1281,23 @@ def stream_chat_message_objects(
|
|||||||
use_agentic_search=new_msg_req.use_agentic_search,
|
use_agentic_search=new_msg_req.use_agentic_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
# reference_db_search_docs = None
|
|
||||||
# qa_docs_response = None
|
|
||||||
# # any files to associate with the AI message e.g. dall-e generated images
|
|
||||||
# ai_message_files = []
|
|
||||||
# dropped_indices = None
|
|
||||||
# tool_result = None
|
|
||||||
|
|
||||||
# TODO: different channels for stored info when it's coming from the agent flow
|
|
||||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||||
lambda: AnswerPostInfo(ai_message_files=[])
|
lambda: AnswerPostInfo(ai_message_files=[])
|
||||||
)
|
)
|
||||||
refined_answer_improvement = True
|
refined_answer_improvement = True
|
||||||
for packet in answer.processed_streamed_output:
|
for packet in answer.processed_streamed_output:
|
||||||
if isinstance(packet, ToolResponse):
|
if isinstance(packet, ToolResponse):
|
||||||
level, level_question_num = (
|
info_by_subq = yield from _process_tool_response(
|
||||||
(packet.level, packet.level_question_num)
|
packet=packet,
|
||||||
if isinstance(packet, ExtendedToolResponse)
|
db_session=db_session,
|
||||||
else BASIC_KEY
|
selected_db_search_docs=selected_db_search_docs,
|
||||||
|
info_by_subq=info_by_subq,
|
||||||
|
retrieval_options=retrieval_options,
|
||||||
|
user_file_files=user_file_files,
|
||||||
|
user_files=user_files,
|
||||||
|
file_id_to_user_file=file_id_to_user_file,
|
||||||
|
search_for_ordering_only=search_for_ordering_only,
|
||||||
)
|
)
|
||||||
assert level is not None
|
|
||||||
assert level_question_num is not None
|
|
||||||
info = info_by_subq[
|
|
||||||
SubQuestionKey(level=level, question_num=level_question_num)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Skip LLM relevance processing entirely for ordering-only mode
|
|
||||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
|
||||||
logger.info(
|
|
||||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
|
||||||
)
|
|
||||||
# Skip this packet entirely since it would trigger LLM processing
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO: don't need to dedupe here when we do it in agent flow
|
|
||||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
|
||||||
if search_for_ordering_only:
|
|
||||||
logger.info(
|
|
||||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
info.qa_docs_response,
|
|
||||||
info.reference_db_search_docs,
|
|
||||||
info.dropped_indices,
|
|
||||||
) = _handle_search_tool_response_summary(
|
|
||||||
packet=packet,
|
|
||||||
db_session=db_session,
|
|
||||||
selected_search_docs=selected_db_search_docs,
|
|
||||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
|
||||||
# Skip deduping completely for ordering-only mode to save time
|
|
||||||
dedupe_docs=bool(
|
|
||||||
not search_for_ordering_only
|
|
||||||
and retrieval_options
|
|
||||||
and retrieval_options.dedupe_docs
|
|
||||||
),
|
|
||||||
user_files=user_file_files if search_for_ordering_only else [],
|
|
||||||
loaded_user_files=(
|
|
||||||
user_files if search_for_ordering_only else []
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If we're using search just for ordering user files
|
|
||||||
if (
|
|
||||||
search_for_ordering_only
|
|
||||||
and user_files
|
|
||||||
and info.qa_docs_response
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract document order from search results
|
|
||||||
doc_order = []
|
|
||||||
for doc in info.qa_docs_response.top_documents:
|
|
||||||
doc_id = doc.document_id
|
|
||||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
|
||||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
|
||||||
if file_id in file_id_to_user_file:
|
|
||||||
doc_order.append(file_id)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"ORDERING: Found {len(doc_order)} files from search results"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add any files that weren't in search results at the end
|
|
||||||
missing_files = [
|
|
||||||
f_id
|
|
||||||
for f_id in file_id_to_user_file.keys()
|
|
||||||
if f_id not in doc_order
|
|
||||||
]
|
|
||||||
|
|
||||||
missing_files.extend(doc_order)
|
|
||||||
doc_order = missing_files
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"ORDERING: Added {len(missing_files)} missing files to the end"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reorder user files based on search results
|
|
||||||
ordered_user_files = [
|
|
||||||
file_id_to_user_file[f_id]
|
|
||||||
for f_id in doc_order
|
|
||||||
if f_id in file_id_to_user_file
|
|
||||||
]
|
|
||||||
|
|
||||||
yield UserKnowledgeFilePacket(
|
|
||||||
user_files=[
|
|
||||||
FileDescriptor(
|
|
||||||
id=str(file.file_id),
|
|
||||||
type=ChatFileType.USER_KNOWLEDGE,
|
|
||||||
)
|
|
||||||
for file in ordered_user_files
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
yield info.qa_docs_response
|
|
||||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
|
||||||
relevance_sections = packet.response
|
|
||||||
|
|
||||||
if search_for_ordering_only:
|
|
||||||
logger.info(
|
|
||||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if info.reference_db_search_docs is None:
|
|
||||||
logger.warning(
|
|
||||||
"No reference docs found for relevance filtering"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
llm_indices = relevant_sections_to_indices(
|
|
||||||
relevance_sections=relevance_sections,
|
|
||||||
items=[
|
|
||||||
translate_db_search_doc_to_server_search_doc(doc)
|
|
||||||
for doc in info.reference_db_search_docs
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if info.dropped_indices:
|
|
||||||
llm_indices = drop_llm_indices(
|
|
||||||
llm_indices=llm_indices,
|
|
||||||
search_docs=info.reference_db_search_docs,
|
|
||||||
dropped_indices=info.dropped_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield LLMRelevanceFilterResponse(
|
|
||||||
llm_selected_doc_indices=llm_indices
|
|
||||||
)
|
|
||||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
|
||||||
yield FinalUsedContextDocsResponse(
|
|
||||||
final_context_docs=packet.response
|
|
||||||
)
|
|
||||||
|
|
||||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
|
||||||
img_generation_response = cast(
|
|
||||||
list[ImageGenerationResponse], packet.response
|
|
||||||
)
|
|
||||||
|
|
||||||
file_ids = save_files(
|
|
||||||
urls=[img.url for img in img_generation_response if img.url],
|
|
||||||
base64_files=[
|
|
||||||
img.image_data
|
|
||||||
for img in img_generation_response
|
|
||||||
if img.image_data
|
|
||||||
],
|
|
||||||
)
|
|
||||||
info.ai_message_files.extend(
|
|
||||||
[
|
|
||||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
|
||||||
for file_id in file_ids
|
|
||||||
]
|
|
||||||
)
|
|
||||||
yield FileChatDisplay(
|
|
||||||
file_ids=[str(file_id) for file_id in file_ids]
|
|
||||||
)
|
|
||||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
|
||||||
(
|
|
||||||
info.qa_docs_response,
|
|
||||||
info.reference_db_search_docs,
|
|
||||||
) = _handle_internet_search_tool_response_summary(
|
|
||||||
packet=packet,
|
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
yield info.qa_docs_response
|
|
||||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
|
||||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
|
||||||
|
|
||||||
if (
|
|
||||||
custom_tool_response.response_type == "image"
|
|
||||||
or custom_tool_response.response_type == "csv"
|
|
||||||
):
|
|
||||||
file_ids = custom_tool_response.tool_result.file_ids
|
|
||||||
info.ai_message_files.extend(
|
|
||||||
[
|
|
||||||
FileDescriptor(
|
|
||||||
id=str(file_id),
|
|
||||||
type=(
|
|
||||||
ChatFileType.IMAGE
|
|
||||||
if custom_tool_response.response_type == "image"
|
|
||||||
else ChatFileType.CSV
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for file_id in file_ids
|
|
||||||
]
|
|
||||||
)
|
|
||||||
yield FileChatDisplay(
|
|
||||||
file_ids=[str(file_id) for file_id in file_ids]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield CustomToolResponse(
|
|
||||||
response=custom_tool_response.tool_result,
|
|
||||||
tool_name=custom_tool_response.tool_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(packet, StreamStopInfo):
|
elif isinstance(packet, StreamStopInfo):
|
||||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||||
@ -1291,22 +1334,46 @@ def stream_chat_message_objects(
|
|||||||
|
|
||||||
if isinstance(e, ToolCallException):
|
if isinstance(e, ToolCallException):
|
||||||
yield StreamingError(error=error_msg, stack_trace=stack_trace)
|
yield StreamingError(error=error_msg, stack_trace=stack_trace)
|
||||||
else:
|
elif llm:
|
||||||
if llm:
|
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
client_error_msg = client_error_msg.replace(
|
||||||
error_msg = error_msg.replace(
|
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
)
|
||||||
)
|
stack_trace = stack_trace.replace(
|
||||||
stack_trace = stack_trace.replace(
|
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||||
|
|
||||||
db_session.rollback()
|
db_session.rollback()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
yield from _post_llm_answer_processing(
|
||||||
|
answer=answer,
|
||||||
|
info_by_subq=info_by_subq,
|
||||||
|
tool_dict=tool_dict,
|
||||||
|
partial_response=partial_response,
|
||||||
|
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||||
|
db_session=db_session,
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
refined_answer_improvement=refined_answer_improvement,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _post_llm_answer_processing(
|
||||||
|
answer: Answer,
|
||||||
|
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||||
|
tool_dict: dict[int, list[Tool]],
|
||||||
|
partial_response: PartialResponse,
|
||||||
|
llm_tokenizer_encode_func: Callable[[str], list[int]],
|
||||||
|
db_session: Session,
|
||||||
|
chat_session_id: UUID,
|
||||||
|
refined_answer_improvement: bool | None,
|
||||||
|
) -> Generator[ChatPacket, None, None]:
|
||||||
|
"""
|
||||||
|
Stores messages in the db and yields some final packets to the frontend
|
||||||
|
"""
|
||||||
# Post-LLM answer processing
|
# Post-LLM answer processing
|
||||||
try:
|
try:
|
||||||
tool_name_to_tool_id: dict[str, int] = {}
|
tool_name_to_tool_id: dict[str, int] = {}
|
||||||
|
@ -56,7 +56,8 @@ def get_total_users_count(db_session: Session) -> int:
|
|||||||
|
|
||||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||||
async with get_async_session_with_tenant() as session:
|
async with get_async_session_with_tenant() as session:
|
||||||
stmt = select(func.count(User.id))
|
count_stmt = func.count(User.id) # type: ignore
|
||||||
|
stmt = select(count_stmt)
|
||||||
if only_admin_users:
|
if only_admin_users:
|
||||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user