Custom OpenAI Model Server (#782)

This commit is contained in:
Yuhong Sun 2023-11-29 01:41:56 -08:00 committed by GitHub
parent 37daf4f3e4
commit c2727a3f19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 5 deletions

View File

@ -90,6 +90,8 @@ GEN_AI_API_KEY = (
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
# API Version, such as (for Azure): 2023-09-15-preview
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)

View File

@ -9,6 +9,7 @@ from langchain.schema.language_model import LanguageModelInput
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
@ -42,8 +43,10 @@ class LangChainChatLLM(LLM, abc.ABC):
logger.debug(f"Prompt:\n{prompt}")
def log_model_configs(self) -> None:
llm_dict = {k: v for k, v in self.llm.__dict__.items() if v}
llm_dict.pop("client")
logger.info(
f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {llm_dict}"
)
def invoke(self, prompt: LanguageModelInput) -> str:
@ -105,21 +108,25 @@ class DefaultMultiLLM(LangChainChatLLM):
self,
api_key: str | None,
timeout: int,
model_provider: str | None = GEN_AI_MODEL_PROVIDER,
model_version: str | None = GEN_AI_MODEL_VERSION,
model_provider: str = GEN_AI_MODEL_PROVIDER,
model_version: str = GEN_AI_MODEL_VERSION,
api_base: str | None = GEN_AI_API_ENDPOINT,
api_version: str | None = GEN_AI_API_VERSION,
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
):
# Litellm Langchain integration currently doesn't take in the api key param
# Can place this in the call below once integration is in
litellm.api_key = api_key
litellm.api_key = api_key or "dummy-key"
litellm.api_version = api_version
self._llm = ChatLiteLLM( # type: ignore
model=_get_model_str(model_provider, model_version),
model=model_version
if custom_llm_provider
else _get_model_str(model_provider, model_version),
api_base=api_base,
custom_llm_provider=custom_llm_provider,
max_tokens=max_output_tokens,
temperature=temperature,
request_timeout=timeout,

View File

@ -22,6 +22,7 @@ services:
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
@ -87,6 +88,7 @@ services:
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- POSTGRES_HOST=relational_db