diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/check_sub_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/check_sub_answer.py index 2374e9624..ab6bdddb3 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/check_sub_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/check_sub_answer.py @@ -31,6 +31,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( get_langgraph_node_log_string, ) from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK from onyx.llm.chat_llm import LLMRateLimitError @@ -92,6 +93,7 @@ def check_sub_answer( fast_llm.invoke, prompt=msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK, + max_tokens=AGENT_MAX_TOKENS_VALIDATION, ) quality_str: str = cast(str, response.content) diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py index b5bc378bd..b321b962f 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py @@ -46,6 +46,7 @@ from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamType from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION from onyx.llm.chat_llm import LLMRateLimitError @@ -119,6 +120,7 @@ def generate_sub_answer( for message in fast_llm.stream( prompt=msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION, + max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION, ): # TODO: in principle, the answer here COULD contain images, but we don't support that yet content = message.content diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py index 7a7c8ffc2..297e2a426 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py @@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin from onyx.agents.agent_search.shared_graph_utils.operators import ( dedup_inference_section_list, ) +from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens from onyx.agents.agent_search.shared_graph_utils.utils import ( dispatch_main_answer_stop_info, ) @@ -62,6 +63,7 @@ from onyx.chat.models import StreamingError from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS from onyx.configs.agent_configs import ( AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION, @@ -279,6 +281,9 @@ def generate_initial_answer( for message in model.stream( msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION, + max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION + if _should_restrict_tokens(model.config) + else None, ): # TODO: in principle, the answer here COULD contain images, but we don't support that yet content = message.content diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py index fbe231d13..33e8577db 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_sub_answers/nodes/decompose_orig_question.py @@ -34,6 +34,7 @@ from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamType from onyx.chat.models import SubQuestionPiece +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION from onyx.configs.agent_configs import ( AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION, @@ -141,6 +142,7 @@ def decompose_orig_question( model.stream( msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION, + max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION, ), dispatch_subquestion(0, writer), sep_callback=dispatch_subquestion_sep(0, writer), diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py index 438f09301..0624a6447 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/compare_answers.py @@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( ) from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import RefinedAnswerImprovement +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS from onyx.llm.chat_llm import LLMRateLimitError @@ -112,6 +113,7 @@ def compare_answers( model.invoke, prompt=msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS, + max_tokens=AGENT_MAX_TOKENS_VALIDATION, ) except (LLMTimeoutError, TimeoutError): diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py index a63c7fa53..d9e15b189 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/create_refined_sub_questions.py @@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import StreamingError +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION from onyx.configs.agent_configs import ( AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION, ) @@ -144,6 +145,7 @@ def create_refined_sub_questions( model.stream( msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION, + max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION, ), dispatch_subquestion(1, writer), sep_callback=dispatch_subquestion_sep(1, writer), diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py index c56bf211b..731db32fd 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/decide_refinement_need.py @@ -50,13 +50,7 @@ def decide_refinement_need( ) ] - if graph_config.behavior.allow_refinement: - return RequireRefinemenEvalUpdate( - require_refined_answer_eval=decision, - log_messages=log_messages, - ) - else: - return RequireRefinemenEvalUpdate( - require_refined_answer_eval=False, - log_messages=log_messages, - ) + return RequireRefinemenEvalUpdate( + require_refined_answer_eval=graph_config.behavior.allow_refinement and decision, + log_messages=log_messages, + ) diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py index bccf390cf..58451f819 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py @@ -21,6 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import ( get_langgraph_node_log_string, ) +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION from onyx.configs.agent_configs import ( AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION, ) @@ -96,6 +97,7 @@ def extract_entities_terms( fast_llm.invoke, prompt=msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION, + max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION, ) cleaned_response = ( diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py index b17c39a6d..856f816f3 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py @@ -46,6 +46,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats from onyx.agents.agent_search.shared_graph_utils.operators import ( dedup_inference_section_list, ) +from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens from onyx.agents.agent_search.shared_graph_utils.utils import ( dispatch_main_answer_stop_info, ) @@ -68,6 +69,8 @@ from onyx.chat.models import StreamingError from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS from onyx.configs.agent_configs import ( AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION, @@ -303,7 +306,11 @@ def generate_validate_refined_answer( def stream_refined_answer() -> list[str]: for message in model.stream( - msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION + msg, + timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION, + max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION + if _should_restrict_tokens(model.config) + else None, ): # TODO: in principle, the answer here COULD contain images, but we don't support that yet content = message.content @@ -410,6 +417,7 @@ def generate_validate_refined_answer( validation_model.invoke, prompt=msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION, + max_tokens=AGENT_MAX_TOKENS_VALIDATION, ) refined_answer_quality = binary_string_test_after_answer_separator( text=cast(str, validation_response.content), diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/expand_queries.py index e325efc7d..3b7898138 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/expand_queries.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/expand_queries.py @@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( get_langgraph_node_log_string, ) from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION from onyx.configs.agent_configs import ( AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION, ) @@ -96,6 +97,7 @@ def expand_queries( model.stream( prompt=msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION, + max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION, ), dispatch_subquery(level, question_num, writer), ) diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/verify_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/verify_documents.py index f9f23d868..077ac6e17 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/verify_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/verify_documents.py @@ -25,6 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin from onyx.agents.agent_search.shared_graph_utils.utils import ( get_langgraph_node_log_string, ) +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION from onyx.llm.chat_llm import LLMRateLimitError @@ -93,6 +94,7 @@ def verify_documents( fast_llm.invoke, prompt=msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION, + max_tokens=AGENT_MAX_TOKENS_VALIDATION, ) assert isinstance(response.content, str) diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index 672706d18..7b8df025a 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -42,6 +42,7 @@ from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamType from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder +from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY from onyx.configs.agent_configs import ( AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION, ) @@ -61,6 +62,7 @@ from onyx.db.persona import Persona from onyx.llm.chat_llm import LLMRateLimitError from onyx.llm.chat_llm import LLMTimeoutError from onyx.llm.interfaces import LLM +from onyx.llm.interfaces import LLMConfig from onyx.prompts.agent_search import ( ASSISTANT_SYSTEM_PROMPT_DEFAULT, ) @@ -402,6 +404,7 @@ def summarize_history( llm.invoke, history_context_prompt, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION, + max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY, ) except (LLMTimeoutError, TimeoutError): logger.error("LLM Timeout Error - summarize history") @@ -505,3 +508,9 @@ def get_deduplicated_structured_subquestion_documents( cited_documents=dedup_inference_section_list(cited_docs), context_documents=dedup_inference_section_list(context_docs), ) + + +def _should_restrict_tokens(llm_config: LLMConfig) -> bool: + return not ( + llm_config.model_provider == "openai" and llm_config.model_name.startswith("o") + ) diff --git a/backend/onyx/configs/agent_configs.py b/backend/onyx/configs/agent_configs.py index 0bb8dcf5c..791e12441 100644 --- a/backend/onyx/configs/agent_configs.py +++ b/backend/onyx/configs/agent_configs.py @@ -217,20 +217,20 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int( ) -AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds +AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION ) -AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds +AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION ) -AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds +AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION @@ -243,13 +243,13 @@ AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int( ) -AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds +AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION ) -AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds +AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION @@ -333,4 +333,45 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int( or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION ) +AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4 +AGENT_MAX_TOKENS_VALIDATION = int( + os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION +) + +AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256 +AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int( + os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION") + or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION +) + +AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024 +AGENT_MAX_TOKENS_ANSWER_GENERATION = int( + os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION") + or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION +) + +AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256 +AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int( + os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION") + or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION +) + +AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024 +AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int( + os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION") + or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION +) + +AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64 +AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int( + os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION") + or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION +) + +AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128 +AGENT_MAX_TOKENS_HISTORY_SUMMARY = int( + os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY") + or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY +) + GRAPH_VERSION_NAME: str = "a" diff --git a/backend/onyx/llm/chat_llm.py b/backend/onyx/llm/chat_llm.py index 8a16333ac..2e9496856 100644 --- a/backend/onyx/llm/chat_llm.py +++ b/backend/onyx/llm/chat_llm.py @@ -167,7 +167,7 @@ def _convert_delta_to_message_chunk( stop_reason: str | None = None, ) -> BaseMessageChunk: """Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk""" - role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None) + role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown") content = _dict.get("content") or "" additional_kwargs = {} if _dict.get("function_call"): @@ -402,6 +402,7 @@ class DefaultMultiLLM(LLM): stream: bool, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> litellm.ModelResponse | litellm.CustomStreamWrapper: # litellm doesn't accept LangChain BaseMessage objects, so we need to convert them # to a dict representation @@ -429,6 +430,7 @@ class DefaultMultiLLM(LLM): # model params temperature=0, timeout=timeout_override or self._timeout, + max_tokens=max_tokens, # For now, we don't support parallel tool calls # NOTE: we can't pass this in if tools are not specified # or else OpenAI throws an error @@ -484,6 +486,7 @@ class DefaultMultiLLM(LLM): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> BaseMessage: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() @@ -497,6 +500,7 @@ class DefaultMultiLLM(LLM): stream=False, structured_response_format=structured_response_format, timeout_override=timeout_override, + max_tokens=max_tokens, ), ) choice = response.choices[0] @@ -515,6 +519,7 @@ class DefaultMultiLLM(LLM): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> Iterator[BaseMessage]: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() @@ -539,6 +544,7 @@ class DefaultMultiLLM(LLM): stream=True, structured_response_format=structured_response_format, timeout_override=timeout_override, + max_tokens=max_tokens, ), ) try: diff --git a/backend/onyx/llm/custom_llm.py b/backend/onyx/llm/custom_llm.py index 690a84e0c..ce5af5621 100644 --- a/backend/onyx/llm/custom_llm.py +++ b/backend/onyx/llm/custom_llm.py @@ -82,6 +82,7 @@ class CustomModelServer(LLM): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> BaseMessage: return self._execute(prompt) @@ -92,5 +93,6 @@ class CustomModelServer(LLM): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> Iterator[BaseMessage]: yield self._execute(prompt) diff --git a/backend/onyx/llm/interfaces.py b/backend/onyx/llm/interfaces.py index 52c502de3..bfc3a5862 100644 --- a/backend/onyx/llm/interfaces.py +++ b/backend/onyx/llm/interfaces.py @@ -91,12 +91,18 @@ class LLM(abc.ABC): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> BaseMessage: self._precall(prompt) # TODO add a postcall to log model outputs independent of concrete class # implementation return self._invoke_implementation( - prompt, tools, tool_choice, structured_response_format, timeout_override + prompt, + tools, + tool_choice, + structured_response_format, + timeout_override, + max_tokens, ) @abc.abstractmethod @@ -107,6 +113,7 @@ class LLM(abc.ABC): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> BaseMessage: raise NotImplementedError @@ -117,12 +124,18 @@ class LLM(abc.ABC): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> Iterator[BaseMessage]: self._precall(prompt) # TODO add a postcall to log model outputs independent of concrete class # implementation messages = self._stream_implementation( - prompt, tools, tool_choice, structured_response_format, timeout_override + prompt, + tools, + tool_choice, + structured_response_format, + timeout_override, + max_tokens, ) tokens = [] @@ -142,5 +155,6 @@ class LLM(abc.ABC): tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, + max_tokens: int | None = None, ) -> Iterator[BaseMessage]: raise NotImplementedError diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 153fedc9e..366f50814 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -25,7 +25,7 @@ from onyx.indexing.models import IndexingSetting from onyx.setup import setup_postgres from onyx.setup import setup_vespa from onyx.utils.logger import setup_logger -from tests.integration.common_utils.timeout import run_with_timeout +from tests.integration.common_utils.timeout import run_with_timeout_multiproc logger = setup_logger() @@ -161,7 +161,7 @@ def reset_postgres( for _ in range(NUM_TRIES): logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})") try: - run_with_timeout( + run_with_timeout_multiproc( downgrade_postgres, TIMEOUT, kwargs={ diff --git a/backend/tests/integration/common_utils/timeout.py b/backend/tests/integration/common_utils/timeout.py index 64dacecaf..52c5ac0a0 100644 --- a/backend/tests/integration/common_utils/timeout.py +++ b/backend/tests/integration/common_utils/timeout.py @@ -6,7 +6,9 @@ from typing import TypeVar T = TypeVar("T") -def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T: +def run_with_timeout_multiproc( + task: Callable[..., T], timeout: int, kwargs: dict[str, Any] +) -> T: # Use multiprocessing to prevent a thread from blocking the main thread with multiprocessing.Pool(processes=1) as pool: async_result = pool.apply_async(task, kwds=kwargs) diff --git a/backend/tests/unit/onyx/llm/test_chat_llm.py b/backend/tests/unit/onyx/llm/test_chat_llm.py index 0e3620292..b69b3b7de 100644 --- a/backend/tests/unit/onyx/llm/test_chat_llm.py +++ b/backend/tests/unit/onyx/llm/test_chat_llm.py @@ -145,6 +145,7 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None: timeout=30, parallel_tool_calls=False, mock_response=MOCK_LLM_RESPONSE, + max_tokens=None, ) @@ -290,4 +291,5 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No timeout=30, parallel_tool_calls=False, mock_response=MOCK_LLM_RESPONSE, + max_tokens=None, )