Update LLM related Libs (#771)

This commit is contained in:
Yuhong Sun 2023-11-26 19:54:16 -08:00 committed by GitHub
parent 39d09a162a
commit 05c2b7d34e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 27 deletions

View File

@ -220,7 +220,7 @@ class QABlock(QAModel):
prompt_tokens = sum(
[
check_number_of_tokens(
text=p.content, encode_fn=get_default_llm_token_encode()
text=str(p.content), encode_fn=get_default_llm_token_encode()
)
for p in prompt
]

View File

@ -43,7 +43,7 @@ class LangChainChatLLM(LLM, abc.ABC):
def log_model_configs(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
def invoke(self, prompt: LanguageModelInput) -> str:
@ -54,6 +54,12 @@ class LangChainChatLLM(LLM, abc.ABC):
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{model_raw}")
if not isinstance(model_raw, str):
raise RuntimeError(
"Model output inconsistent with expected type, "
"is this related to a library upgrade?"
)
return model_raw
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:

View File

@ -137,6 +137,9 @@ def message_generator_to_string_generator(
messages: Iterator[BaseMessageChunk],
) -> Iterator[str]:
for message in messages:
if not isinstance(message.content, str):
raise RuntimeError("LLM message not in expected format.")
yield message.content

View File

@ -1,3 +1,4 @@
import logging
import os
import numpy as np
@ -30,6 +31,8 @@ from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
logger = setup_logger()
# Remove useless info about layer initialization
logging.getLogger("transformers").setLevel(logging.ERROR)
_TOKENIZER: None | AutoTokenizer = None

View File

@ -1,7 +1,7 @@
alembic==1.10.4
asyncpg==0.27.0
atlassian-python-api==3.37.0
beautifulsoup4==4.12.0
beautifulsoup4==4.12.2
celery==5.3.4
dask==2023.8.1
distributed==2023.8.1
@ -21,13 +21,13 @@ httpx==0.23.3
httpx-oauth==0.11.2
huggingface-hub==0.16.4
jira==3.5.1
langchain==0.0.325
litellm==0.12.5
llama-index==0.8.27
langchain==0.0.340
litellm==1.7.5
llama-index==0.9.8
Mako==1.2.4
nltk==3.8.1
docx2txt==0.8
openai==0.27.6
openai==1.3.5
oauthlib==3.2.2
playwright==1.37.0
psutil==5.9.5

View File

@ -1,9 +1,22 @@
from typing import cast
import openai
from openai import OpenAI
VALID_MODEL_LIST = ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]
VALID_MODEL_LIST = [
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
]
if __name__ == "__main__":
@ -12,29 +25,28 @@ if __name__ == "__main__":
model_version = input("Please provide an OpenAI model version to test: ")
if model_version not in VALID_MODEL_LIST:
print(f"Model must be from valid list: {', '.join(VALID_MODEL_LIST)}")
assert model_version
api_key = input("Please provide an OpenAI API Key to test: ")
openai.api_key = api_key
client = OpenAI(
api_key=api_key,
)
prompt = "The boy went to the "
print(f"Asking OpenAI to finish the sentence using {model_version}")
print(prompt)
try:
if model_version == "text-davinci-003":
response = openai.Completion.create(
model=model_version, prompt=prompt, max_tokens=5, temperature=2
)
print(cast(str, response["choices"][0]["text"]).strip())
else:
messages = [
{"role": "system", "content": "Finish the sentence"},
{"role": "user", "content": prompt},
]
response = openai.ChatCompletion.create(
model=model_version, messages=messages, max_tokens=5, temperature=2
)
print(cast(str, response["choices"][0]["message"]["content"]).strip())
messages = [
{"role": "system", "content": "Finish the sentence"},
{"role": "user", "content": prompt},
]
response = client.chat.completions.create(
model=model_version,
messages=messages, # type:ignore
max_tokens=5,
temperature=2,
)
print(response.choices[0].message.content)
print("Success! Feel free to use this API key for Danswer.")
except Exception:
print(