add default api keys for cloud users (#3044)

* add default api keys for cloud users

* add cohere as well

* naming
This commit is contained in:
pablodanswer
2024-11-04 11:11:12 -08:00
committed by GitHub
parent 2cd1e6be00
commit 2cb33b1fb4
4 changed files with 39 additions and 2 deletions

View File

@ -25,3 +25,7 @@ NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")

View File

@ -27,13 +27,13 @@ from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ImpersonateRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
@ -68,6 +68,8 @@ def create_tenant(
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
add_users_to_tenant([email], tenant_id)
return {

View File

@ -10,9 +10,17 @@ from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import UserTenantMapping
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
@ -52,6 +60,29 @@ def run_alembic_migrations(schema_name: str) -> None:
raise
def configure_default_api_keys(db_session: Session) -> None:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider="OpenAI",
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
)
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider="Anthropic",
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20240620",
)
upsert_llm_provider(open_provider, db_session)
upsert_llm_provider(anthropic_provider, db_session)
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
def ensure_schema_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():