mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Model inference for connector classifier on queries (#2137)
This commit is contained in:
@@ -88,3 +88,6 @@ HARD_DELETE_CHATS = False
|
|||||||
|
|
||||||
# Internet Search
|
# Internet Search
|
||||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||||
|
|
||||||
|
# Enable in-house model for detecting connector-based filtering in queries
|
||||||
|
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||||
|
@@ -24,6 +24,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
|||||||
from shared_configs.enums import EmbeddingProvider
|
from shared_configs.enums import EmbeddingProvider
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
from shared_configs.enums import RerankerProvider
|
from shared_configs.enums import RerankerProvider
|
||||||
|
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||||
|
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||||
from shared_configs.model_server_models import Embedding
|
from shared_configs.model_server_models import Embedding
|
||||||
from shared_configs.model_server_models import EmbedRequest
|
from shared_configs.model_server_models import EmbedRequest
|
||||||
from shared_configs.model_server_models import EmbedResponse
|
from shared_configs.model_server_models import EmbedResponse
|
||||||
@@ -301,6 +303,37 @@ class QueryAnalysisModel:
|
|||||||
return response_model.is_keyword, response_model.keywords
|
return response_model.is_keyword, response_model.keywords
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorClassificationModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_server_host: str = MODEL_SERVER_HOST,
|
||||||
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
|
):
|
||||||
|
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||||
|
self.connector_classification_endpoint = (
|
||||||
|
model_server_url + "/custom/connector-classification"
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
available_connectors: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
connector_classification_request = ConnectorClassificationRequest(
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
response = requests.post(
|
||||||
|
self.connector_classification_endpoint,
|
||||||
|
json=connector_classification_request.dict(),
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_model = ConnectorClassificationResponse(**response.json())
|
||||||
|
|
||||||
|
return response_model.connectors
|
||||||
|
|
||||||
|
|
||||||
def warm_up_retry(
|
def warm_up_retry(
|
||||||
func: Callable[..., Any],
|
func: Callable[..., Any],
|
||||||
tries: int = 20,
|
tries: int = 20,
|
||||||
|
@@ -3,12 +3,16 @@ import random
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.chat_configs import ENABLE_CONNECTOR_CLASSIFIER
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.db.connector import fetch_unique_document_sources
|
from danswer.db.connector import fetch_unique_document_sources
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.llm.utils import message_to_string
|
from danswer.llm.utils import message_to_string
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import (
|
||||||
|
ConnectorClassificationModel,
|
||||||
|
)
|
||||||
from danswer.prompts.constants import SOURCES_KEY
|
from danswer.prompts.constants import SOURCES_KEY
|
||||||
from danswer.prompts.filter_extration import FILE_SOURCE_WARNING
|
from danswer.prompts.filter_extration import FILE_SOURCE_WARNING
|
||||||
from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT
|
from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT
|
||||||
@@ -42,11 +46,38 @@ def _sample_document_sources(
|
|||||||
return random.sample(valid_sources, num_sample)
|
return random.sample(valid_sources, num_sample)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_documents_using_custom_connector_classifier(
|
||||||
|
query: str,
|
||||||
|
valid_sources: list[DocumentSource],
|
||||||
|
) -> list[DocumentSource] | None:
|
||||||
|
query_joined = "".join(ch for ch in query.lower() if ch.isalnum())
|
||||||
|
available_connectors = list(
|
||||||
|
filter(
|
||||||
|
lambda conn: conn.lower() in query_joined,
|
||||||
|
[item.value for item in valid_sources],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not available_connectors:
|
||||||
|
return None
|
||||||
|
|
||||||
|
connectors = ConnectorClassificationModel().predict(query, available_connectors)
|
||||||
|
|
||||||
|
return strings_to_document_sources(connectors) if connectors else None
|
||||||
|
|
||||||
|
|
||||||
def extract_source_filter(
|
def extract_source_filter(
|
||||||
query: str, llm: LLM, db_session: Session
|
query: str, llm: LLM, db_session: Session
|
||||||
) -> list[DocumentSource] | None:
|
) -> list[DocumentSource] | None:
|
||||||
"""Returns a list of valid sources for search or None if no specific sources were detected"""
|
"""Returns a list of valid sources for search or None if no specific sources were detected"""
|
||||||
|
|
||||||
|
valid_sources = fetch_unique_document_sources(db_session)
|
||||||
|
if not valid_sources:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if ENABLE_CONNECTOR_CLASSIFIER:
|
||||||
|
return _sample_documents_using_custom_connector_classifier(query, valid_sources)
|
||||||
|
|
||||||
def _get_source_filter_messages(
|
def _get_source_filter_messages(
|
||||||
query: str,
|
query: str,
|
||||||
valid_sources: list[DocumentSource],
|
valid_sources: list[DocumentSource],
|
||||||
@@ -146,10 +177,6 @@ def extract_source_filter(
|
|||||||
logger.warning("LLM failed to provide a valid Source Filter output")
|
logger.warning("LLM failed to provide a valid Source Filter output")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
valid_sources = fetch_unique_document_sources(db_session)
|
|
||||||
if not valid_sources:
|
|
||||||
return None
|
|
||||||
|
|
||||||
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||||
|
@@ -3,15 +3,21 @@ import torch.nn.functional as F
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from huggingface_hub import snapshot_download # type: ignore
|
from huggingface_hub import snapshot_download # type: ignore
|
||||||
from transformers import AutoTokenizer # type: ignore
|
from transformers import AutoTokenizer # type: ignore
|
||||||
from transformers import BatchEncoding
|
from transformers import BatchEncoding # type: ignore
|
||||||
|
from transformers import PreTrainedTokenizer # type: ignore
|
||||||
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from model_server.constants import MODEL_WARM_UP_STRING
|
from model_server.constants import MODEL_WARM_UP_STRING
|
||||||
|
from model_server.danswer_torch_model import ConnectorClassifier
|
||||||
from model_server.danswer_torch_model import HybridClassifier
|
from model_server.danswer_torch_model import HybridClassifier
|
||||||
from model_server.utils import simple_log_function_time
|
from model_server.utils import simple_log_function_time
|
||||||
|
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||||
|
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||||
from shared_configs.configs import INDEXING_ONLY
|
from shared_configs.configs import INDEXING_ONLY
|
||||||
from shared_configs.configs import INTENT_MODEL_TAG
|
from shared_configs.configs import INTENT_MODEL_TAG
|
||||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||||
|
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||||
|
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||||
from shared_configs.model_server_models import IntentRequest
|
from shared_configs.model_server_models import IntentRequest
|
||||||
from shared_configs.model_server_models import IntentResponse
|
from shared_configs.model_server_models import IntentResponse
|
||||||
|
|
||||||
@@ -19,10 +25,55 @@ logger = setup_logger()
|
|||||||
|
|
||||||
router = APIRouter(prefix="/custom")
|
router = APIRouter(prefix="/custom")
|
||||||
|
|
||||||
|
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
|
||||||
|
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||||
|
|
||||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||||
_INTENT_MODEL: HybridClassifier | None = None
|
_INTENT_MODEL: HybridClassifier | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||||
|
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||||
|
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||||
|
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||||
|
# unmodified distilbert tokenizer.
|
||||||
|
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
|
||||||
|
"distilbert-base-uncased"
|
||||||
|
)
|
||||||
|
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_connector_classifier(
|
||||||
|
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
|
||||||
|
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
|
||||||
|
) -> ConnectorClassifier:
|
||||||
|
global _CONNECTOR_CLASSIFIER_MODEL
|
||||||
|
if _CONNECTOR_CLASSIFIER_MODEL is None:
|
||||||
|
try:
|
||||||
|
# Calculate where the cache should be, then load from local if available
|
||||||
|
local_path = snapshot_download(
|
||||||
|
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||||
|
)
|
||||||
|
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||||
|
local_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load model directly: {e}")
|
||||||
|
try:
|
||||||
|
# Attempt to download the model snapshot
|
||||||
|
logger.info(f"Downloading model snapshot for {model_name_or_path}")
|
||||||
|
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||||
|
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||||
|
local_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to load model even after attempted snapshot download: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
return _CONNECTOR_CLASSIFIER_MODEL
|
||||||
|
|
||||||
|
|
||||||
def get_intent_model_tokenizer() -> AutoTokenizer:
|
def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||||
global _INTENT_TOKENIZER
|
global _INTENT_TOKENIZER
|
||||||
if _INTENT_TOKENIZER is None:
|
if _INTENT_TOKENIZER is None:
|
||||||
@@ -61,6 +112,74 @@ def get_local_intent_model(
|
|||||||
return _INTENT_MODEL
|
return _INTENT_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_connector_classification_query(
|
||||||
|
connectors: list[str],
|
||||||
|
query: str,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
connector_token_end_id: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
|
||||||
|
|
||||||
|
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
|
||||||
|
token and then the user query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
|
||||||
|
|
||||||
|
for connector in connectors:
|
||||||
|
connector_token_ids = tokenizer(
|
||||||
|
connector,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.cat(
|
||||||
|
(
|
||||||
|
input_ids,
|
||||||
|
connector_token_ids["input_ids"].squeeze(dim=0),
|
||||||
|
torch.tensor([connector_token_end_id], dtype=torch.long),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
query_token_ids = tokenizer(
|
||||||
|
query,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.cat(
|
||||||
|
(
|
||||||
|
input_ids,
|
||||||
|
query_token_ids["input_ids"].squeeze(dim=0),
|
||||||
|
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
|
||||||
|
|
||||||
|
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_connector_classifier_model() -> None:
|
||||||
|
logger.info(
|
||||||
|
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
|
||||||
|
)
|
||||||
|
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
|
||||||
|
connector_classifier = get_local_connector_classifier()
|
||||||
|
|
||||||
|
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||||
|
["GitHub"],
|
||||||
|
"danswer classifier query google doc",
|
||||||
|
connector_classifier_tokenizer,
|
||||||
|
connector_classifier.connector_end_token_id,
|
||||||
|
)
|
||||||
|
input_ids = input_ids.to(connector_classifier.device)
|
||||||
|
attention_mask = attention_mask.to(connector_classifier.device)
|
||||||
|
|
||||||
|
connector_classifier(input_ids, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
def warm_up_intent_model() -> None:
|
def warm_up_intent_model() -> None:
|
||||||
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
||||||
intent_tokenizer = get_intent_model_tokenizer()
|
intent_tokenizer = get_intent_model_tokenizer()
|
||||||
@@ -157,6 +276,35 @@ def clean_keywords(keywords: list[str]) -> list[str]:
|
|||||||
return cleaned_words
|
return cleaned_words
|
||||||
|
|
||||||
|
|
||||||
|
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
|
||||||
|
tokenizer = get_connector_classifier_tokenizer()
|
||||||
|
model = get_local_connector_classifier()
|
||||||
|
|
||||||
|
connector_names = req.available_connectors
|
||||||
|
|
||||||
|
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||||
|
connector_names,
|
||||||
|
req.query,
|
||||||
|
tokenizer,
|
||||||
|
model.connector_end_token_id,
|
||||||
|
)
|
||||||
|
input_ids = input_ids.to(model.device)
|
||||||
|
attention_mask = attention_mask.to(model.device)
|
||||||
|
|
||||||
|
global_confidence, classifier_confidence = model(input_ids, attention_mask)
|
||||||
|
|
||||||
|
if global_confidence.item() < 0.5:
|
||||||
|
return []
|
||||||
|
|
||||||
|
passed_connectors = []
|
||||||
|
|
||||||
|
for i, connector_name in enumerate(connector_names):
|
||||||
|
if classifier_confidence.view(-1)[i].item() > 0.5:
|
||||||
|
passed_connectors.append(connector_name)
|
||||||
|
|
||||||
|
return passed_connectors
|
||||||
|
|
||||||
|
|
||||||
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||||
tokenizer = get_intent_model_tokenizer()
|
tokenizer = get_intent_model_tokenizer()
|
||||||
model_input = tokenizer(
|
model_input = tokenizer(
|
||||||
@@ -189,6 +337,22 @@ def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
|||||||
return is_keyword_sequence, cleaned_keywords
|
return is_keyword_sequence, cleaned_keywords
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/connector-classification")
|
||||||
|
async def process_connector_classification_request(
|
||||||
|
classification_request: ConnectorClassificationRequest,
|
||||||
|
) -> ConnectorClassificationResponse:
|
||||||
|
if INDEXING_ONLY:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Indexing model server should not call connector classification endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(classification_request.available_connectors) == 0:
|
||||||
|
return ConnectorClassificationResponse(connectors=[])
|
||||||
|
|
||||||
|
connectors = run_connector_classification(classification_request)
|
||||||
|
return ConnectorClassificationResponse(connectors=connectors)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/query-analysis")
|
@router.post("/query-analysis")
|
||||||
async def process_analysis_request(
|
async def process_analysis_request(
|
||||||
intent_request: IntentRequest,
|
intent_request: IntentRequest,
|
||||||
|
@@ -4,7 +4,8 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import DistilBertConfig # type: ignore
|
from transformers import DistilBertConfig # type: ignore
|
||||||
from transformers import DistilBertModel
|
from transformers import DistilBertModel # type: ignore
|
||||||
|
from transformers import DistilBertTokenizer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class HybridClassifier(nn.Module):
|
class HybridClassifier(nn.Module):
|
||||||
@@ -21,7 +22,6 @@ class HybridClassifier(nn.Module):
|
|||||||
self.distilbert.config.dim, self.distilbert.config.dim
|
self.distilbert.config.dim, self.distilbert.config.dim
|
||||||
)
|
)
|
||||||
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||||
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
|
|
||||||
|
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
@@ -36,8 +36,7 @@ class HybridClassifier(nn.Module):
|
|||||||
# Intent classification on the CLS token
|
# Intent classification on the CLS token
|
||||||
cls_token_state = sequence_output[:, 0, :]
|
cls_token_state = sequence_output[:, 0, :]
|
||||||
pre_classifier_out = self.pre_classifier(cls_token_state)
|
pre_classifier_out = self.pre_classifier(cls_token_state)
|
||||||
dropout_out = self.dropout(pre_classifier_out)
|
intent_logits = self.intent_classifier(pre_classifier_out)
|
||||||
intent_logits = self.intent_classifier(dropout_out)
|
|
||||||
|
|
||||||
# Keyword classification on all tokens
|
# Keyword classification on all tokens
|
||||||
token_logits = self.keyword_classifier(sequence_output)
|
token_logits = self.keyword_classifier(sequence_output)
|
||||||
@@ -72,3 +71,70 @@ class HybridClassifier(nn.Module):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorClassifier(nn.Module):
|
||||||
|
def __init__(self, config: DistilBertConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.distilbert = DistilBertModel(config)
|
||||||
|
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||||
|
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||||
|
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||||
|
|
||||||
|
# Token indicating end of connector name, and on which classifier is used
|
||||||
|
self.connector_end_token_id = self.tokenizer.get_vocab()[
|
||||||
|
self.config.connector_end_token
|
||||||
|
]
|
||||||
|
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
hidden_states = self.distilbert(
|
||||||
|
input_ids=input_ids, attention_mask=attention_mask
|
||||||
|
).last_hidden_state
|
||||||
|
|
||||||
|
cls_hidden_states = hidden_states[
|
||||||
|
:, 0, :
|
||||||
|
] # Take leap of faith that first token is always [CLS]
|
||||||
|
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
|
||||||
|
global_confidence = torch.sigmoid(global_logits).view(-1)
|
||||||
|
|
||||||
|
connector_end_position_ids = input_ids == self.connector_end_token_id
|
||||||
|
connector_end_hidden_states = hidden_states[connector_end_position_ids]
|
||||||
|
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
|
||||||
|
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
|
||||||
|
|
||||||
|
return global_confidence, classifier_confidence
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||||
|
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
|
||||||
|
device = (
|
||||||
|
torch.device("cuda")
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else torch.device("mps")
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else torch.device("cpu")
|
||||||
|
)
|
||||||
|
state_dict = torch.load(
|
||||||
|
os.path.join(repo_dir, "pytorch_model.pt"),
|
||||||
|
map_location=device,
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = cls(config)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.to(device)
|
||||||
|
model.device = device
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return model
|
||||||
|
@@ -16,9 +16,12 @@ INDEXING_MODEL_SERVER_PORT = int(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Danswer custom Deep Learning Models
|
# Danswer custom Deep Learning Models
|
||||||
|
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
|
||||||
|
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||||
INTENT_MODEL_TAG = "v1.0.3"
|
INTENT_MODEL_TAG = "v1.0.3"
|
||||||
|
|
||||||
|
|
||||||
# Bi-Encoder, other details
|
# Bi-Encoder, other details
|
||||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||||
|
|
||||||
|
@@ -7,6 +7,15 @@ from shared_configs.enums import RerankerProvider
|
|||||||
Embedding = list[float]
|
Embedding = list[float]
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorClassificationRequest(BaseModel):
|
||||||
|
available_connectors: list[str]
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorClassificationResponse(BaseModel):
|
||||||
|
connectors: list[str]
|
||||||
|
|
||||||
|
|
||||||
class EmbedRequest(BaseModel):
|
class EmbedRequest(BaseModel):
|
||||||
texts: list[str]
|
texts: list[str]
|
||||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||||
|
Reference in New Issue
Block a user