Add Request Model Class for Google Colab Demo (#273)

Need to add the blog links later
This commit is contained in:
Yuhong Sun
2023-08-08 00:09:11 -07:00
committed by GitHub
parent ca72027b28
commit 02c3139bc9
3 changed files with 75 additions and 6 deletions

View File

@@ -52,4 +52,6 @@ class ModelHostType(str, Enum):
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
# TODO link blog here
COLAB_DEMO = "colab_demo"
# TODO support for Azure, AWS, GCP GenAI model hosting # TODO support for Azure, AWS, GCP GenAI model hosting

View File

@@ -84,7 +84,10 @@ def get_default_backend_qa_model(
raise ValueError( raise ValueError(
"Request based GenAI model requires an endpoint and host type" "Request based GenAI model requires an endpoint and host type"
) )
if model_host_type == ModelHostType.HUGGINGFACE.value: if (
model_host_type == ModelHostType.HUGGINGFACE.value
or model_host_type == ModelHostType.COLAB_DEMO.value
):
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits # Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens # With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context # This is the sum of input and output tokens, so cannot take in full Danswer context

View File

@@ -4,6 +4,7 @@ from collections.abc import Generator
from typing import Any from typing import Any
import requests import requests
from requests.exceptions import Timeout
from requests.models import Response from requests.models import Response
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
@@ -32,6 +33,12 @@ class HostSpecificRequestModel(abc.ABC):
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
""" """
@property
def requires_api_key(self) -> bool:
"""Is this model protected by security features
Does it need an api key to access the model for inference"""
return True
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def send_model_request( def send_model_request(
@@ -92,8 +99,8 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
} }
try: try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout) return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except TimeoutError as error: except Timeout as error:
raise TimeoutError(f"Model inference to {endpoint} timed out") from error raise Timeout(f"Model inference to {endpoint} timed out") from error
@staticmethod @staticmethod
def _hf_extract_model_output( def _hf_extract_model_output(
@@ -121,14 +128,71 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
yield from simulate_streaming_response(model_out) yield from simulate_streaming_response(model_out)
class ColabDemoRequestModel(HostSpecificRequestModel):
"""Guide found at:
TODO place guide here
"""
@property
def requires_api_key(self) -> bool:
return False
@staticmethod
def send_model_request(
filled_prompt: str,
endpoint: str,
api_key: str | None, # ngrok basic setup doesn't require this
max_output_tokens: int,
stream: bool,
timeout: int | None,
) -> Response:
headers = {
"Content-Type": "application/json",
}
data = {
"input": filled_prompt,
"parameters": {
"temperature": 0.0,
"max_tokens": max_output_tokens,
},
}
try:
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
except Timeout as error:
raise Timeout(f"Model inference to {endpoint} timed out") from error
@staticmethod
def _colab_demo_extract_model_output(
response: Response,
) -> str:
if response.status_code != 200:
response.raise_for_status()
return json.loads(response.content).get("generated_text", "")
@staticmethod
def extract_model_output_from_response(
response: Response,
) -> str:
return ColabDemoRequestModel._colab_demo_extract_model_output(response)
@staticmethod
def generate_model_tokens_from_response(
response: Response,
) -> Generator[str, None, None]:
model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response)
yield from simulate_streaming_response(model_out)
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel: def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
if model_host_type == ModelHostType.HUGGINGFACE.value: if model_host_type == ModelHostType.HUGGINGFACE.value:
return HuggingFaceRequestModel() return HuggingFaceRequestModel()
if model_host_type == ModelHostType.COLAB_DEMO.value:
return ColabDemoRequestModel()
else: else:
# TODO support Azure, GCP, AWS # TODO support Azure, GCP, AWS
raise ValueError( raise ValueError("Invalid model hosting service selected")
"Invalid model hosting service selected, currently supports only huggingface"
)
class RequestCompletionQA(QAModel): class RequestCompletionQA(QAModel):