mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-05 17:53:54 +02:00
Add Request Model Class for Google Colab Demo (#273)
Need to add the blog links later
This commit is contained in:
@@ -52,4 +52,6 @@ class ModelHostType(str, Enum):
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
|
||||
# TODO link blog here
|
||||
COLAB_DEMO = "colab_demo"
|
||||
# TODO support for Azure, AWS, GCP GenAI model hosting
|
||||
|
@@ -84,7 +84,10 @@ def get_default_backend_qa_model(
|
||||
raise ValueError(
|
||||
"Request based GenAI model requires an endpoint and host type"
|
||||
)
|
||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
||||
if (
|
||||
model_host_type == ModelHostType.HUGGINGFACE.value
|
||||
or model_host_type == ModelHostType.COLAB_DEMO.value
|
||||
):
|
||||
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
||||
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
||||
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
||||
|
@@ -4,6 +4,7 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.exceptions import Timeout
|
||||
from requests.models import Response
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
@@ -32,6 +33,12 @@ class HostSpecificRequestModel(abc.ABC):
|
||||
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
|
||||
"""
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
"""Is this model protected by security features
|
||||
Does it need an api key to access the model for inference"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def send_model_request(
|
||||
@@ -92,8 +99,8 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
|
||||
}
|
||||
try:
|
||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
||||
except TimeoutError as error:
|
||||
raise TimeoutError(f"Model inference to {endpoint} timed out") from error
|
||||
except Timeout as error:
|
||||
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
||||
|
||||
@staticmethod
|
||||
def _hf_extract_model_output(
|
||||
@@ -121,14 +128,71 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
|
||||
yield from simulate_streaming_response(model_out)
|
||||
|
||||
|
||||
class ColabDemoRequestModel(HostSpecificRequestModel):
|
||||
"""Guide found at:
|
||||
TODO place guide here
|
||||
"""
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def send_model_request(
|
||||
filled_prompt: str,
|
||||
endpoint: str,
|
||||
api_key: str | None, # ngrok basic setup doesn't require this
|
||||
max_output_tokens: int,
|
||||
stream: bool,
|
||||
timeout: int | None,
|
||||
) -> Response:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"input": filled_prompt,
|
||||
"parameters": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": max_output_tokens,
|
||||
},
|
||||
}
|
||||
try:
|
||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
||||
except Timeout as error:
|
||||
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
||||
|
||||
@staticmethod
|
||||
def _colab_demo_extract_model_output(
|
||||
response: Response,
|
||||
) -> str:
|
||||
if response.status_code != 200:
|
||||
response.raise_for_status()
|
||||
|
||||
return json.loads(response.content).get("generated_text", "")
|
||||
|
||||
@staticmethod
|
||||
def extract_model_output_from_response(
|
||||
response: Response,
|
||||
) -> str:
|
||||
return ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
||||
|
||||
@staticmethod
|
||||
def generate_model_tokens_from_response(
|
||||
response: Response,
|
||||
) -> Generator[str, None, None]:
|
||||
model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
||||
yield from simulate_streaming_response(model_out)
|
||||
|
||||
|
||||
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
|
||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
||||
return HuggingFaceRequestModel()
|
||||
if model_host_type == ModelHostType.COLAB_DEMO.value:
|
||||
return ColabDemoRequestModel()
|
||||
else:
|
||||
# TODO support Azure, GCP, AWS
|
||||
raise ValueError(
|
||||
"Invalid model hosting service selected, currently supports only huggingface"
|
||||
)
|
||||
raise ValueError("Invalid model hosting service selected")
|
||||
|
||||
|
||||
class RequestCompletionQA(QAModel):
|
||||
|
Reference in New Issue
Block a user