mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
initial working code
This commit is contained in:
parent
ca803859cc
commit
125877ec65
@ -32,10 +32,13 @@ AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
|
||||
snapshot_download(repo_id='sentence-transformers/paraphrase-mpnet-base-v2'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True); \
|
||||
from setfit import SetFitModel; \
|
||||
SetFitModel.from_pretrained('sentence-transformers/paraphrase-mpnet-base-v2');"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, don't overwrite it with the built in cache folder
|
||||
|
@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
|
@ -2,10 +2,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from model_server.constants import CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.onyx_torch_model import ConnectorClassifier
|
||||
from model_server.onyx_torch_model import HybridClassifier
|
||||
@ -13,6 +15,7 @@ from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import CONTENT_MODEL_VERSION
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
@ -21,6 +24,7 @@ from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
@ -31,6 +35,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_CONTENT_MODEL: SetFitModel | None = None
|
||||
|
||||
_TEMPERATURE_CONTENT_CLASSIFICATION = 4.0
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
@ -112,6 +120,13 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def get_local_content_model() -> SetFitModel:
|
||||
global _CONTENT_MODEL
|
||||
if _CONTENT_MODEL is None:
|
||||
_CONTENT_MODEL = SetFitModel(CONTENT_MODEL_VERSION)
|
||||
return _CONTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
@ -195,6 +210,16 @@ def warm_up_intent_model() -> None:
|
||||
)
|
||||
|
||||
|
||||
def warm_up_content_model() -> None:
|
||||
logger.notice(
|
||||
"Warming up Content Model"
|
||||
) # TODO: add version once we have proper model
|
||||
|
||||
content_model = get_local_content_model()
|
||||
content_model.device
|
||||
content_model(CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
@ -218,6 +243,29 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[tuple[int, float]]:
|
||||
get_local_content_model()
|
||||
|
||||
# output_classes = list([x.numpy() for x in content_model(text_inputs)])
|
||||
# base_output_probabilities = list([x[1].numpy() for x in content_model.predict_proba(text_inputs)])
|
||||
# logits = [np.log(p/(1-p)) for p in base_output_probabilities]
|
||||
# scaled_logits = [l/_TEMPERATURE_CONTENT_CLASSIFICATION for l in logits]
|
||||
# output_probabilities_with_temp = [np.exp(scaled_logit)/(1 + np.exp(scaled_logit)) for scaled_logit in scaled_logits]
|
||||
|
||||
output_classes = [1] * len(text_inputs)
|
||||
output_probabilities_with_temp = [0.9] * len(text_inputs)
|
||||
|
||||
return [
|
||||
(predicted_label, predicted_probability)
|
||||
for predicted_label, predicted_probability in zip(
|
||||
output_classes, output_probabilities_with_temp
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
@ -362,3 +410,13 @@ async def process_analysis_request(
|
||||
|
||||
is_keyword, keywords = run_analysis(intent_request)
|
||||
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
@router.post("/content-classification")
|
||||
async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[tuple[int, float]]:
|
||||
content_classification_result = run_content_classification_inference(
|
||||
content_classification_requests
|
||||
)
|
||||
return content_classification_result
|
||||
|
@ -53,6 +53,9 @@ from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
ContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
@ -348,6 +351,10 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
content_classification_model = ContentClassificationModel(
|
||||
model_server_host="localhost", model_server_port=9000
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
@ -356,6 +363,7 @@ def _run_indexing(
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
content_classification_model=content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
|
@ -101,6 +101,7 @@ class VespaDocumentFields:
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
hidden: bool | None = None
|
||||
aggregated_boost_factor: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -80,6 +80,11 @@ schema DANSWER_CHUNK_NAME {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
}
|
||||
# Field to indicate whether a short chunk is a low content chunk
|
||||
field aggregated_boost_factor type float {
|
||||
indexing: attribute
|
||||
}
|
||||
|
||||
# Needs to have a separate Attribute list for efficient filtering
|
||||
field metadata_list type array<string> {
|
||||
indexing: summary | attribute
|
||||
@ -142,6 +147,11 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
|
||||
}
|
||||
|
||||
function inline document_aggregated_boost_factor() {
|
||||
# Time in years (91.3 days ~= 3 Months ~= 1 fiscal quarter if no age found)
|
||||
expression: if(isNan(attribute(aggregated_boost_factor)) == 1, 1.0, attribute(aggregated_boost_factor))
|
||||
}
|
||||
|
||||
# Document score decays from 1 to 0.75 as age of last updated time increases
|
||||
function inline recency_bias() {
|
||||
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
|
||||
@ -199,6 +209,8 @@ schema DANSWER_CHUNK_NAME {
|
||||
* document_boost
|
||||
# Decay factor based on time document was last updated
|
||||
* recency_bias
|
||||
# Boost based on aggregated boost calculation
|
||||
* document_aggregated_boost_factor
|
||||
}
|
||||
rerank-count: 1000
|
||||
}
|
||||
@ -210,6 +222,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
closeness(field, embeddings)
|
||||
document_boost
|
||||
recency_bias
|
||||
document_aggregated_boost_factor
|
||||
closest(embeddings)
|
||||
}
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import AGGREGATED_BOOST_FACTOR
|
||||
from onyx.document_index.vespa_constants import BLURB
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
@ -201,6 +202,7 @@ def _index_vespa_chunk(
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
IMAGE_FILE_NAME: chunk.image_file_name,
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_BOOST_FACTOR: chunk.aggregated_boost_factor,
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
|
@ -72,6 +72,7 @@ METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
BOOST = "boost"
|
||||
AGGREGATED_BOOST_FACTOR = "aggregated_boost_factor"
|
||||
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
|
||||
PRIMARY_OWNERS = "primary_owners"
|
||||
SECONDARY_OWNERS = "secondary_owners"
|
||||
@ -97,6 +98,7 @@ YQL_BASE = (
|
||||
f"{SECTION_CONTINUATION}, "
|
||||
f"{IMAGE_FILE_NAME}, "
|
||||
f"{BOOST}, "
|
||||
f"{AGGREGATED_BOOST_FACTOR}, "
|
||||
f"{HIDDEN}, "
|
||||
f"{DOC_UPDATED_AT}, "
|
||||
f"{PRIMARY_OWNERS}, "
|
||||
|
0
backend/onyx/indexing/content_classification.py
Normal file
0
backend/onyx/indexing/content_classification.py
Normal file
@ -52,7 +52,11 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
ContentClassificationModel,
|
||||
)
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
@ -136,6 +140,81 @@ def _upsert_documents_in_db(
|
||||
)
|
||||
|
||||
|
||||
def _get_aggregated_boost_factor(
|
||||
chunks: list[IndexChunk], content_classification_model: ContentClassificationModel
|
||||
) -> tuple[list[IndexChunk], list[float], list[ConnectorFailure]]:
|
||||
"""Calculates the aggregated boost factor for a chunk based on its content."""
|
||||
|
||||
short_chunk_content_dict = {
|
||||
chunk_num: chunk.content
|
||||
for chunk_num, chunk in enumerate(chunks)
|
||||
if len(chunk.content.split()) <= 10
|
||||
}
|
||||
short_chunk_contents = list(short_chunk_content_dict.values())
|
||||
short_chunk_keys = list(short_chunk_content_dict.keys())
|
||||
|
||||
try:
|
||||
short_content_classification_predictions = content_classification_model.predict(
|
||||
short_chunk_contents
|
||||
)
|
||||
short_content_classification_results = [
|
||||
raw_score for _, raw_score in short_content_classification_predictions
|
||||
]
|
||||
short_content_classification_results_dict = {
|
||||
short_chunk_keys[i]: short_content_classification_results[i]
|
||||
for i in range(len(short_chunk_keys))
|
||||
}
|
||||
chunk_content_scores = [
|
||||
1.0
|
||||
if chunk_num not in short_chunk_keys
|
||||
else short_content_classification_results_dict[chunk_num]
|
||||
for chunk_num in range(len(chunks))
|
||||
]
|
||||
|
||||
return chunks, chunk_content_scores, []
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error predicting content classification for chunks: {e}. Falling back to individual examples."
|
||||
)
|
||||
|
||||
chunks_with_scores: list[IndexChunk] = []
|
||||
chunk_content_scores = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
|
||||
for chunk in chunks:
|
||||
if len(chunk.content.split()) <= 10:
|
||||
try:
|
||||
chunk_content_scores.append(
|
||||
content_classification_model.predict([chunk.content])[0][1]
|
||||
)
|
||||
chunks_with_scores.append(chunk)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error predicting content classification for chunk: {e}. Adding to missed content classifications."
|
||||
)
|
||||
# chunk_content_scores.append(1.0)
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=chunk.source_document.id,
|
||||
document_link=(
|
||||
chunk.source_document.sections[0].link
|
||||
if chunk.source_document.sections
|
||||
else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chunk_content_scores.append(1.0)
|
||||
chunks_with_scores.append(chunk)
|
||||
|
||||
return chunks_with_scores, chunk_content_scores, failures
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
@ -165,6 +244,7 @@ def index_doc_batch_with_handler(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
content_classification_model: ContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
@ -176,6 +256,7 @@ def index_doc_batch_with_handler(
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
content_classification_model=content_classification_model,
|
||||
document_index=document_index,
|
||||
document_batch=document_batch,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
@ -450,6 +531,7 @@ def index_doc_batch(
|
||||
document_batch: list[Document],
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
content_classification_model: ContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
@ -526,6 +608,14 @@ def index_doc_batch(
|
||||
else ([], [])
|
||||
)
|
||||
|
||||
(
|
||||
chunks_with_embeddings_scores,
|
||||
chunk_content_scores,
|
||||
chunk_content_classification_failures,
|
||||
) = _get_aggregated_boost_factor(
|
||||
chunks_with_embeddings, content_classification_model
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
@ -554,7 +644,7 @@ def index_doc_batch(
|
||||
document_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
for chunk in chunks_with_embeddings_scores
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
)
|
||||
@ -579,8 +669,9 @@ def index_doc_batch(
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
aggregated_boost_factor=chunk_content_scores[chunk_num],
|
||||
)
|
||||
for chunk in chunks_with_embeddings
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings_scores)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
@ -671,7 +762,9 @@ def index_doc_batch(
|
||||
new_docs=len([r for r in insertion_records if r.already_existed is False]),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(access_aware_chunks),
|
||||
failures=vector_db_write_failures + embedding_failures,
|
||||
failures=vector_db_write_failures
|
||||
+ embedding_failures
|
||||
+ chunk_content_classification_failures,
|
||||
)
|
||||
|
||||
return result
|
||||
@ -680,6 +773,7 @@ def index_doc_batch(
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
embedder: IndexingEmbedder,
|
||||
content_classification_model: ContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
@ -703,6 +797,7 @@ def build_indexing_pipeline(
|
||||
index_doc_batch_with_handler,
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
content_classification_model=content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
|
@ -83,13 +83,16 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
document_sets: all document sets the source document for this chunk is a part
|
||||
of. This is used for filtering / personas.
|
||||
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
|
||||
negative -> ranked lower.
|
||||
negative -> ranked lower. Not included in aggregated boost calculation
|
||||
for legacy reasons.
|
||||
aggregated_boost_factor: represents non-user-specific aggregated boost calculation
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
aggregated_boost_factor: float = 1.0
|
||||
|
||||
@classmethod
|
||||
def from_index_chunk(
|
||||
@ -98,6 +101,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
aggregated_boost_factor: float,
|
||||
tenant_id: str,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
@ -106,6 +110,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
boost=boost,
|
||||
aggregated_boost_factor=aggregated_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
@ -36,6 +36,7 @@ from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationResponses
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
@ -377,6 +378,31 @@ class QueryAnalysisModel:
|
||||
return response_model.is_keyword, response_model.keywords
|
||||
|
||||
|
||||
class ContentClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.content_server_endpoint = (
|
||||
model_server_url + "/custom/content-classification"
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
queries: list[str],
|
||||
) -> list[tuple[int, float]]:
|
||||
response = requests.post(self.content_server_endpoint, json=queries)
|
||||
response.raise_for_status()
|
||||
|
||||
response_model = ContentClassificationResponses(
|
||||
content_classifications=response.json()
|
||||
)
|
||||
|
||||
return response_model.content_classifications
|
||||
|
||||
|
||||
class ConnectorClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -19,10 +19,15 @@ from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
ContentClassificationModel,
|
||||
)
|
||||
from onyx.server.onyx_api.models import DocMinimalInfo
|
||||
from onyx.server.onyx_api.models import IngestionDocument
|
||||
from onyx.server.onyx_api.models import IngestionResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@ -102,8 +107,13 @@ def upsert_ingestion_doc(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
content_classification_model = ContentClassificationModel(
|
||||
model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=index_embedding_model,
|
||||
content_classification_model=content_classification_model,
|
||||
document_index=curr_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
@ -138,6 +148,7 @@ def upsert_ingestion_doc(
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
embedder=new_index_embedding_model,
|
||||
content_classification_model=content_classification_model,
|
||||
document_index=sec_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
|
@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0
|
||||
httpcore==1.0.5
|
||||
httpx[http2]==0.27.0
|
||||
httpx-oauth==0.15.1
|
||||
huggingface-hub==0.20.1
|
||||
huggingface-hub==0.29.0
|
||||
inflection==0.5.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
@ -71,6 +71,7 @@ requests==2.32.2
|
||||
requests-oauthlib==1.3.1
|
||||
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
|
||||
rfc3986==1.5.0
|
||||
setfit==1.1.1
|
||||
simple-salesforce==1.12.6
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
@ -78,7 +79,7 @@ starlette==0.36.3
|
||||
supervisor==4.2.5
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.39.2
|
||||
transformers==4.49.0
|
||||
unstructured==0.15.1
|
||||
unstructured-client==0.25.4
|
||||
uvicorn==0.21.1
|
||||
|
@ -8,8 +8,9 @@ pydantic==2.8.2
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
setfit==1.1.1
|
||||
torch==2.2.0
|
||||
transformers==4.39.2
|
||||
transformers==4.49.0
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.61.16
|
||||
|
@ -161,17 +161,21 @@ overview_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/overview",
|
||||
title=overview_title,
|
||||
content=overview,
|
||||
title_embedding=model.encode(f"search_document: {overview_title}"),
|
||||
content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"),
|
||||
title_embedding=list(model.encode(f"search_document: {overview_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {overview_title}\n{overview}")
|
||||
),
|
||||
)
|
||||
|
||||
enterprise_search_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/enterprise_search",
|
||||
title=enterprise_search_title,
|
||||
content=enterprise_search_1,
|
||||
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
|
||||
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -179,9 +183,11 @@ enterprise_search_doc_2 = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/enterprise_search",
|
||||
title=enterprise_search_title,
|
||||
content=enterprise_search_2,
|
||||
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
|
||||
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
|
||||
)
|
||||
),
|
||||
chunk_ind=1,
|
||||
)
|
||||
@ -190,9 +196,9 @@ ai_platform_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/ai_platform",
|
||||
title=ai_platform_title,
|
||||
content=ai_platform,
|
||||
title_embedding=model.encode(f"search_document: {ai_platform_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {ai_platform_title}\n{ai_platform}"
|
||||
title_embedding=list(model.encode(f"search_document: {ai_platform_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {ai_platform_title}\n{ai_platform}")
|
||||
),
|
||||
)
|
||||
|
||||
@ -200,9 +206,9 @@ customer_support_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/support",
|
||||
title=customer_support_title,
|
||||
content=customer_support,
|
||||
title_embedding=model.encode(f"search_document: {customer_support_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {customer_support_title}\n{customer_support}"
|
||||
title_embedding=list(model.encode(f"search_document: {customer_support_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {customer_support_title}\n{customer_support}")
|
||||
),
|
||||
)
|
||||
|
||||
@ -210,17 +216,17 @@ sales_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/sales",
|
||||
title=sales_title,
|
||||
content=sales,
|
||||
title_embedding=model.encode(f"search_document: {sales_title}"),
|
||||
content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"),
|
||||
title_embedding=list(model.encode(f"search_document: {sales_title}")),
|
||||
content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")),
|
||||
)
|
||||
|
||||
operations_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/operations",
|
||||
title=operations_title,
|
||||
content=operations,
|
||||
title_embedding=model.encode(f"search_document: {operations_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {operations_title}\n{operations}"
|
||||
title_embedding=list(model.encode(f"search_document: {operations_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {operations_title}\n{operations}")
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -99,6 +99,7 @@ def generate_dummy_chunk(
|
||||
),
|
||||
document_sets={document_set for document_set in document_set_names},
|
||||
boost=random.randint(-1, 1),
|
||||
aggregated_boost_factor=random.random(),
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
)
|
||||
|
||||
|
@ -25,6 +25,9 @@ CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
|
||||
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
CONTENT_MODEL_VERSION = (
|
||||
"sentence-transformers/paraphrase-mpnet-base-v2" # TODO: replace with Onyx FT model
|
||||
)
|
||||
|
||||
|
||||
# Bi-Encoder, other details
|
||||
|
@ -73,6 +73,14 @@ class IntentResponse(BaseModel):
|
||||
keywords: list[str]
|
||||
|
||||
|
||||
class ContentClassificationRequests(BaseModel):
|
||||
queries: list[str]
|
||||
|
||||
|
||||
class ContentClassificationResponses(BaseModel):
|
||||
content_classifications: list[tuple[int, float]]
|
||||
|
||||
|
||||
class SupportedEmbeddingModel(BaseModel):
|
||||
name: str
|
||||
dim: int
|
||||
|
Loading…
x
Reference in New Issue
Block a user