mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 11:12:02 +01:00
More accurate input token count for LLM (#1267)
This commit is contained in:
parent
9757fbee90
commit
fd69203be8
@ -55,7 +55,7 @@ def create_chat_chain(
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
if not all_chat_messages:
|
||||
raise ValueError("No messages in Chat Session")
|
||||
raise RuntimeError("No messages in Chat Session")
|
||||
|
||||
root_message = all_chat_messages[0]
|
||||
if root_message.parent_message is not None:
|
||||
|
@ -241,6 +241,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
) -> int:
|
||||
@ -250,18 +251,12 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
model_name = model_name or get_default_llm_version()[0]
|
||||
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
||||
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
|
||||
# and there is no other interface to get what we want. This should be okay though, since the
|
||||
# `model_cost` dict is a named public interface:
|
||||
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
||||
litellm_model_map = litellm.model_cost
|
||||
|
||||
try:
|
||||
if model_provider == "openai":
|
||||
model_obj = litellm_model_map[model_name]
|
||||
model_obj = model_map[model_name]
|
||||
else:
|
||||
model_obj = litellm_model_map[f"{model_provider}/{model_name}"]
|
||||
model_obj = model_map[f"{model_provider}/{model_name}"]
|
||||
if "max_tokens" in model_obj:
|
||||
return model_obj["max_tokens"]
|
||||
elif "max_input_tokens" in model_obj and "max_output_tokens" in model_obj:
|
||||
@ -275,17 +270,73 @@ def get_llm_max_tokens(
|
||||
return 4096
|
||||
|
||||
|
||||
def get_llm_max_input_tokens(
|
||||
output_tokens: int,
|
||||
model_map: dict,
|
||||
model_name: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
) -> int | None:
|
||||
try:
|
||||
if model_provider == "openai":
|
||||
model_obj = model_map[model_name]
|
||||
else:
|
||||
model_obj = model_map[f"{model_provider}/{model_name}"]
|
||||
|
||||
max_in = model_obj.get("max_input_tokens")
|
||||
max_out = model_obj.get("max_output_tokens")
|
||||
if max_in is None or max_out is None:
|
||||
# Can't calculate precisely, use the fallback method
|
||||
return None
|
||||
|
||||
# Some APIs may not actually work like this, but it's a safer approach generally speaking
|
||||
# since worst case we remove some extra tokens from the input space
|
||||
output_token_debt = 0
|
||||
if output_tokens > max_out:
|
||||
logger.warning(
|
||||
"More output tokens requested than model is likely able to handle"
|
||||
)
|
||||
output_token_debt = output_tokens - max_out
|
||||
|
||||
remaining_max_input_tokens = max_in - output_token_debt
|
||||
return remaining_max_input_tokens
|
||||
|
||||
except Exception:
|
||||
# We can try the less accurate approach if this fails
|
||||
return None
|
||||
|
||||
|
||||
def get_max_input_tokens(
|
||||
model_name: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
) -> int:
|
||||
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
||||
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
|
||||
# and there is no other interface to get what we want. This should be okay though, since the
|
||||
# `model_cost` dict is a named public interface:
|
||||
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
||||
# model_map is litellm.model_cost
|
||||
litellm_model_map = litellm.model_cost
|
||||
|
||||
model_name = model_name or get_default_llm_version()[0]
|
||||
input_toks = (
|
||||
get_llm_max_tokens(model_name=model_name, model_provider=model_provider)
|
||||
- output_tokens
|
||||
|
||||
input_toks = get_llm_max_input_tokens(
|
||||
output_tokens=output_tokens,
|
||||
model_map=litellm_model_map,
|
||||
model_name=model_name,
|
||||
model_provider=model_provider,
|
||||
)
|
||||
|
||||
if input_toks is None:
|
||||
input_toks = (
|
||||
get_llm_max_tokens(
|
||||
model_name=model_name,
|
||||
model_provider=model_provider,
|
||||
model_map=litellm_model_map,
|
||||
)
|
||||
- output_tokens
|
||||
)
|
||||
|
||||
if input_toks <= 0:
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
|
@ -24,7 +24,7 @@ httpx-oauth==0.11.2
|
||||
huggingface-hub==0.20.1
|
||||
jira==3.5.1
|
||||
langchain==0.1.9
|
||||
litellm==1.27.10
|
||||
litellm==1.34.8
|
||||
llama-index==0.9.45
|
||||
Mako==1.2.4
|
||||
msal==1.26.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user