small updates

This commit is contained in:
joachim-danswer 2025-03-06 10:45:55 -08:00
parent 9f37ca23e8
commit b8f64d10a2
2 changed files with 14 additions and 9 deletions

View File

@ -121,7 +121,9 @@ def get_local_intent_model(
try:
# Calculate where the cache should be, then load from local if available
logger.notice(f"Loading model from local cache: {model_name_or_path}")
local_path = _create_local_path(model_name_or_path, tag, True)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
logger.notice(f"Loaded model from local cache: {local_path}")
except Exception as e:
@ -129,7 +131,9 @@ def get_local_intent_model(
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = _create_local_path(model_name_or_path, tag, False)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
@ -150,7 +154,9 @@ def get_local_information_content_model(
logger.notice(
f"Loading content information model from local cache: {model_name_or_path}"
)
local_path = _create_local_path(model_name_or_path, tag, True)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
logger.notice(
f"Loaded content information model from local cache: {local_path}"
@ -162,7 +168,9 @@ def get_local_information_content_model(
logger.notice(
f"Downloading content information model snapshot for {model_name_or_path}"
)
local_path = _create_local_path(model_name_or_path, tag, False)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
except Exception as e:
logger.error(
@ -260,7 +268,6 @@ def warm_up_information_content_model() -> None:
logger.notice("Warming up Content Model") # TODO: add version if needed
information_content_model = get_local_information_content_model()
information_content_model.device
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
@ -331,9 +338,7 @@ def run_content_classification_inference(
ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score
)
for predicted_label, output_score in list(
zip(output_classes, prediction_scores)
)
for predicted_label, output_score in zip(output_classes, prediction_scores)
]
return content_classification_predictions

View File

@ -285,7 +285,7 @@ INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float(
)
# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.8
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7
)
# Temperature for the information content classification model
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(