diff --git a/backend/onyx/llm/chat_llm.py b/backend/onyx/llm/chat_llm.py index 50dc1ab770..002287fae1 100644 --- a/backend/onyx/llm/chat_llm.py +++ b/backend/onyx/llm/chat_llm.py @@ -51,6 +51,7 @@ litellm.drop_params = True litellm.telemetry = False _LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt" +VERTEX_CREDENTIALS_KWARG = "vertex_credentials" class LLMTimeoutError(Exception): @@ -303,11 +304,10 @@ class DefaultMultiLLM(LLM): # Specifically pass in "vertex_credentials" / "vertex_location" as a # model_kwarg to the completion call for vertex AI. More details here: # https://docs.litellm.ai/docs/providers/vertex - vertex_credentials_key = "vertex_credentials" vertex_location_key = "vertex_location" for k, v in custom_config.items(): if model_provider == "vertex_ai": - if k == vertex_credentials_key: + if k == VERTEX_CREDENTIALS_KWARG: model_kwargs[k] = v continue elif k == vertex_location_key: @@ -412,6 +412,13 @@ class DefaultMultiLLM(LLM): processed_prompt = _prompt_to_dict(prompt) self._record_call(processed_prompt) + final_model_kwargs = {**self._model_kwargs} + if ( + VERTEX_CREDENTIALS_KWARG not in final_model_kwargs + and self.config.credentials_file + ): + final_model_kwargs[VERTEX_CREDENTIALS_KWARG] = self.config.credentials_file + try: return litellm.completion( mock_response=MOCK_LLM_RESPONSE, @@ -457,14 +464,7 @@ class DefaultMultiLLM(LLM): if structured_response_format else {} ), - **( - { - "vertex_credentials": self.config.credentials_file, - } - if self.config.model_provider == "vertex_ai" - else {} - ), - **self._model_kwargs, + **final_model_kwargs, ) except Exception as e: self._record_error(processed_prompt, e)