mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 21:26:01 +02:00
Add Request Model Class for Google Colab Demo (#273)
Need to add the blog links later
This commit is contained in:
@@ -52,4 +52,6 @@ class ModelHostType(str, Enum):
|
|||||||
|
|
||||||
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task
|
||||||
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
|
HUGGINGFACE = "huggingface" # HuggingFace test-generation Inference API
|
||||||
|
# TODO link blog here
|
||||||
|
COLAB_DEMO = "colab_demo"
|
||||||
# TODO support for Azure, AWS, GCP GenAI model hosting
|
# TODO support for Azure, AWS, GCP GenAI model hosting
|
||||||
|
@@ -84,7 +84,10 @@ def get_default_backend_qa_model(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Request based GenAI model requires an endpoint and host type"
|
"Request based GenAI model requires an endpoint and host type"
|
||||||
)
|
)
|
||||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
if (
|
||||||
|
model_host_type == ModelHostType.HUGGINGFACE.value
|
||||||
|
or model_host_type == ModelHostType.COLAB_DEMO.value
|
||||||
|
):
|
||||||
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
||||||
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
||||||
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
||||||
|
@@ -4,6 +4,7 @@ from collections.abc import Generator
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from requests.exceptions import Timeout
|
||||||
from requests.models import Response
|
from requests.models import Response
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
@@ -32,6 +33,12 @@ class HostSpecificRequestModel(abc.ABC):
|
|||||||
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
|
hosted behind REST APIs. Calling class abstracts away all Danswer internal specifics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_api_key(self) -> bool:
|
||||||
|
"""Is this model protected by security features
|
||||||
|
Does it need an api key to access the model for inference"""
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def send_model_request(
|
def send_model_request(
|
||||||
@@ -92,8 +99,8 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
|
|||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
||||||
except TimeoutError as error:
|
except Timeout as error:
|
||||||
raise TimeoutError(f"Model inference to {endpoint} timed out") from error
|
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _hf_extract_model_output(
|
def _hf_extract_model_output(
|
||||||
@@ -121,14 +128,71 @@ class HuggingFaceRequestModel(HostSpecificRequestModel):
|
|||||||
yield from simulate_streaming_response(model_out)
|
yield from simulate_streaming_response(model_out)
|
||||||
|
|
||||||
|
|
||||||
|
class ColabDemoRequestModel(HostSpecificRequestModel):
|
||||||
|
"""Guide found at:
|
||||||
|
TODO place guide here
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_api_key(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def send_model_request(
|
||||||
|
filled_prompt: str,
|
||||||
|
endpoint: str,
|
||||||
|
api_key: str | None, # ngrok basic setup doesn't require this
|
||||||
|
max_output_tokens: int,
|
||||||
|
stream: bool,
|
||||||
|
timeout: int | None,
|
||||||
|
) -> Response:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"input": filled_prompt,
|
||||||
|
"parameters": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": max_output_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
return requests.post(endpoint, headers=headers, json=data, timeout=timeout)
|
||||||
|
except Timeout as error:
|
||||||
|
raise Timeout(f"Model inference to {endpoint} timed out") from error
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _colab_demo_extract_model_output(
|
||||||
|
response: Response,
|
||||||
|
) -> str:
|
||||||
|
if response.status_code != 200:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return json.loads(response.content).get("generated_text", "")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_model_output_from_response(
|
||||||
|
response: Response,
|
||||||
|
) -> str:
|
||||||
|
return ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_model_tokens_from_response(
|
||||||
|
response: Response,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
model_out = ColabDemoRequestModel._colab_demo_extract_model_output(response)
|
||||||
|
yield from simulate_streaming_response(model_out)
|
||||||
|
|
||||||
|
|
||||||
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
|
def get_host_specific_model_class(model_host_type: str) -> HostSpecificRequestModel:
|
||||||
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
if model_host_type == ModelHostType.HUGGINGFACE.value:
|
||||||
return HuggingFaceRequestModel()
|
return HuggingFaceRequestModel()
|
||||||
|
if model_host_type == ModelHostType.COLAB_DEMO.value:
|
||||||
|
return ColabDemoRequestModel()
|
||||||
else:
|
else:
|
||||||
# TODO support Azure, GCP, AWS
|
# TODO support Azure, GCP, AWS
|
||||||
raise ValueError(
|
raise ValueError("Invalid model hosting service selected")
|
||||||
"Invalid model hosting service selected, currently supports only huggingface"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RequestCompletionQA(QAModel):
|
class RequestCompletionQA(QAModel):
|
||||||
|
Reference in New Issue
Block a user