initial working code

This commit is contained in:
joachim-danswer 2025-02-21 10:16:57 -08:00
parent ca803859cc
commit 125877ec65
19 changed files with 272 additions and 27 deletions

View File

@ -32,10 +32,13 @@ AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download(repo_id='sentence-transformers/paraphrase-mpnet-base-v2'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True); \
from setfit import SetFitModel; \
SetFitModel.from_pretrained('sentence-transformers/paraphrase-mpnet-base-v2');"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Onyx, don't overwrite it with the built in cache folder

View File

@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"

View File

@ -2,10 +2,12 @@ import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
@ -13,6 +15,7 @@ from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import CONTENT_MODEL_VERSION
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
@ -21,6 +24,7 @@ from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
logger = setup_logger()
router = APIRouter(prefix="/custom")
@ -31,6 +35,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_CONTENT_MODEL: SetFitModel | None = None
_TEMPERATURE_CONTENT_CLASSIFICATION = 4.0
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
@ -112,6 +120,13 @@ def get_local_intent_model(
return _INTENT_MODEL
def get_local_content_model() -> SetFitModel:
global _CONTENT_MODEL
if _CONTENT_MODEL is None:
_CONTENT_MODEL = SetFitModel(CONTENT_MODEL_VERSION)
return _CONTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
@ -195,6 +210,16 @@ def warm_up_intent_model() -> None:
)
def warm_up_content_model() -> None:
logger.notice(
"Warming up Content Model"
) # TODO: add version once we have proper model
content_model = get_local_content_model()
content_model.device
content_model(CONTENT_MODEL_WARM_UP_STRING)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
@ -218,6 +243,29 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
return intent_probabilities.tolist(), token_positive_probs
@simple_log_function_time()
def run_content_classification_inference(
text_inputs: list[str],
) -> list[tuple[int, float]]:
get_local_content_model()
# output_classes = list([x.numpy() for x in content_model(text_inputs)])
# base_output_probabilities = list([x[1].numpy() for x in content_model.predict_proba(text_inputs)])
# logits = [np.log(p/(1-p)) for p in base_output_probabilities]
# scaled_logits = [l/_TEMPERATURE_CONTENT_CLASSIFICATION for l in logits]
# output_probabilities_with_temp = [np.exp(scaled_logit)/(1 + np.exp(scaled_logit)) for scaled_logit in scaled_logits]
output_classes = [1] * len(text_inputs)
output_probabilities_with_temp = [0.9] * len(text_inputs)
return [
(predicted_label, predicted_probability)
for predicted_label, predicted_probability in zip(
output_classes, output_probabilities_with_temp
)
]
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
@ -362,3 +410,13 @@ async def process_analysis_request(
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
@router.post("/content-classification")
async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[tuple[int, float]]:
content_classification_result = run_content_classification_inference(
content_classification_requests
)
return content_classification_result

View File

@ -53,6 +53,9 @@ from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
ContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
@ -348,6 +351,10 @@ def _run_indexing(
callback=callback,
)
content_classification_model = ContentClassificationModel(
model_server_host="localhost", model_server_port=9000
)
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
@ -356,6 +363,7 @@ def _run_indexing(
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
content_classification_model=content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning

View File

@ -101,6 +101,7 @@ class VespaDocumentFields:
document_sets: set[str] | None = None
boost: float | None = None
hidden: bool | None = None
aggregated_boost_factor: float | None = None
@dataclass

View File

@ -80,6 +80,11 @@ schema DANSWER_CHUNK_NAME {
indexing: summary | attribute
rank: filter
}
# Field to indicate whether a short chunk is a low content chunk
field aggregated_boost_factor type float {
indexing: attribute
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute
@ -142,6 +147,11 @@ schema DANSWER_CHUNK_NAME {
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
}
function inline document_aggregated_boost_factor() {
# Time in years (91.3 days ~= 3 Months ~= 1 fiscal quarter if no age found)
expression: if(isNan(attribute(aggregated_boost_factor)) == 1, 1.0, attribute(aggregated_boost_factor))
}
# Document score decays from 1 to 0.75 as age of last updated time increases
function inline recency_bias() {
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
@ -199,6 +209,8 @@ schema DANSWER_CHUNK_NAME {
* document_boost
# Decay factor based on time document was last updated
* recency_bias
# Boost based on aggregated boost calculation
* document_aggregated_boost_factor
}
rerank-count: 1000
}
@ -210,6 +222,7 @@ schema DANSWER_CHUNK_NAME {
closeness(field, embeddings)
document_boost
recency_bias
document_aggregated_boost_factor
closest(embeddings)
}
}

View File

@ -22,6 +22,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_ID
@ -201,6 +202,7 @@ def _index_vespa_chunk(
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
IMAGE_FILE_NAME: chunk.image_file_name,
BOOST: chunk.boost,
AGGREGATED_BOOST_FACTOR: chunk.aggregated_boost_factor,
}
if multitenant:

View File

@ -72,6 +72,7 @@ METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
BOOST = "boost"
AGGREGATED_BOOST_FACTOR = "aggregated_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
@ -97,6 +98,7 @@ YQL_BASE = (
f"{SECTION_CONTINUATION}, "
f"{IMAGE_FILE_NAME}, "
f"{BOOST}, "
f"{AGGREGATED_BOOST_FACTOR}, "
f"{HIDDEN}, "
f"{DOC_UPDATED_AT}, "
f"{PRIMARY_OWNERS}, "

View File

@ -52,7 +52,11 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.natural_language_processing.search_nlp_models import (
ContentClassificationModel,
)
from onyx.llm.factory import get_default_llm_with_vision
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
@ -136,6 +140,81 @@ def _upsert_documents_in_db(
)
def _get_aggregated_boost_factor(
chunks: list[IndexChunk], content_classification_model: ContentClassificationModel
) -> tuple[list[IndexChunk], list[float], list[ConnectorFailure]]:
"""Calculates the aggregated boost factor for a chunk based on its content."""
short_chunk_content_dict = {
chunk_num: chunk.content
for chunk_num, chunk in enumerate(chunks)
if len(chunk.content.split()) <= 10
}
short_chunk_contents = list(short_chunk_content_dict.values())
short_chunk_keys = list(short_chunk_content_dict.keys())
try:
short_content_classification_predictions = content_classification_model.predict(
short_chunk_contents
)
short_content_classification_results = [
raw_score for _, raw_score in short_content_classification_predictions
]
short_content_classification_results_dict = {
short_chunk_keys[i]: short_content_classification_results[i]
for i in range(len(short_chunk_keys))
}
chunk_content_scores = [
1.0
if chunk_num not in short_chunk_keys
else short_content_classification_results_dict[chunk_num]
for chunk_num in range(len(chunks))
]
return chunks, chunk_content_scores, []
except Exception as e:
logger.exception(
f"Error predicting content classification for chunks: {e}. Falling back to individual examples."
)
chunks_with_scores: list[IndexChunk] = []
chunk_content_scores = []
failures: list[ConnectorFailure] = []
for chunk in chunks:
if len(chunk.content.split()) <= 10:
try:
chunk_content_scores.append(
content_classification_model.predict([chunk.content])[0][1]
)
chunks_with_scores.append(chunk)
except Exception as e:
logger.exception(
f"Error predicting content classification for chunk: {e}. Adding to missed content classifications."
)
# chunk_content_scores.append(1.0)
failures.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=chunk.source_document.id,
document_link=(
chunk.source_document.sections[0].link
if chunk.source_document.sections
else None
),
),
failure_message=str(e),
exception=e,
)
)
else:
chunk_content_scores.append(1.0)
chunks_with_scores.append(chunk)
return chunks_with_scores, chunk_content_scores, failures
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@ -165,6 +244,7 @@ def index_doc_batch_with_handler(
*,
chunker: Chunker,
embedder: IndexingEmbedder,
content_classification_model: ContentClassificationModel,
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
@ -176,6 +256,7 @@ def index_doc_batch_with_handler(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
content_classification_model=content_classification_model,
document_index=document_index,
document_batch=document_batch,
index_attempt_metadata=index_attempt_metadata,
@ -450,6 +531,7 @@ def index_doc_batch(
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
content_classification_model: ContentClassificationModel,
document_index: DocumentIndex,
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
@ -526,6 +608,14 @@ def index_doc_batch(
else ([], [])
)
(
chunks_with_embeddings_scores,
chunk_content_scores,
chunk_content_classification_failures,
) = _get_aggregated_boost_factor(
chunks_with_embeddings, content_classification_model
)
updatable_ids = [doc.id for doc in ctx.updatable_docs]
# Acquires a lock on the documents so that no other process can modify them
@ -554,7 +644,7 @@ def index_doc_batch(
document_id: len(
[
chunk
for chunk in chunks_with_embeddings
for chunk in chunks_with_embeddings_scores
if chunk.source_document.id == document_id
]
)
@ -579,8 +669,9 @@ def index_doc_batch(
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_boost_factor=chunk_content_scores[chunk_num],
)
for chunk in chunks_with_embeddings
for chunk_num, chunk in enumerate(chunks_with_embeddings_scores)
]
logger.debug(
@ -671,7 +762,9 @@ def index_doc_batch(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
failures=vector_db_write_failures + embedding_failures,
failures=vector_db_write_failures
+ embedding_failures
+ chunk_content_classification_failures,
)
return result
@ -680,6 +773,7 @@ def index_doc_batch(
def build_indexing_pipeline(
*,
embedder: IndexingEmbedder,
content_classification_model: ContentClassificationModel,
document_index: DocumentIndex,
db_session: Session,
tenant_id: str,
@ -703,6 +797,7 @@ def build_indexing_pipeline(
index_doc_batch_with_handler,
chunker=chunker,
embedder=embedder,
content_classification_model=content_classification_model,
document_index=document_index,
ignore_time_skip=ignore_time_skip,
db_session=db_session,

View File

@ -83,13 +83,16 @@ class DocMetadataAwareIndexChunk(IndexChunk):
document_sets: all document sets the source document for this chunk is a part
of. This is used for filtering / personas.
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
negative -> ranked lower.
negative -> ranked lower. Not included in aggregated boost calculation
for legacy reasons.
aggregated_boost_factor: represents non-user-specific aggregated boost calculation
"""
tenant_id: str
access: "DocumentAccess"
document_sets: set[str]
boost: int
aggregated_boost_factor: float = 1.0
@classmethod
def from_index_chunk(
@ -98,6 +101,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
aggregated_boost_factor: float,
tenant_id: str,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
@ -106,6 +110,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
boost=boost,
aggregated_boost_factor=aggregated_boost_factor,
tenant_id=tenant_id,
)

View File

@ -36,6 +36,7 @@ from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationResponses
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
@ -377,6 +378,31 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords
class ContentClassificationModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.content_server_endpoint = (
model_server_url + "/custom/content-classification"
)
def predict(
self,
queries: list[str],
) -> list[tuple[int, float]]:
response = requests.post(self.content_server_endpoint, json=queries)
response.raise_for_status()
response_model = ContentClassificationResponses(
content_classifications=response.json()
)
return response_model.content_classifications
class ConnectorClassificationModel:
def __init__(
self,

View File

@ -19,10 +19,15 @@ from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
ContentClassificationModel,
)
from onyx.server.onyx_api.models import DocMinimalInfo
from onyx.server.onyx_api.models import IngestionDocument
from onyx.server.onyx_api.models import IngestionResult
from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@ -102,8 +107,13 @@ def upsert_ingestion_doc(
search_settings=search_settings
)
content_classification_model = ContentClassificationModel(
model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT
)
indexing_pipeline = build_indexing_pipeline(
embedder=index_embedding_model,
content_classification_model=content_classification_model,
document_index=curr_doc_index,
ignore_time_skip=True,
db_session=db_session,
@ -138,6 +148,7 @@ def upsert_ingestion_doc(
sec_ind_pipeline = build_indexing_pipeline(
embedder=new_index_embedding_model,
content_classification_model=content_classification_model,
document_index=sec_doc_index,
ignore_time_skip=True,
db_session=db_session,

View File

@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0
httpcore==1.0.5
httpx[http2]==0.27.0
httpx-oauth==0.15.1
huggingface-hub==0.20.1
huggingface-hub==0.29.0
inflection==0.5.1
jira==3.5.1
jsonref==1.1.0
@ -71,6 +71,7 @@ requests==2.32.2
requests-oauthlib==1.3.1
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
rfc3986==1.5.0
setfit==1.1.1
simple-salesforce==1.12.6
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
@ -78,7 +79,7 @@ starlette==0.36.3
supervisor==4.2.5
tiktoken==0.7.0
timeago==1.0.16
transformers==4.39.2
transformers==4.49.0
unstructured==0.15.1
unstructured-client==0.25.4
uvicorn==0.21.1

View File

@ -8,8 +8,9 @@ pydantic==2.8.2
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
setfit==1.1.1
torch==2.2.0
transformers==4.39.2
transformers==4.49.0
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16

View File

@ -161,17 +161,21 @@ overview_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/overview",
title=overview_title,
content=overview,
title_embedding=model.encode(f"search_document: {overview_title}"),
content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"),
title_embedding=list(model.encode(f"search_document: {overview_title}")),
content_embedding=list(
model.encode(f"search_document: {overview_title}\n{overview}")
),
)
enterprise_search_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_1,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
)
),
)
@ -179,9 +183,11 @@ enterprise_search_doc_2 = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_2,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
)
),
chunk_ind=1,
)
@ -190,9 +196,9 @@ ai_platform_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/ai_platform",
title=ai_platform_title,
content=ai_platform,
title_embedding=model.encode(f"search_document: {ai_platform_title}"),
content_embedding=model.encode(
f"search_document: {ai_platform_title}\n{ai_platform}"
title_embedding=list(model.encode(f"search_document: {ai_platform_title}")),
content_embedding=list(
model.encode(f"search_document: {ai_platform_title}\n{ai_platform}")
),
)
@ -200,9 +206,9 @@ customer_support_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/support",
title=customer_support_title,
content=customer_support,
title_embedding=model.encode(f"search_document: {customer_support_title}"),
content_embedding=model.encode(
f"search_document: {customer_support_title}\n{customer_support}"
title_embedding=list(model.encode(f"search_document: {customer_support_title}")),
content_embedding=list(
model.encode(f"search_document: {customer_support_title}\n{customer_support}")
),
)
@ -210,17 +216,17 @@ sales_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/sales",
title=sales_title,
content=sales,
title_embedding=model.encode(f"search_document: {sales_title}"),
content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"),
title_embedding=list(model.encode(f"search_document: {sales_title}")),
content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")),
)
operations_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/operations",
title=operations_title,
content=operations,
title_embedding=model.encode(f"search_document: {operations_title}"),
content_embedding=model.encode(
f"search_document: {operations_title}\n{operations}"
title_embedding=list(model.encode(f"search_document: {operations_title}")),
content_embedding=list(
model.encode(f"search_document: {operations_title}\n{operations}")
),
)

View File

@ -99,6 +99,7 @@ def generate_dummy_chunk(
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
aggregated_boost_factor=random.random(),
tenant_id=POSTGRES_DEFAULT_SCHEMA,
)

View File

@ -25,6 +25,9 @@ CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
CONTENT_MODEL_VERSION = (
"sentence-transformers/paraphrase-mpnet-base-v2" # TODO: replace with Onyx FT model
)
# Bi-Encoder, other details

View File

@ -73,6 +73,14 @@ class IntentResponse(BaseModel):
keywords: list[str]
class ContentClassificationRequests(BaseModel):
queries: list[str]
class ContentClassificationResponses(BaseModel):
content_classifications: list[tuple[int, float]]
class SupportedEmbeddingModel(BaseModel):
name: str
dim: int