diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 73782d391..73d482f35 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -338,9 +338,11 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS try: - model_obj = model_map.get(f"{model_provider}/{model_name}") - if not model_obj: - model_obj = model_map[model_name] + model_obj = ( + model_map.get(f"{model_provider}/{model_name}") + or model_map.get(model_name) + or model_map[model_name.split("/")[1]] + ) if "max_input_tokens" in model_obj: return model_obj["max_input_tokens"]