This commit is contained in:
Timothy J. Baek 2024-07-11 15:20:56 -07:00
parent 4dd77b785a
commit 8dcb3d78dc
2 changed files with 30 additions and 6 deletions

View File

@ -47,6 +47,8 @@ from config import (
OAUTH_PICTURE_CLAIM,
)
from apps.socket.main import get_event_call, get_event_emitter
import inspect
import uuid
import time
@ -197,8 +199,21 @@ async def generate_function_chat_completion(form_data, user):
metadata = form_data["metadata"]
del form_data["metadata"]
__event_emitter__ = None
__event_call__ = None
__task__ = None
if metadata:
print(metadata)
if (
metadata.get("session_id")
and metadata.get("chat_id")
and metadata.get("message_id")
):
__event_emitter__ = await get_event_emitter(metadata)
__event_call__ = await get_event_call(metadata)
if metadata.get("task"):
__task__ = metadata.get("task")
if model_info:
if model_info.base_model_id:
@ -314,6 +329,15 @@ async def generate_function_chat_completion(form_data, user):
params = {**params, "__user__": __user__}
if "__event_emitter__" in sig.parameters:
params = {**params, "__event_emitter__": __event_emitter__}
if "__event_call__" in sig.parameters:
params = {**params, "__event_call__": __event_call__}
if "__task__" in sig.parameters:
params = {**params, "__task__": __task__}
if form_data["stream"]:
async def stream_content():

View File

@ -95,8 +95,8 @@ class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
FUNCTION_CALLING = "function_calling"