mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 21:09:51 +02:00
more solid schema context passing
This commit is contained in:
parent
b4ee066424
commit
5775aec498
@ -7,6 +7,7 @@ Create Date: 2024-09-25 12:47:44.877589
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@ -16,7 +17,135 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# -------- Insert Tools --------
|
||||
|
||||
# Ensure 'ImageGenerationTool' exists in the tool table
|
||||
image_gen_tool_name = 'ImageGenerationTool'
|
||||
existing_tool = session.execute(
|
||||
sa.select(tool_table.c.id).where(tool_table.c.name == image_gen_tool_name)
|
||||
).fetchone()
|
||||
|
||||
if not existing_tool:
|
||||
result = session.execute(
|
||||
tool_table.insert().values(
|
||||
name=image_gen_tool_name,
|
||||
display_name='Image Generator',
|
||||
description='Generates images based on descriptions',
|
||||
builtin_tool=True,
|
||||
is_public=True,
|
||||
)
|
||||
)
|
||||
image_gen_tool_id = result.inserted_primary_key[0]
|
||||
else:
|
||||
image_gen_tool_id = existing_tool[0]
|
||||
|
||||
# -------- Insert Personas --------
|
||||
|
||||
personas = personas_data.get('personas', [])
|
||||
for persona in personas:
|
||||
persona_id = persona.get('id')
|
||||
# Check if persona already exists
|
||||
existing_persona = session.execute(
|
||||
sa.select(persona_table.c.id).where(persona_table.c.id == persona_id)
|
||||
).fetchone()
|
||||
|
||||
persona_values = {
|
||||
'id': persona_id,
|
||||
'name': persona['name'],
|
||||
'description': persona.get('description', '').strip(),
|
||||
'num_chunks': persona.get('num_chunks'),
|
||||
'llm_relevance_filter': persona.get('llm_relevance_filter', False),
|
||||
'llm_filter_extraction': persona.get('llm_filter_extraction', False),
|
||||
'recency_bias': persona.get('recency_bias'),
|
||||
'icon_shape': persona.get('icon_shape'),
|
||||
'icon_color': persona.get('icon_color'),
|
||||
'display_priority': persona.get('display_priority'),
|
||||
'is_visible': persona.get('is_visible', True),
|
||||
'builtin_persona': True,
|
||||
'is_public': True,
|
||||
'image_generation': persona.get('image_generation', False),
|
||||
'llm_model_provider_override': persona.get('llm_model_provider_override'),
|
||||
'llm_model_version_override': persona.get('llm_model_version_override'),
|
||||
}
|
||||
|
||||
if not existing_persona:
|
||||
# Insert new persona
|
||||
session.execute(
|
||||
persona_table.insert().values(**persona_values)
|
||||
)
|
||||
else:
|
||||
# Update existing persona
|
||||
session.execute(
|
||||
persona_table.update()
|
||||
.where(persona_table.c.id == persona_id)
|
||||
.values(**persona_values)
|
||||
)
|
||||
|
||||
# -------- Associate Personas with Tools --------
|
||||
|
||||
tool_ids = []
|
||||
if persona.get('image_generation'):
|
||||
tool_ids.append(image_gen_tool_id)
|
||||
|
||||
# Associate persona with tools
|
||||
for tool_id in tool_ids:
|
||||
# Check if association already exists
|
||||
existing_association = session.execute(
|
||||
sa.select(persona_tool_association_table.c.persona_id)
|
||||
.where(
|
||||
(persona_tool_association_table.c.persona_id == persona_id) &
|
||||
(persona_tool_association_table.c.tool_id == tool_id)
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if not existing_association:
|
||||
session.execute(
|
||||
persona_tool_association_table.insert().values(
|
||||
persona_id=persona_id,
|
||||
tool_id=tool_id,
|
||||
)
|
||||
)
|
||||
|
||||
# -------- Insert Input Prompts --------
|
||||
|
||||
input_prompts = input_prompts_data.get('input_prompts', [])
|
||||
for input_prompt in input_prompts:
|
||||
input_prompt_id = input_prompt.get('id')
|
||||
# Check if input prompt already exists
|
||||
existing_input_prompt = session.execute(
|
||||
sa.select(input_prompt_table.c.id).where(input_prompt_table.c.id == input_prompt_id)
|
||||
).fetchone()
|
||||
|
||||
input_prompt_values = {
|
||||
'id': input_prompt_id,
|
||||
'prompt': input_prompt['prompt'],
|
||||
'content': input_prompt['content'],
|
||||
'is_public': input_prompt.get('is_public', True),
|
||||
'active': input_prompt.get('active', True),
|
||||
}
|
||||
|
||||
if not existing_input_prompt:
|
||||
# Insert new input prompt
|
||||
session.execute(
|
||||
input_prompt_table.insert().values(**input_prompt_values)
|
||||
)
|
||||
else:
|
||||
# Update existing input prompt
|
||||
session.execute(
|
||||
input_prompt_table.update()
|
||||
.where(input_prompt_table.c.id == input_prompt_id)
|
||||
.values(**input_prompt_values)
|
||||
)
|
||||
|
||||
# Commit the session
|
||||
session.commit()
|
||||
|
||||
def downgrade():
|
||||
# Optional: Implement logic to remove the inserted data if necessary
|
||||
pass
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
from contextvars import ContextVar
|
||||
|
||||
from fastapi import Depends
|
||||
@ -192,30 +193,29 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
global_tenant_id = "650a1472-4101-497c-b5f1-5dfe1b067730"
|
||||
current_tenant_id = contextvars.ContextVar(
|
||||
"current_tenant_id", default="650a1472-4101-497c-b5f1-5dfe1b067730"
|
||||
)
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
global global_tenant_id
|
||||
return contextlib.contextmanager(lambda: get_session(override_tenant_id=global_tenant_id))()
|
||||
|
||||
tenant_id = current_tenant_id.get()
|
||||
return contextlib.contextmanager(lambda: get_session(override_tenant_id=tenant_id))()
|
||||
|
||||
def get_current_tenant_id(request: Request) -> str | None:
|
||||
if not MULTI_TENANT:
|
||||
return DEFAULT_SCHEMA
|
||||
|
||||
token = request.cookies.get("tenant_details")
|
||||
global global_tenant_id
|
||||
if not token:
|
||||
logger.warning("zzzztoken found in cookies")
|
||||
log_stack_trace()
|
||||
|
||||
print('returning', global_tenant_id)
|
||||
return "650a1472-4101-497c-b5f1-5dfe1b067730"
|
||||
# raise HTTPException(status_code=401, detail="Authentication required")
|
||||
logger.warning("No token found in cookies")
|
||||
tenant_id = current_tenant_id.get()
|
||||
logger.info(f"Returning default tenant_id: {tenant_id}")
|
||||
return tenant_id
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to decode token: {token[:10]}...") # Log only first 10 characters for security
|
||||
logger.info("Attempting to decode token")
|
||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||
logger.info(f"Decoded payload: {payload}")
|
||||
tenant_id = payload.get("tenant_id")
|
||||
@ -224,26 +224,29 @@ def get_current_tenant_id(request: Request) -> str | None:
|
||||
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
|
||||
logger.info(f"Valid tenant_id found: {tenant_id}")
|
||||
current_tenant_id.set(tenant_id)
|
||||
global_tenant_id = tenant_id
|
||||
return tenant_id
|
||||
except DecodeError as e:
|
||||
except (DecodeError, InvalidTokenError) as e:
|
||||
logger.error(f"JWT decode error: {str(e)}")
|
||||
raise HTTPException(status_code=401, detail="Invalid token format")
|
||||
except InvalidTokenError as e:
|
||||
logger.error(f"Invalid token error: {str(e)}")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
def get_session(tenant_id: str | None= Depends(get_current_tenant_id), override_tenant_id: str | None = None) -> Generator[Session, None, None]:
|
||||
def get_session(tenant_id: str | None = None, override_tenant_id: str | None = None) -> Generator[Session, None, None]:
|
||||
if override_tenant_id:
|
||||
tenant_id = override_tenant_id
|
||||
else:
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
with Session(get_sqlalchemy_engine(schema=override_tenant_id or tenant_id), expire_on_commit=False) as session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session:
|
||||
yield session
|
||||
# finally:
|
||||
# current_tenant_id.reset(tenant_id)
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async def get_async_session(tenant_id: str | None = None, override_tenant_id: str | None = None) -> AsyncGenerator[AsyncSession, None]:
|
||||
if override_tenant_id:
|
||||
tenant_id = override_tenant_id
|
||||
else:
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
async with AsyncSession(
|
||||
get_sqlalchemy_async_engine(), expire_on_commit=False
|
||||
) as async_session:
|
||||
|
@ -9,11 +9,11 @@ export default async function AdminLayout({
|
||||
return (
|
||||
<div className="flex h-screen">
|
||||
<div className="mx-auto my-auto text-lg font-bold text-red-500">
|
||||
This funcitonality is only available in the Enterprise Edition :(
|
||||
This functionality is only available in the Enterprise Edition :(
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return children;
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user