From 8dcb3d78dccc85cb88847dccfae6b196dd9330dc Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 11 Jul 2024 15:20:56 -0700 Subject: [PATCH] refac --- backend/apps/webui/main.py | 26 +++++++++++++++++++++++++- backend/constants.py | 10 +++++----- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 7a0be2d22..325370482 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -47,6 +47,8 @@ from config import ( OAUTH_PICTURE_CLAIM, ) +from apps.socket.main import get_event_call, get_event_emitter + import inspect import uuid import time @@ -197,8 +199,21 @@ async def generate_function_chat_completion(form_data, user): metadata = form_data["metadata"] del form_data["metadata"] + __event_emitter__ = None + __event_call__ = None + __task__ = None + if metadata: - print(metadata) + if ( + metadata.get("session_id") + and metadata.get("chat_id") + and metadata.get("message_id") + ): + __event_emitter__ = await get_event_emitter(metadata) + __event_call__ = await get_event_call(metadata) + + if metadata.get("task"): + __task__ = metadata.get("task") if model_info: if model_info.base_model_id: @@ -314,6 +329,15 @@ async def generate_function_chat_completion(form_data, user): params = {**params, "__user__": __user__} + if "__event_emitter__" in sig.parameters: + params = {**params, "__event_emitter__": __event_emitter__} + + if "__event_call__" in sig.parameters: + params = {**params, "__event_call__": __event_call__} + + if "__task__" in sig.parameters: + params = {**params, "__task__": __task__} + if form_data["stream"]: async def stream_content(): diff --git a/backend/constants.py b/backend/constants.py index 7c366c222..b9c7fc430 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -95,8 +95,8 @@ class TASKS(str, Enum): def __str__(self) -> str: return super().__str__() - DEFAULT = lambda task="": f"{task if task else 'default'}" - TITLE_GENERATION = "Title Generation" - EMOJI_GENERATION = "Emoji Generation" - QUERY_GENERATION = "Query Generation" - FUNCTION_CALLING = "Function Calling" + DEFAULT = lambda task="": f"{task if task else 'generation'}" + TITLE_GENERATION = "title_generation" + EMOJI_GENERATION = "emoji_generation" + QUERY_GENERATION = "query_generation" + FUNCTION_CALLING = "function_calling"