More accurate input token count for LLM (#1267)

This commit is contained in:
Yuhong Sun 2024-03-28 11:11:37 -07:00 committed by GitHub
parent 9757fbee90
commit fd69203be8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 64 additions and 13 deletions

View File

@ -55,7 +55,7 @@ def create_chat_chain(
id_to_msg = {msg.id: msg for msg in all_chat_messages}
if not all_chat_messages:
raise ValueError("No messages in Chat Session")
raise RuntimeError("No messages in Chat Session")
root_message = all_chat_messages[0]
if root_message.parent_message is not None:

View File

@ -241,6 +241,7 @@ def test_llm(llm: LLM) -> str | None:
def get_llm_max_tokens(
model_map: dict,
model_name: str | None = GEN_AI_MODEL_VERSION,
model_provider: str = GEN_AI_MODEL_PROVIDER,
) -> int:
@ -250,18 +251,12 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS
model_name = model_name or get_default_llm_version()[0]
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
# and there is no other interface to get what we want. This should be okay though, since the
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
litellm_model_map = litellm.model_cost
try:
if model_provider == "openai":
model_obj = litellm_model_map[model_name]
model_obj = model_map[model_name]
else:
model_obj = litellm_model_map[f"{model_provider}/{model_name}"]
model_obj = model_map[f"{model_provider}/{model_name}"]
if "max_tokens" in model_obj:
return model_obj["max_tokens"]
elif "max_input_tokens" in model_obj and "max_output_tokens" in model_obj:
@ -275,17 +270,73 @@ def get_llm_max_tokens(
return 4096
def get_llm_max_input_tokens(
output_tokens: int,
model_map: dict,
model_name: str | None = GEN_AI_MODEL_VERSION,
model_provider: str = GEN_AI_MODEL_PROVIDER,
) -> int | None:
try:
if model_provider == "openai":
model_obj = model_map[model_name]
else:
model_obj = model_map[f"{model_provider}/{model_name}"]
max_in = model_obj.get("max_input_tokens")
max_out = model_obj.get("max_output_tokens")
if max_in is None or max_out is None:
# Can't calculate precisely, use the fallback method
return None
# Some APIs may not actually work like this, but it's a safer approach generally speaking
# since worst case we remove some extra tokens from the input space
output_token_debt = 0
if output_tokens > max_out:
logger.warning(
"More output tokens requested than model is likely able to handle"
)
output_token_debt = output_tokens - max_out
remaining_max_input_tokens = max_in - output_token_debt
return remaining_max_input_tokens
except Exception:
# We can try the less accurate approach if this fails
return None
def get_max_input_tokens(
model_name: str | None = GEN_AI_MODEL_VERSION,
model_provider: str = GEN_AI_MODEL_PROVIDER,
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
) -> int:
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
# and there is no other interface to get what we want. This should be okay though, since the
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost
litellm_model_map = litellm.model_cost
model_name = model_name or get_default_llm_version()[0]
input_toks = (
get_llm_max_tokens(model_name=model_name, model_provider=model_provider)
- output_tokens
input_toks = get_llm_max_input_tokens(
output_tokens=output_tokens,
model_map=litellm_model_map,
model_name=model_name,
model_provider=model_provider,
)
if input_toks is None:
input_toks = (
get_llm_max_tokens(
model_name=model_name,
model_provider=model_provider,
model_map=litellm_model_map,
)
- output_tokens
)
if input_toks <= 0:
raise RuntimeError("No tokens for input for the LLM given settings")

View File

@ -24,7 +24,7 @@ httpx-oauth==0.11.2
huggingface-hub==0.20.1
jira==3.5.1
langchain==0.1.9
litellm==1.27.10
litellm==1.34.8
llama-index==0.9.45
Mako==1.2.4
msal==1.26.0