Model inference for connector classifier on queries (#2137)

This commit is contained in:
Shukant Pal
2024-09-08 14:46:00 -07:00
committed by GitHub
parent 3fa9676478
commit 362156f97e
7 changed files with 314 additions and 9 deletions

View File

@@ -88,3 +88,6 @@ HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)

View File

@@ -24,6 +24,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
@@ -301,6 +303,37 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords
class ConnectorClassificationModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
):
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.connector_classification_endpoint = (
model_server_url + "/custom/connector-classification"
)
def predict(
self,
query: str,
available_connectors: list[str],
) -> list[str]:
connector_classification_request = ConnectorClassificationRequest(
available_connectors=available_connectors,
query=query,
)
response = requests.post(
self.connector_classification_endpoint,
json=connector_classification_request.dict(),
)
response.raise_for_status()
response_model = ConnectorClassificationResponse(**response.json())
return response_model.connectors
def warm_up_retry(
func: Callable[..., Any],
tries: int = 20,

View File

@@ -3,12 +3,16 @@ import random
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import ENABLE_CONNECTOR_CLASSIFIER
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_unique_document_sources
from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.natural_language_processing.search_nlp_models import (
ConnectorClassificationModel,
)
from danswer.prompts.constants import SOURCES_KEY
from danswer.prompts.filter_extration import FILE_SOURCE_WARNING
from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT
@@ -42,11 +46,38 @@ def _sample_document_sources(
return random.sample(valid_sources, num_sample)
def _sample_documents_using_custom_connector_classifier(
query: str,
valid_sources: list[DocumentSource],
) -> list[DocumentSource] | None:
query_joined = "".join(ch for ch in query.lower() if ch.isalnum())
available_connectors = list(
filter(
lambda conn: conn.lower() in query_joined,
[item.value for item in valid_sources],
)
)
if not available_connectors:
return None
connectors = ConnectorClassificationModel().predict(query, available_connectors)
return strings_to_document_sources(connectors) if connectors else None
def extract_source_filter(
query: str, llm: LLM, db_session: Session
) -> list[DocumentSource] | None:
"""Returns a list of valid sources for search or None if no specific sources were detected"""
valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources:
return None
if ENABLE_CONNECTOR_CLASSIFIER:
return _sample_documents_using_custom_connector_classifier(query, valid_sources)
def _get_source_filter_messages(
query: str,
valid_sources: list[DocumentSource],
@@ -146,10 +177,6 @@ def extract_source_filter(
logger.warning("LLM failed to provide a valid Source Filter output")
return None
valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources:
return None
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))

View File

@@ -3,15 +3,21 @@ import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from danswer.utils.logger import setup_logger
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.danswer_torch_model import ConnectorClassifier
from model_server.danswer_torch_model import HybridClassifier
from model_server.utils import simple_log_function_time
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
@@ -19,10 +25,55 @@ logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
"distilbert-base-uncased"
)
return _CONNECTOR_CLASSIFIER_TOKENIZER
def get_local_connector_classifier(
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
) -> ConnectorClassifier:
global _CONNECTOR_CLASSIFIER_MODEL
if _CONNECTOR_CLASSIFIER_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.warning(f"Failed to load model directly: {e}")
try:
# Attempt to download the model snapshot
logger.info(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.error(
f"Failed to load model even after attempted snapshot download: {e}"
)
raise
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> AutoTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
@@ -61,6 +112,74 @@ def get_local_intent_model(
return _INTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
tokenizer: PreTrainedTokenizer,
connector_token_end_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
token and then the user query.
"""
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
for connector in connectors:
connector_token_ids = tokenizer(
connector,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
connector_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([connector_token_end_id], dtype=torch.long),
),
dim=-1,
)
query_token_ids = tokenizer(
query,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
query_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
),
dim=-1,
)
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
def warm_up_connector_classifier_model() -> None:
logger.info(
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
)
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
connector_classifier = get_local_connector_classifier()
input_ids, attention_mask = tokenize_connector_classification_query(
["GitHub"],
"danswer classifier query google doc",
connector_classifier_tokenizer,
connector_classifier.connector_end_token_id,
)
input_ids = input_ids.to(connector_classifier.device)
attention_mask = attention_mask.to(connector_classifier.device)
connector_classifier(input_ids, attention_mask)
def warm_up_intent_model() -> None:
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
intent_tokenizer = get_intent_model_tokenizer()
@@ -157,6 +276,35 @@ def clean_keywords(keywords: list[str]) -> list[str]:
return cleaned_words
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
tokenizer = get_connector_classifier_tokenizer()
model = get_local_connector_classifier()
connector_names = req.available_connectors
input_ids, attention_mask = tokenize_connector_classification_query(
connector_names,
req.query,
tokenizer,
model.connector_end_token_id,
)
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
global_confidence, classifier_confidence = model(input_ids, attention_mask)
if global_confidence.item() < 0.5:
return []
passed_connectors = []
for i, connector_name in enumerate(connector_names):
if classifier_confidence.view(-1)[i].item() > 0.5:
passed_connectors.append(connector_name)
return passed_connectors
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
tokenizer = get_intent_model_tokenizer()
model_input = tokenizer(
@@ -189,6 +337,22 @@ def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
return is_keyword_sequence, cleaned_keywords
@router.post("/connector-classification")
async def process_connector_classification_request(
classification_request: ConnectorClassificationRequest,
) -> ConnectorClassificationResponse:
if INDEXING_ONLY:
raise RuntimeError(
"Indexing model server should not call connector classification endpoint"
)
if len(classification_request.available_connectors) == 0:
return ConnectorClassificationResponse(connectors=[])
connectors = run_connector_classification(classification_request)
return ConnectorClassificationResponse(connectors=connectors)
@router.post("/query-analysis")
async def process_analysis_request(
intent_request: IntentRequest,

View File

@@ -4,7 +4,8 @@ import os
import torch
import torch.nn as nn
from transformers import DistilBertConfig # type: ignore
from transformers import DistilBertModel
from transformers import DistilBertModel # type: ignore
from transformers import DistilBertTokenizer # type: ignore
class HybridClassifier(nn.Module):
@@ -21,7 +22,6 @@ class HybridClassifier(nn.Module):
self.distilbert.config.dim, self.distilbert.config.dim
)
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
self.device = torch.device("cpu")
@@ -36,8 +36,7 @@ class HybridClassifier(nn.Module):
# Intent classification on the CLS token
cls_token_state = sequence_output[:, 0, :]
pre_classifier_out = self.pre_classifier(cls_token_state)
dropout_out = self.dropout(pre_classifier_out)
intent_logits = self.intent_classifier(dropout_out)
intent_logits = self.intent_classifier(pre_classifier_out)
# Keyword classification on all tokens
token_logits = self.keyword_classifier(sequence_output)
@@ -72,3 +71,70 @@ class HybridClassifier(nn.Module):
param.requires_grad = False
return model
class ConnectorClassifier(nn.Module):
def __init__(self, config: DistilBertConfig) -> None:
super().__init__()
self.config = config
self.distilbert = DistilBertModel(config)
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Token indicating end of connector name, and on which classifier is used
self.connector_end_token_id = self.tokenizer.get_vocab()[
self.config.connector_end_token
]
self.device = torch.device("cpu")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
cls_hidden_states = hidden_states[
:, 0, :
] # Take leap of faith that first token is always [CLS]
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
global_confidence = torch.sigmoid(global_logits).view(-1)
connector_end_position_ids = input_ids == self.connector_end_token_id
connector_end_hidden_states = hidden_states[connector_end_position_ids]
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
return global_confidence, classifier_confidence
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
state_dict = torch.load(
os.path.join(repo_dir, "pytorch_model.pt"),
map_location=device,
weights_only=True,
)
model = cls(config)
model.load_state_dict(state_dict)
model.to(device)
model.device = device
model.eval()
for param in model.parameters():
param.requires_grad = False
return model

View File

@@ -16,9 +16,12 @@ INDEXING_MODEL_SERVER_PORT = int(
)
# Danswer custom Deep Learning Models
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512

View File

@@ -7,6 +7,15 @@ from shared_configs.enums import RerankerProvider
Embedding = list[float]
class ConnectorClassificationRequest(BaseModel):
available_connectors: list[str]
query: str
class ConnectorClassificationResponse(BaseModel):
connectors: list[str]
class EmbedRequest(BaseModel):
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases