more solid schema context passing

This commit is contained in:
pablodanswer 2024-09-25 14:15:16 -07:00
parent b4ee066424
commit 5775aec498
3 changed files with 156 additions and 24 deletions

View File

@ -7,6 +7,7 @@ Create Date: 2024-09-25 12:47:44.877589
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
# revision identifiers, used by Alembic.
@ -16,7 +17,135 @@ branch_labels = None
depends_on = None
def upgrade() -> None:
# -------- Insert Tools --------
# Ensure 'ImageGenerationTool' exists in the tool table
image_gen_tool_name = 'ImageGenerationTool'
existing_tool = session.execute(
sa.select(tool_table.c.id).where(tool_table.c.name == image_gen_tool_name)
).fetchone()
if not existing_tool:
result = session.execute(
tool_table.insert().values(
name=image_gen_tool_name,
display_name='Image Generator',
description='Generates images based on descriptions',
builtin_tool=True,
is_public=True,
)
)
image_gen_tool_id = result.inserted_primary_key[0]
else:
image_gen_tool_id = existing_tool[0]
# -------- Insert Personas --------
personas = personas_data.get('personas', [])
for persona in personas:
persona_id = persona.get('id')
# Check if persona already exists
existing_persona = session.execute(
sa.select(persona_table.c.id).where(persona_table.c.id == persona_id)
).fetchone()
persona_values = {
'id': persona_id,
'name': persona['name'],
'description': persona.get('description', '').strip(),
'num_chunks': persona.get('num_chunks'),
'llm_relevance_filter': persona.get('llm_relevance_filter', False),
'llm_filter_extraction': persona.get('llm_filter_extraction', False),
'recency_bias': persona.get('recency_bias'),
'icon_shape': persona.get('icon_shape'),
'icon_color': persona.get('icon_color'),
'display_priority': persona.get('display_priority'),
'is_visible': persona.get('is_visible', True),
'builtin_persona': True,
'is_public': True,
'image_generation': persona.get('image_generation', False),
'llm_model_provider_override': persona.get('llm_model_provider_override'),
'llm_model_version_override': persona.get('llm_model_version_override'),
}
if not existing_persona:
# Insert new persona
session.execute(
persona_table.insert().values(**persona_values)
)
else:
# Update existing persona
session.execute(
persona_table.update()
.where(persona_table.c.id == persona_id)
.values(**persona_values)
)
# -------- Associate Personas with Tools --------
tool_ids = []
if persona.get('image_generation'):
tool_ids.append(image_gen_tool_id)
# Associate persona with tools
for tool_id in tool_ids:
# Check if association already exists
existing_association = session.execute(
sa.select(persona_tool_association_table.c.persona_id)
.where(
(persona_tool_association_table.c.persona_id == persona_id) &
(persona_tool_association_table.c.tool_id == tool_id)
)
).fetchone()
if not existing_association:
session.execute(
persona_tool_association_table.insert().values(
persona_id=persona_id,
tool_id=tool_id,
)
)
# -------- Insert Input Prompts --------
input_prompts = input_prompts_data.get('input_prompts', [])
for input_prompt in input_prompts:
input_prompt_id = input_prompt.get('id')
# Check if input prompt already exists
existing_input_prompt = session.execute(
sa.select(input_prompt_table.c.id).where(input_prompt_table.c.id == input_prompt_id)
).fetchone()
input_prompt_values = {
'id': input_prompt_id,
'prompt': input_prompt['prompt'],
'content': input_prompt['content'],
'is_public': input_prompt.get('is_public', True),
'active': input_prompt.get('active', True),
}
if not existing_input_prompt:
# Insert new input prompt
session.execute(
input_prompt_table.insert().values(**input_prompt_values)
)
else:
# Update existing input prompt
session.execute(
input_prompt_table.update()
.where(input_prompt_table.c.id == input_prompt_id)
.values(**input_prompt_values)
)
# Commit the session
session.commit()
def downgrade():
# Optional: Implement logic to remove the inserted data if necessary
pass

View File

@ -1,3 +1,4 @@
import contextvars
from contextvars import ContextVar
from fastapi import Depends
@ -192,30 +193,29 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
global_tenant_id = "650a1472-4101-497c-b5f1-5dfe1b067730"
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default="650a1472-4101-497c-b5f1-5dfe1b067730"
)
def get_session_context_manager() -> ContextManager[Session]:
global global_tenant_id
return contextlib.contextmanager(lambda: get_session(override_tenant_id=global_tenant_id))()
tenant_id = current_tenant_id.get()
return contextlib.contextmanager(lambda: get_session(override_tenant_id=tenant_id))()
def get_current_tenant_id(request: Request) -> str | None:
if not MULTI_TENANT:
return DEFAULT_SCHEMA
token = request.cookies.get("tenant_details")
global global_tenant_id
if not token:
logger.warning("zzzztoken found in cookies")
log_stack_trace()
print('returning', global_tenant_id)
return "650a1472-4101-497c-b5f1-5dfe1b067730"
# raise HTTPException(status_code=401, detail="Authentication required")
logger.warning("No token found in cookies")
tenant_id = current_tenant_id.get()
logger.info(f"Returning default tenant_id: {tenant_id}")
return tenant_id
try:
logger.info(f"Attempting to decode token: {token[:10]}...") # Log only first 10 characters for security
logger.info("Attempting to decode token")
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
logger.info(f"Decoded payload: {payload}")
tenant_id = payload.get("tenant_id")
@ -224,26 +224,29 @@ def get_current_tenant_id(request: Request) -> str | None:
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
logger.info(f"Valid tenant_id found: {tenant_id}")
current_tenant_id.set(tenant_id)
global_tenant_id = tenant_id
return tenant_id
except DecodeError as e:
except (DecodeError, InvalidTokenError) as e:
logger.error(f"JWT decode error: {str(e)}")
raise HTTPException(status_code=401, detail="Invalid token format")
except InvalidTokenError as e:
logger.error(f"Invalid token error: {str(e)}")
raise HTTPException(status_code=401, detail="Invalid token")
except Exception as e:
logger.exception(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
def get_session(tenant_id: str | None= Depends(get_current_tenant_id), override_tenant_id: str | None = None) -> Generator[Session, None, None]:
def get_session(tenant_id: str | None = None, override_tenant_id: str | None = None) -> Generator[Session, None, None]:
if override_tenant_id:
tenant_id = override_tenant_id
else:
tenant_id = current_tenant_id.get()
with Session(get_sqlalchemy_engine(schema=override_tenant_id or tenant_id), expire_on_commit=False) as session:
with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session:
yield session
# finally:
# current_tenant_id.reset(tenant_id)
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async def get_async_session(tenant_id: str | None = None, override_tenant_id: str | None = None) -> AsyncGenerator[AsyncSession, None]:
if override_tenant_id:
tenant_id = override_tenant_id
else:
tenant_id = current_tenant_id.get()
async with AsyncSession(
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:

View File

@ -9,11 +9,11 @@ export default async function AdminLayout({
return (
<div className="flex h-screen">
<div className="mx-auto my-auto text-lg font-bold text-red-500">
This funcitonality is only available in the Enterprise Edition :(
This functionality is only available in the Enterprise Edition :(
</div>
</div>
);
}
return children;
}
}