This commit is contained in:
pablodanswer
2024-09-22 11:24:48 -07:00
parent d1641652a2
commit fe3f6d451d
6 changed files with 17 additions and 36 deletions

View File

@@ -10,17 +10,6 @@ from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
import asyncio
from logging.config import fileConfig
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
# Alembic Config object
config = context.config

View File

@@ -19,7 +19,7 @@ depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text('SELECT id, chosen_assistants FROM "user"')
)

View File

@@ -1,3 +1,4 @@
from datetime import timedelta
import contextlib
import smtplib
import uuid
@@ -240,7 +241,6 @@ async def get_or_create_user(email: str, user_id: str, tenant_id: str) -> User:
created_user: User = await user_db.create(new_user)
return created_user
from datetime import timedelta
async def create_user_session(user: User, tenant_id: str) -> str:
# Create a payload with user information and tenant_id
@@ -250,11 +250,11 @@ async def create_user_session(user: User, tenant_id: str) -> str:
"tenant_id": tenant_id,
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
}
# Encode the token
token = jwt.encode(payload, "JWT_SECRET_KEY", algorithm="HS256")
return token

View File

@@ -378,4 +378,4 @@ STRIPE_WEBHOOK_SECRET = (
"whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3"
)
DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public")
DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public")

View File

@@ -1,4 +1,4 @@
from fastapi import Request, Depends, HTTPException
from fastapi import Request, HTTPException
import contextlib
import time
from collections.abc import AsyncGenerator
@@ -31,18 +31,11 @@ from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from typing import Generator
from sqlalchemy.orm import Session
from sqlalchemy import text
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
import jwt
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
DEFAULT_SCHEMA = "public"
logger = setup_logger()
@@ -157,7 +150,7 @@ def get_sqlalchemy_engine(schema: str = DEFAULT_SCHEMA) -> Engine:
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
# NOTE: Should be unnecessary
@@ -199,7 +192,7 @@ def get_session_context_manager() -> ContextManager[Session]:
def get_current_tenant_id(request: Request) -> str:
if AUTH_TYPE == AuthType.DISABLED:
return DEFAULT_SCHEMA
token = request.cookies.get("fastapiusersauth")
if not token:
raise HTTPException(status_code=401, detail="Authentication required")

View File

@@ -5,15 +5,12 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.security import OAuth2PasswordBearer
from fastapi_users.authentication import Strategy
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from danswer.auth.users import create_user_session
from danswer.auth.users import optional_user
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_database_strategy
from danswer.auth.users import get_user_manager
from danswer.auth.users import is_user_admin
from danswer.auth.users import UserManager
@@ -40,17 +37,19 @@ from danswer.server.settings.store import store_settings
from danswer.utils.logger import setup_logger
from fastapi.responses import JSONResponse
from fastapi.responses import Response
from danswer.db.engine import get_async_session
import subprocess
import contextlib
from sqlalchemy import text
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
admin_router = APIRouter(prefix="/admin/settings")
basic_router = APIRouter(prefix="/settings")
from danswer.db.engine import get_async_session
import subprocess
logger = setup_logger()
from sqlalchemy import text
import contextlib
logger = setup_logger()
def run_alembic_migrations(schema_name: str) -> None:
# alembic -x "schema=tenant1,create_schema=true" upgrade head
@@ -74,7 +73,7 @@ async def check_schema_exists(tenant_id: str) -> bool:
logger.info(f"Checking if schema exists for tenant: {tenant_id}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
)
)
async with get_async_session_context() as session:
result = await session.execute(
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
@@ -90,7 +89,7 @@ async def create_tenant_schema(tenant_id: str) -> None:
# Create the schema
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
)
)
async with get_async_session_context() as session: