mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-05 17:30:26 +02:00
small updates
This commit is contained in:
parent
9f37ca23e8
commit
b8f64d10a2
@ -121,7 +121,9 @@ def get_local_intent_model(
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(f"Loading model from local cache: {model_name_or_path}")
|
||||
local_path = _create_local_path(model_name_or_path, tag, True)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
logger.notice(f"Loaded model from local cache: {local_path}")
|
||||
except Exception as e:
|
||||
@ -129,7 +131,9 @@ def get_local_intent_model(
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = _create_local_path(model_name_or_path, tag, False)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@ -150,7 +154,9 @@ def get_local_information_content_model(
|
||||
logger.notice(
|
||||
f"Loading content information model from local cache: {model_name_or_path}"
|
||||
)
|
||||
local_path = _create_local_path(model_name_or_path, tag, True)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
logger.notice(
|
||||
f"Loaded content information model from local cache: {local_path}"
|
||||
@ -162,7 +168,9 @@ def get_local_information_content_model(
|
||||
logger.notice(
|
||||
f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
)
|
||||
local_path = _create_local_path(model_name_or_path, tag, False)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@ -260,7 +268,6 @@ def warm_up_information_content_model() -> None:
|
||||
logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
information_content_model = get_local_information_content_model()
|
||||
information_content_model.device
|
||||
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
@ -331,9 +338,7 @@ def run_content_classification_inference(
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=predicted_label, content_boost_factor=output_score
|
||||
)
|
||||
for predicted_label, output_score in list(
|
||||
zip(output_classes, prediction_scores)
|
||||
)
|
||||
for predicted_label, output_score in zip(output_classes, prediction_scores)
|
||||
]
|
||||
|
||||
return content_classification_predictions
|
||||
|
@ -285,7 +285,7 @@ INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float(
|
||||
)
|
||||
# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.8
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7
|
||||
)
|
||||
# Temperature for the information content classification model
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
|
||||
|
Loading…
x
Reference in New Issue
Block a user