From bd9f15854f6fb11af65d3e04bcd9c1c3c47580e6 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 21 Nov 2024 14:08:16 -0800 Subject: [PATCH] provider fix (#3187) * clean horizontal scrollbar * provider fix * ensure proper migration * k * update migration * Revert "clean horizontal scrollbar" This reverts commit fa592a1b7a69897110a928a222b19eaef3b7267a. --- .../177de57c21c9_display_custom_llm_models.py | 59 +++++++++++++++++++ backend/danswer/db/models.py | 2 +- .../llm/CustomLLMProviderUpdateForm.tsx | 2 + 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 backend/alembic/versions/177de57c21c9_display_custom_llm_models.py diff --git a/backend/alembic/versions/177de57c21c9_display_custom_llm_models.py b/backend/alembic/versions/177de57c21c9_display_custom_llm_models.py new file mode 100644 index 000000000..d622f55b2 --- /dev/null +++ b/backend/alembic/versions/177de57c21c9_display_custom_llm_models.py @@ -0,0 +1,59 @@ +"""display custom llm models + +Revision ID: 177de57c21c9 +Revises: 4ee1287bd26a +Create Date: 2024-11-21 11:49:04.488677 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy import and_ + +revision = "177de57c21c9" +down_revision = "4ee1287bd26a" +branch_labels = None +depends_on = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + llm_provider = sa.table( + "llm_provider", + sa.column("id", sa.Integer), + sa.column("provider", sa.String), + sa.column("model_names", postgresql.ARRAY(sa.String)), + sa.column("display_model_names", postgresql.ARRAY(sa.String)), + ) + + excluded_providers = ["openai", "bedrock", "anthropic", "azure"] + + providers_to_update = sa.select( + llm_provider.c.id, + llm_provider.c.model_names, + llm_provider.c.display_model_names, + ).where( + and_( + ~llm_provider.c.provider.in_(excluded_providers), + llm_provider.c.model_names.isnot(None), + ) + ) + + results = conn.execute(providers_to_update).fetchall() + + for provider_id, model_names, display_model_names in results: + if display_model_names is None: + display_model_names = [] + + combined_model_names = list(set(display_model_names + model_names)) + update_stmt = ( + llm_provider.update() + .where(llm_provider.c.id == provider_id) + .values(display_model_names=combined_model_names) + ) + conn.execute(update_stmt) + + +def downgrade() -> None: + pass diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index b19753942..1513fc5fc 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1181,7 +1181,7 @@ class LLMProvider(Base): default_model_name: Mapped[str] = mapped_column(String) fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True) - # Models to actually disp;aly to users + # Models to actually display to users # If nulled out, we assume in the application logic we should present all display_model_names: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True diff --git a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx index 4301c34d3..9011b2cdf 100644 --- a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx @@ -142,6 +142,8 @@ export function CustomLLMProviderUpdateForm({ }, body: JSON.stringify({ ...values, + // For custom llm providers, all model names are displayed + display_model_names: values.model_names, custom_config: customConfigProcessing(values.custom_config_list), }), });