Add Request Model Class for Google Colab Demo (#273)

Need to add the blog links later
This commit is contained in:
Yuhong Sun
2023-08-08 00:09:11 -07:00
committed by GitHub
parent ca72027b28
commit 02c3139bc9
3 changed files with 75 additions and 6 deletions

View File

@@ -52,4 +52,6 @@ class ModelHostType(str, Enum):
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
# TODO link blog here
COLAB_DEMO = "colab_demo"
# TODO support for Azure, AWS, GCP GenAI model hosting

View File

@@ -84,7 +84,10 @@ def get_default_backend_qa_model(
raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if model_host_type == ModelHostType.HUGGINGFACE.value:
if (
model_host_type == ModelHostType.HUGGINGFACE.value
or model_host_type == ModelHostType.COLAB_DEMO.value
):
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context

View File

@@ -4,6 +4,7 @@ from collections.abc import Generator
from typing import Any
import requests
from requests.exceptions import Timeout
from requests.models import Response
from danswer.chunking.models import InferenceChunk
@@ -32,6 +33,12 @@ class HostSpecificRequestModel(abc.ABC):
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
"""
@property
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
return True
@staticmethod
@abc.abstractmethod
def send_model_request(
@@ -92,8 +99,8 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
}
try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except TimeoutError as error:
raise TimeoutError(f"Model inference to {endpoint} timed out") from error
except Timeout as error:
raise Timeout(f"Model inference to {endpoint} timed out") from error
@staticmethod
def _hf_extract_model_output(
@@ -121,14 +128,71 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
yield from simulate_streaming_response(model_out)
class ColabDemoRequestModel(HostSpecificRequestModel):
"""Guide found at:
TODO place guide here
"""
@property
def requires_api_key(self) -> bool:
return False
@staticmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None, # ngrok basic setup doesn't require this
max_output_tokens: int,
stream: bool,
timeout: int | None,
) -> Response:
headers = {
"Content-Type": "application/json",
}
data = {
"input": filled_prompt,
"parameters": {
"temperature": 0.0,
"max_tokens": max_output_tokens,
},
}
try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except Timeout as error:
raise Timeout(f"Model inference to {endpoint} timed out") from error
@staticmethod
def _colab_demo_extract_model_output(
response: Response,
) -> str:
if response.status_code != 200:
response.raise_for_status()
return json.loads(response.content).get("generated_text", "")
@staticmethod
def extract_model_output_from_response(
response: Response,
) -> str:
return ColabDemoRequestModel._colab_demo_extract_model_output(response)
@staticmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response)
yield from simulate_streaming_response(model_out)
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
if model_host_type == ModelHostType.HUGGINGFACE.value:
return HuggingFaceRequestModel()
if model_host_type == ModelHostType.COLAB_DEMO.value:
return ColabDemoRequestModel()
else:
# TODO support Azure, GCP, AWS
raise ValueError(
"Invalid model hosting service selected, currently supports only huggingface"
)
raise ValueError("Invalid model hosting service selected")
class RequestCompletionQA(QAModel):