Remove All Enums from Postgres (#1247)

This commit is contained in:
Yuhong Sun 2024-03-22 23:01:05 -07:00 committed by GitHub
parent 89e72783a7
commit aaa7b26a4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 10 deletions

View File

@ -0,0 +1,71 @@
"""Remove Remaining Enums
Revision ID: 776b3bbe9092
Revises: 4738e4b3bae1
Create Date: 2024-03-22 21:34:27.629444
"""
from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.models import RecencyBiasSetting
from danswer.search.models import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"
down_revision = "4738e4b3bae1"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"persona",
"search_type",
type_=sa.String,
existing_type=sa.Enum(SearchType, native_enum=False),
existing_nullable=False,
)
op.alter_column(
"persona",
"recency_bias",
type_=sa.String,
existing_type=sa.Enum(RecencyBiasSetting, native_enum=False),
existing_nullable=False,
)
# Because the indexmodelstatus enum does not have a mapping to a string type
# we need this workaround instead of directly changing the type
op.add_column("embedding_model", sa.Column("temp_status", sa.String))
op.execute("UPDATE embedding_model SET temp_status = status::text")
op.drop_column("embedding_model", "status")
op.alter_column("embedding_model", "temp_status", new_column_name="status")
op.execute("DROP TYPE IF EXISTS searchtype")
op.execute("DROP TYPE IF EXISTS recencybiassetting")
op.execute("DROP TYPE IF EXISTS indexmodelstatus")
def downgrade() -> None:
op.alter_column(
"persona",
"search_type",
type_=sa.Enum(SearchType, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)
op.alter_column(
"persona",
"recency_bias",
type_=sa.Enum(RecencyBiasSetting, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)
op.alter_column(
"embedding_model",
"status",
type_=sa.Enum(IndexModelStatus, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)

View File

@ -242,7 +242,7 @@ class ConnectorCredentialPair(Base):
DateTime(timezone=True), default=None
)
last_attempt_status: Mapped[IndexingStatus | None] = mapped_column(
Enum(IndexingStatus)
Enum(IndexingStatus, native_enum=False)
)
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
@ -309,7 +309,9 @@ class Tag(Base):
id: Mapped[int] = mapped_column(primary_key=True)
tag_key: Mapped[str] = mapped_column(String)
tag_value: Mapped[str] = mapped_column(String)
source: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource))
source: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
documents = relationship(
"Document",
@ -396,7 +398,9 @@ class EmbeddingModel(Base):
normalize: Mapped[bool] = mapped_column(Boolean)
query_prefix: Mapped[str] = mapped_column(String)
passage_prefix: Mapped[str] = mapped_column(String)
status: Mapped[IndexModelStatus] = mapped_column(Enum(IndexModelStatus))
status: Mapped[IndexModelStatus] = mapped_column(
Enum(IndexModelStatus, native_enum=False)
)
index_name: Mapped[str] = mapped_column(String)
index_attempts: Mapped[List["IndexAttempt"]] = relationship(
@ -441,7 +445,9 @@ class IndexAttempt(Base):
# This is only for attempts that are explicitly marked as from the start via
# the run once API
from_beginning: Mapped[bool] = mapped_column(Boolean)
status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
status: Mapped[IndexingStatus] = mapped_column(
Enum(IndexingStatus, native_enum=False)
)
# The two below may be slightly out of sync if user switches Embedding Model
new_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0)
total_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0)
@ -544,7 +550,9 @@ class SearchDoc(Base):
link: Mapped[str | None] = mapped_column(String, nullable=True)
blurb: Mapped[str] = mapped_column(String)
boost: Mapped[int] = mapped_column(Integer)
source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource))
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
hidden: Mapped[bool] = mapped_column(Boolean)
doc_metadata: Mapped[dict[str, str | list[str]]] = mapped_column(postgresql.JSONB())
score: Mapped[float] = mapped_column(Float)
@ -617,7 +625,9 @@ class ChatMessage(Base):
# If prompt is None, then token_count is 0 as this message won't be passed into
# the LLM's context (not included in the history of messages)
token_count: Mapped[int] = mapped_column(Integer)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
message_type: Mapped[MessageType] = mapped_column(
Enum(MessageType, native_enum=False)
)
# Maps the citation numbers to a SearchDoc id
citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True)
# Only applies for LLM
@ -656,7 +666,7 @@ class DocumentRetrievalFeedback(Base):
document_rank: Mapped[int] = mapped_column(Integer)
clicked: Mapped[bool] = mapped_column(Boolean, default=False)
feedback: Mapped[SearchFeedbackType | None] = mapped_column(
Enum(SearchFeedbackType), nullable=True
Enum(SearchFeedbackType, native_enum=False), nullable=True
)
chat_message: Mapped[ChatMessage] = relationship(
@ -768,7 +778,7 @@ class Persona(Base):
description: Mapped[str] = mapped_column(String)
# Currently stored but unused, all flows use hybrid
search_type: Mapped[SearchType] = mapped_column(
Enum(SearchType), default=SearchType.HYBRID
Enum(SearchType, native_enum=False), default=SearchType.HYBRID
)
# Number of chunks to pass to the LLM for generation.
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
@ -778,7 +788,9 @@ class Persona(Base):
# Enables using LLM to extract time and source type filters
# Can also be admin disabled globally
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(Enum(RecencyBiasSetting))
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
# Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
@ -891,7 +903,7 @@ class TaskQueueState(Base):
# For any job type, this would be the same
task_name: Mapped[str] = mapped_column(String)
# Note that if the task dies, this won't necessarily be marked FAILED correctly
status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus))
status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus, native_enum=False))
start_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True)
)