365 lines
13 KiB
Python

import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from danswer.utils.logger import setup_logger
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.danswer_torch_model import ConnectorClassifier
from model_server.danswer_torch_model import HybridClassifier
from model_server.utils import simple_log_function_time
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
"distilbert-base-uncased"
)
return _CONNECTOR_CLASSIFIER_TOKENIZER
def get_local_connector_classifier(
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
) -> ConnectorClassifier:
global _CONNECTOR_CLASSIFIER_MODEL
if _CONNECTOR_CLASSIFIER_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.warning(f"Failed to load model directly: {e}")
try:
# Attempt to download the model snapshot
logger.info(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.error(
f"Failed to load model even after attempted snapshot download: {e}"
)
raise
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> AutoTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name_or_path: str = INTENT_MODEL_VERSION,
tag: str = INTENT_MODEL_TAG,
) -> HybridClassifier:
global _INTENT_MODEL
if _INTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(f"Loading model from local cache: {model_name_or_path}")
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
logger.notice(f"Loaded model from local cache: {local_path}")
except Exception as e:
logger.warning(f"Failed to load model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load model even after attempted snapshot download: {e}"
)
raise
return _INTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
tokenizer: PreTrainedTokenizer,
connector_token_end_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
token and then the user query.
"""
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
for connector in connectors:
connector_token_ids = tokenizer(
connector,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
connector_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([connector_token_end_id], dtype=torch.long),
),
dim=-1,
)
query_token_ids = tokenizer(
query,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
query_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
),
dim=-1,
)
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
def warm_up_connector_classifier_model() -> None:
logger.info(
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
)
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
connector_classifier = get_local_connector_classifier()
input_ids, attention_mask = tokenize_connector_classification_query(
["GitHub"],
"danswer classifier query google doc",
connector_classifier_tokenizer,
connector_classifier.connector_end_token_id,
)
input_ids = input_ids.to(connector_classifier.device)
attention_mask = attention_mask.to(connector_classifier.device)
connector_classifier(input_ids, attention_mask)
def warm_up_intent_model() -> None:
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
intent_tokenizer = get_intent_model_tokenizer()
tokens = intent_tokenizer(
MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
)
intent_model = get_local_intent_model()
device = intent_model.device
intent_model(
query_ids=tokens["input_ids"].to(device),
query_mask=tokens["attention_mask"].to(device),
)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
device = intent_model.device
outputs = intent_model(
query_ids=tokens["input_ids"].to(device),
query_mask=tokens["attention_mask"].to(device),
)
token_logits = outputs["token_logits"]
intent_logits = outputs["intent_logits"]
# Move tensors to CPU before applying softmax and converting to numpy
intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
# Extract the probabilities for the positive class (index 1) for each token
token_positive_probs = token_probabilities[:, 1].tolist()
return intent_probabilities.tolist(), token_positive_probs
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids)
if not len(tokens) == len(is_keyword):
raise ValueError("Length of tokens and keyword predictions must match")
if input_ids[0] == tokenizer.cls_token_id:
tokens = tokens[1:]
is_keyword = is_keyword[1:]
if input_ids[-1] == tokenizer.sep_token_id:
tokens = tokens[:-1]
is_keyword = is_keyword[:-1]
unk_token = tokenizer.unk_token
if unk_token in tokens:
raise ValueError("Unknown token detected in the input")
keywords = []
current_keyword = ""
for ind, token in enumerate(tokens):
if is_keyword[ind]:
if token.startswith("##"):
current_keyword += token[2:]
else:
if current_keyword:
keywords.append(current_keyword)
current_keyword = token
else:
# If mispredicted a later token of a keyword, add it to the current keyword
# to complete it
if current_keyword:
if len(current_keyword) > 2 and current_keyword.startswith("##"):
current_keyword = current_keyword[2:]
else:
keywords.append(current_keyword)
current_keyword = ""
if current_keyword:
keywords.append(current_keyword)
return keywords
def clean_keywords(keywords: list[str]) -> list[str]:
cleaned_words = []
for word in keywords:
word = word[:-2] if word.endswith("'s") else word
word = word.replace("/", " ")
word = word.replace("'", "").replace('"', "")
cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()])
return cleaned_words
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
tokenizer = get_connector_classifier_tokenizer()
model = get_local_connector_classifier()
connector_names = req.available_connectors
input_ids, attention_mask = tokenize_connector_classification_query(
connector_names,
req.query,
tokenizer,
model.connector_end_token_id,
)
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
global_confidence, classifier_confidence = model(input_ids, attention_mask)
if global_confidence.item() < 0.5:
return []
passed_connectors = []
for i, connector_name in enumerate(connector_names):
if classifier_confidence.view(-1)[i].item() > 0.5:
passed_connectors.append(connector_name)
return passed_connectors
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
tokenizer = get_intent_model_tokenizer()
model_input = tokenizer(
intent_req.query, return_tensors="pt", truncation=False, padding=False
)
if len(model_input.input_ids[0]) > 512:
# If the user text is too long, assume it is semantic and keep all words
return True, intent_req.query.split()
intent_probs, token_probs = run_inference(model_input)
is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold
keyword_preds = [
token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs
]
try:
keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds)
except Exception as e:
logger.error(
f"Failed to extract keywords for query: {intent_req.query} due to {e}"
)
# Fallback to keeping all words
keywords = intent_req.query.split()
cleaned_keywords = clean_keywords(keywords)
return is_keyword_sequence, cleaned_keywords
@router.post("/connector-classification")
async def process_connector_classification_request(
classification_request: ConnectorClassificationRequest,
) -> ConnectorClassificationResponse:
if INDEXING_ONLY:
raise RuntimeError(
"Indexing model server should not call connector classification endpoint"
)
if len(classification_request.available_connectors) == 0:
return ConnectorClassificationResponse(connectors=[])
connectors = run_connector_classification(classification_request)
return ConnectorClassificationResponse(connectors=connectors)
@router.post("/query-analysis")
async def process_analysis_request(
intent_request: IntentRequest,
) -> IntentResponse:
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)