mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-21 14:12:42 +02:00
Model inference for connector classifier on queries (#2137)
This commit is contained in:
@@ -88,3 +88,6 @@ HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
@@ -24,6 +24,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
@@ -301,6 +303,37 @@ class QueryAnalysisModel:
|
||||
return response_model.is_keyword, response_model.keywords
|
||||
|
||||
|
||||
class ConnectorClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
):
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.connector_classification_endpoint = (
|
||||
model_server_url + "/custom/connector-classification"
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
query: str,
|
||||
available_connectors: list[str],
|
||||
) -> list[str]:
|
||||
connector_classification_request = ConnectorClassificationRequest(
|
||||
available_connectors=available_connectors,
|
||||
query=query,
|
||||
)
|
||||
response = requests.post(
|
||||
self.connector_classification_endpoint,
|
||||
json=connector_classification_request.dict(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_model = ConnectorClassificationResponse(**response.json())
|
||||
|
||||
return response_model.connectors
|
||||
|
||||
|
||||
def warm_up_retry(
|
||||
func: Callable[..., Any],
|
||||
tries: int = 20,
|
||||
|
@@ -3,12 +3,16 @@ import random
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import ENABLE_CONNECTOR_CLASSIFIER
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_unique_document_sources
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.natural_language_processing.search_nlp_models import (
|
||||
ConnectorClassificationModel,
|
||||
)
|
||||
from danswer.prompts.constants import SOURCES_KEY
|
||||
from danswer.prompts.filter_extration import FILE_SOURCE_WARNING
|
||||
from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT
|
||||
@@ -42,11 +46,38 @@ def _sample_document_sources(
|
||||
return random.sample(valid_sources, num_sample)
|
||||
|
||||
|
||||
def _sample_documents_using_custom_connector_classifier(
|
||||
query: str,
|
||||
valid_sources: list[DocumentSource],
|
||||
) -> list[DocumentSource] | None:
|
||||
query_joined = "".join(ch for ch in query.lower() if ch.isalnum())
|
||||
available_connectors = list(
|
||||
filter(
|
||||
lambda conn: conn.lower() in query_joined,
|
||||
[item.value for item in valid_sources],
|
||||
)
|
||||
)
|
||||
|
||||
if not available_connectors:
|
||||
return None
|
||||
|
||||
connectors = ConnectorClassificationModel().predict(query, available_connectors)
|
||||
|
||||
return strings_to_document_sources(connectors) if connectors else None
|
||||
|
||||
|
||||
def extract_source_filter(
|
||||
query: str, llm: LLM, db_session: Session
|
||||
) -> list[DocumentSource] | None:
|
||||
"""Returns a list of valid sources for search or None if no specific sources were detected"""
|
||||
|
||||
valid_sources = fetch_unique_document_sources(db_session)
|
||||
if not valid_sources:
|
||||
return None
|
||||
|
||||
if ENABLE_CONNECTOR_CLASSIFIER:
|
||||
return _sample_documents_using_custom_connector_classifier(query, valid_sources)
|
||||
|
||||
def _get_source_filter_messages(
|
||||
query: str,
|
||||
valid_sources: list[DocumentSource],
|
||||
@@ -146,10 +177,6 @@ def extract_source_filter(
|
||||
logger.warning("LLM failed to provide a valid Source Filter output")
|
||||
return None
|
||||
|
||||
valid_sources = fetch_unique_document_sources(db_session)
|
||||
if not valid_sources:
|
||||
return None
|
||||
|
||||
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
|
@@ -3,15 +3,21 @@ import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.danswer_torch_model import ConnectorClassifier
|
||||
from model_server.danswer_torch_model import HybridClassifier
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
@@ -19,10 +25,55 @@ logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
|
||||
"distilbert-base-uncased"
|
||||
)
|
||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
|
||||
def get_local_connector_classifier(
|
||||
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
|
||||
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
|
||||
) -> ConnectorClassifier:
|
||||
global _CONNECTOR_CLASSIFIER_MODEL
|
||||
if _CONNECTOR_CLASSIFIER_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.info(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
@@ -61,6 +112,74 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
connector_token_end_id: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
|
||||
|
||||
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
|
||||
token and then the user query.
|
||||
"""
|
||||
|
||||
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
|
||||
|
||||
for connector in connectors:
|
||||
connector_token_ids = tokenizer(
|
||||
connector,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
connector_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([connector_token_end_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
query_token_ids = tokenizer(
|
||||
query,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
query_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
|
||||
|
||||
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
|
||||
|
||||
|
||||
def warm_up_connector_classifier_model() -> None:
|
||||
logger.info(
|
||||
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
|
||||
)
|
||||
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
|
||||
connector_classifier = get_local_connector_classifier()
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
["GitHub"],
|
||||
"danswer classifier query google doc",
|
||||
connector_classifier_tokenizer,
|
||||
connector_classifier.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(connector_classifier.device)
|
||||
attention_mask = attention_mask.to(connector_classifier.device)
|
||||
|
||||
connector_classifier(input_ids, attention_mask)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
@@ -157,6 +276,35 @@ def clean_keywords(keywords: list[str]) -> list[str]:
|
||||
return cleaned_words
|
||||
|
||||
|
||||
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
|
||||
tokenizer = get_connector_classifier_tokenizer()
|
||||
model = get_local_connector_classifier()
|
||||
|
||||
connector_names = req.available_connectors
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
connector_names,
|
||||
req.query,
|
||||
tokenizer,
|
||||
model.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(model.device)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
global_confidence, classifier_confidence = model(input_ids, attention_mask)
|
||||
|
||||
if global_confidence.item() < 0.5:
|
||||
return []
|
||||
|
||||
passed_connectors = []
|
||||
|
||||
for i, connector_name in enumerate(connector_names):
|
||||
if classifier_confidence.view(-1)[i].item() > 0.5:
|
||||
passed_connectors.append(connector_name)
|
||||
|
||||
return passed_connectors
|
||||
|
||||
|
||||
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
model_input = tokenizer(
|
||||
@@ -189,6 +337,22 @@ def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
return is_keyword_sequence, cleaned_keywords
|
||||
|
||||
|
||||
@router.post("/connector-classification")
|
||||
async def process_connector_classification_request(
|
||||
classification_request: ConnectorClassificationRequest,
|
||||
) -> ConnectorClassificationResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError(
|
||||
"Indexing model server should not call connector classification endpoint"
|
||||
)
|
||||
|
||||
if len(classification_request.available_connectors) == 0:
|
||||
return ConnectorClassificationResponse(connectors=[])
|
||||
|
||||
connectors = run_connector_classification(classification_request)
|
||||
return ConnectorClassificationResponse(connectors=connectors)
|
||||
|
||||
|
||||
@router.post("/query-analysis")
|
||||
async def process_analysis_request(
|
||||
intent_request: IntentRequest,
|
||||
|
@@ -4,7 +4,8 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
from transformers import DistilBertModel
|
||||
from transformers import DistilBertModel # type: ignore
|
||||
from transformers import DistilBertTokenizer # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
@@ -21,7 +22,6 @@ class HybridClassifier(nn.Module):
|
||||
self.distilbert.config.dim, self.distilbert.config.dim
|
||||
)
|
||||
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@@ -36,8 +36,7 @@ class HybridClassifier(nn.Module):
|
||||
# Intent classification on the CLS token
|
||||
cls_token_state = sequence_output[:, 0, :]
|
||||
pre_classifier_out = self.pre_classifier(cls_token_state)
|
||||
dropout_out = self.dropout(pre_classifier_out)
|
||||
intent_logits = self.intent_classifier(dropout_out)
|
||||
intent_logits = self.intent_classifier(pre_classifier_out)
|
||||
|
||||
# Keyword classification on all tokens
|
||||
token_logits = self.keyword_classifier(sequence_output)
|
||||
@@ -72,3 +71,70 @@ class HybridClassifier(nn.Module):
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ConnectorClassifier(nn.Module):
|
||||
def __init__(self, config: DistilBertConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.distilbert = DistilBertModel(config)
|
||||
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# Token indicating end of connector name, and on which classifier is used
|
||||
self.connector_end_token_id = self.tokenizer.get_vocab()[
|
||||
self.config.connector_end_token
|
||||
]
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
cls_hidden_states = hidden_states[
|
||||
:, 0, :
|
||||
] # Take leap of faith that first token is always [CLS]
|
||||
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
|
||||
global_confidence = torch.sigmoid(global_logits).view(-1)
|
||||
|
||||
connector_end_position_ids = input_ids == self.connector_end_token_id
|
||||
connector_end_hidden_states = hidden_states[connector_end_position_ids]
|
||||
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
|
||||
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
|
||||
|
||||
return global_confidence, classifier_confidence
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
else torch.device("mps")
|
||||
if torch.backends.mps.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
state_dict = torch.load(
|
||||
os.path.join(repo_dir, "pytorch_model.pt"),
|
||||
map_location=device,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
model = cls(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
model.device = device
|
||||
model.eval()
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
@@ -16,9 +16,12 @@ INDEXING_MODEL_SERVER_PORT = int(
|
||||
)
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
|
||||
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
||||
|
@@ -7,6 +7,15 @@ from shared_configs.enums import RerankerProvider
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
class ConnectorClassificationRequest(BaseModel):
|
||||
available_connectors: list[str]
|
||||
query: str
|
||||
|
||||
|
||||
class ConnectorClassificationResponse(BaseModel):
|
||||
connectors: list[str]
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
texts: list[str]
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
|
Reference in New Issue
Block a user