Merge branch 'main' into add-teams-connector

This commit is contained in:
hagen-danswer 2024-06-04 12:09:52 -04:00 committed by GitHub
commit a9834853ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
546 changed files with 36747 additions and 10466 deletions

View File

@ -14,16 +14,16 @@ jobs:
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v1
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v2
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
@ -38,5 +38,7 @@ jobs:
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@ -14,16 +14,16 @@ jobs:
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v1
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v2
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server

View File

@ -5,38 +5,114 @@ on:
tags:
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build-and-push:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=danswer/danswer-web-server:${{ github.ref_name }}
type=raw,value=danswer/danswer-web-server:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: true
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: digests-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
merge:
runs-on: ubuntu-latest
needs:
- build
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Web Image Docker Build and Push
uses: docker/build-push-action@v2
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-web-server:${{ github.ref_name }}
danswer/danswer-web-server:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@ -20,10 +20,12 @@ jobs:
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
- name: Run MyPy
run: |

View File

@ -11,62 +11,6 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "API Server",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"env": {
"LOG_LEVEL": "DEBUG",
"DISABLE_AUTH": "True",
"TYPESENSE_API_KEY": "typesense_api_key",
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage"
},
"args": [
"danswer.main:app",
"--reload",
"--port",
"8080"
]
},
{
"name": "Indexer",
"type": "python",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONPATH": ".",
"TYPESENSE_API_KEY": "typesense_api_key",
"DYNAMIC_CONFIG_DIR_PATH": "./dynamic_config_storage"
}
},
{
"name": "Temp File Deletion",
"type": "python",
"request": "launch",
"program": "danswer/background/file_deletion.py",
"cwd": "${workspaceFolder}/backend",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONPATH": "${workspaceFolder}/backend"
}
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot Listener",
"type": "python",
"request": "launch",
"program": "danswer/listeners/slack_listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG"
}
},
{
"name": "Web Server",
"type": "node",
@ -77,6 +21,85 @@
"run", "dev"
],
"console": "integratedTerminal"
},
{
"name": "Model Server",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"model_server.main:app",
"--reload",
"--port",
"9000"
]
},
{
"name": "API Server",
"type": "python",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"env": {
"LOG_ALL_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"danswer.main:app",
"--reload",
"--port",
"8080"
]
},
{
"name": "Indexing",
"type": "python",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"env": {
"ENABLE_MINI_CHUNK": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
"name": "Background Jobs",
"type": "python",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"--no-indexing"
]
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"type": "python",
"request": "launch",
"program": "danswer/danswerbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
}
]
}

View File

@ -72,15 +72,20 @@ For convenience here's a command for it:
python -m venv .venv
source .venv/bin/activate
```
_For Windows activate via:_
_For Windows, activate the virtual environment using Command Prompt:_
```bash
.venv\Scripts\activate
```
If using PowerShell, the command slightly differs:
```powershell
.venv\Scripts\Activate.ps1
```
Install the required python dependencies:
```bash
pip install -r danswer/backend/requirements/default.txt
pip install -r danswer/backend/requirements/dev.txt
pip install -r danswer/backend/requirements/model_server.txt
```
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
@ -108,26 +113,24 @@ docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational
(index refers to Vespa and relational_db refers to Postgres)
#### Running Danswer
Setup a folder to store config. Navigate to `danswer/backend` and run:
```bash
mkdir dynamic_config_storage
```
To start the frontend, navigate to `danswer/web` and run:
```bash
npm run dev
```
Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally.
Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run:
Next, start the model server which runs the local NLP models.
Navigate to `danswer/backend` and run:
```bash
zip -r ../vespa-app.zip .
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
uvicorn model_server.main:app --reload --port 9000
"
```
- Note: If you don't have the `zip` utility, you will need to install it prior to running the above
The first time running Danswer, you will also need to run the DB migrations for Postgres.
The first time running Danswer, you will need to run the DB migrations for Postgres.
After the first time, this is no longer required unless the DB models change.
Navigate to `danswer/backend` and with the venv active, run:
@ -145,17 +148,12 @@ python ./scripts/dev_run_background_jobs.py
To run the backend API server, navigate back to `danswer/backend` and run:
```bash
AUTH_TYPE=disabled \
DYNAMIC_CONFIG_DIR_PATH=./dynamic_config_storage \
VESPA_DEPLOYMENT_ZIP=./danswer/document_index/vespa/vespa-app.zip \
uvicorn danswer.main:app --reload --port 8080
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
$env:AUTH_TYPE='disabled'
$env:DYNAMIC_CONFIG_DIR_PATH='./dynamic_config_storage'
$env:VESPA_DEPLOYMENT_ZIP='./danswer/document_index/vespa/vespa-app.zip'
uvicorn danswer.main:app --reload --port 8080
"
```
@ -174,20 +172,16 @@ pre-commit install
Additionally, we use `mypy` for static type checking.
Danswer is fully type-annotated, and we would like to keep it that way!
Right now, there is no automated type checking at the moment (coming soon), but we ask you to manually run it before
creating a pull requests with `python -m mypy .` from the `danswer/backend` directory.
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
#### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `danswer/web` directory.
To run the formatter, use `npx prettier --write .` from the `danswer/web` directory.
Like `mypy`, we have no automated formatting yet (coming soon), but we request that, for now,
you run this manually before creating a pull request.
Please double check that prettier passes before creating a pull request.
### Release Process
Danswer follows the semver versioning standard.
A set of Docker containers will be pushed automatically to DockerHub with every tag.
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
As pre-1.0 software, even patch releases may contain breaking or non-backwards-compatible changes.

View File

@ -5,7 +5,7 @@
</h2>
<p align="center">
<p align="center">Open Source Unified Search and Gen-AI Chat with your Docs.</p>
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
<p align="center">
<a href="https://docs.danswer.dev/" target="_blank">
@ -22,16 +22,17 @@
</a>
</p>
<strong>[Danswer](https://www.danswer.ai/)</strong> lets you ask questions in natural language questions and get back
answers based on your team specific documents. Think ChatGPT if it had access to your team's unique
knowledge. Connects to all common workplace tools such as Slack, Google Drive, Confluence, etc.
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
configuring Personas (AI Assistants) and their Prompts.
Teams have used Danswer to:
- Speedup customer support and escalation turnaround time.
- Improve Engineering efficiency by making documentation and code changelogs easy to find.
- Let sales team get fuller context and faster in preparation for calls.
- Track customer requests and priorities for Product teams.
- Help teams self-serve IT, Onboarding, HR, etc.
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
supported?" or "Where's the pull request for feature Y?"
<h3>Usage</h3>
@ -57,19 +58,27 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
## 💃 Main Features
* Chat UI with the ability to select documents to chat with.
* Create custom AI Assistants with different prompts and backing knowledge sets.
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
* Document Search + AI Answers for natural language queries.
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
* Chat support (think ChatGPT but it has access to your private knowledge sources).
* Create custom AI Assistants with different prompts and backing knowledge sets.
* Slack integration to get answers and search results directly in Slack.
## 🚧 Roadmap
* Chat/Prompt sharing with specific teammates and user groups.
* Multi-Model model support, chat with images, video etc.
* Choosing between LLMs and parameters during chat session.
* Tool calling and agent configurations options.
* Organizational understanding and ability to locate and suggest experts from your team.
## Other Noteable Benefits of Danswer
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
* User Authentication with document level access management.
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
* Admin Dashboard to configure connectors, document-sets, access, etc.
* Custom deep learning models + learn from user feedback.
* Connect Danswer with LLM of your choice for a fully airgapped solution.
* Easy deployment and ability to host Danswer anywhere of your choosing.
@ -96,11 +105,5 @@ Efficiently pulls the latest changes from:
* Websites
* And more ...
## 🚧 Roadmap
* Organizational understanding.
* Ability to locate and suggest experts from your team.
* Code Search
* Structured Query Languages (SQL, Excel formulas, etc.)
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

2
backend/.gitignore vendored
View File

@ -8,4 +8,4 @@ api_keys.py
.env
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule
celerybeat-schedule*

46
backend/.trivyignore Normal file
View File

@ -0,0 +1,46 @@
# https://github.com/madler/zlib/issues/868
# Pulled in with base Debian image, it's part of the contrib folder but unused
# zlib1g is fine
# Will be gone with Debian image upgrade
# No impact in our settings
CVE-2023-45853
# krb5 related, worst case is denial of service by resource exhaustion
# Accept the risk
CVE-2024-26458
CVE-2024-26461
CVE-2024-26462
CVE-2024-26458
CVE-2024-26461
CVE-2024-26462
CVE-2024-26458
CVE-2024-26461
CVE-2024-26462
CVE-2024-26458
CVE-2024-26461
CVE-2024-26462
# Specific to Firefox which we do not use
# No impact in our settings
CVE-2024-0743
# bind9 related, worst case is denial of service by CPU resource exhaustion
# Accept the risk
CVE-2023-50387
CVE-2023-50868
CVE-2023-50387
CVE-2023-50868
# libexpat1, XML parsing resource exhaustion
# We don't parse any user provided XMLs
# No impact in our settings
CVE-2023-52425
CVE-2024-28757
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
# No impact in our settings
CVE-2023-7104
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
# Accept the risk
CVE-2023-25193

View File

@ -1,5 +1,10 @@
FROM python:3.11.7-slim-bookworm
LABEL com.danswer.maintainer="founders@danswer.ai"
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
more details, visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
@ -12,7 +17,9 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# zip for Vespa step futher down
# ca-certificates for HTTPS
RUN apt-get update && \
apt-get install -y cmake curl zip ca-certificates && \
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
@ -29,15 +36,25 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
libldap-2.5-0 libldap-2.5-0 && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('wordnet', quiet=True); \
nltk.download('punkt', quiet=True);"
# Set up application files
WORKDIR /app
COPY ./danswer /app/danswer
COPY ./shared_models /app/shared_models
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf

View File

@ -1,5 +1,11 @@
FROM python:3.11.7-slim-bookworm
LABEL com.danswer.maintainer="founders@danswer.ai"
LABEL com.danswer.description="This image is for the Danswer model server which runs all of the \
AI models for Danswer. This container and all the code is MIT Licensed and free for all to use. \
You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For more details, \
visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
@ -11,25 +17,26 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
WORKDIR /app
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
from huggingface_hub import snapshot_download; \
AutoTokenizer.from_pretrained('danswer/intent-model'); \
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
snapshot_download('danswer/intent-model'); \
snapshot_download('intfloat/e5-base-v2'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
# Needed for model configs and defaults
COPY ./danswer/configs /app/danswer/configs
COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs
WORKDIR /app
# Utils used by model server
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
# Place to fetch version information
COPY ./danswer/__init__.py /app/danswer/__init__.py
# Shared implementations for running NLP models locally
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
# Request/Response models
COPY ./shared_models /app/shared_models
# Shared between Danswer Backend and Model Server
COPY ./shared_configs /app/shared_configs
# Model Server main code
COPY ./model_server /app/model_server

View File

@ -0,0 +1,31 @@
"""Add starter prompts
Revision ID: 0a2b51deb0b8
Revises: 5f4b8568a221
Create Date: 2024-03-02 23:23:49.960309
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "0a2b51deb0b8"
down_revision = "5f4b8568a221"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column(
"starter_messages",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("persona", "starter_messages")

View File

@ -0,0 +1,113 @@
"""Enable Encrypted Fields
Revision ID: 0a98909f2757
Revises: 570282d33c49
Create Date: 2024-05-05 19:30:34.317972
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table
from sqlalchemy.dialects import postgresql
import json
from danswer.utils.encryption import encrypt_string_to_bytes
# revision identifiers, used by Alembic.
revision = "0a98909f2757"
down_revision = "570282d33c49"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
connection = op.get_bind()
op.alter_column("key_value_store", "value", nullable=True)
op.add_column(
"key_value_store",
sa.Column(
"encrypted_value",
sa.LargeBinary,
nullable=True,
),
)
# Need a temporary column to translate the JSONB to binary
op.add_column("credential", sa.Column("temp_column", sa.LargeBinary()))
creds_table = table(
"credential",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"credential_json",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column(
"temp_column",
sa.LargeBinary(),
nullable=False,
),
)
results = connection.execute(sa.select(creds_table))
# This uses the MIT encrypt which does not actually encrypt the credentials
# In other words, this upgrade does not apply the encryption. Porting existing sensitive data
# and key rotation currently is not supported and will come out in the future
for row_id, creds, _ in results:
creds_binary = encrypt_string_to_bytes(json.dumps(creds))
connection.execute(
creds_table.update()
.where(creds_table.c.id == row_id)
.values(temp_column=creds_binary)
)
op.drop_column("credential", "credential_json")
op.alter_column("credential", "temp_column", new_column_name="credential_json")
op.add_column("llm_provider", sa.Column("temp_column", sa.LargeBinary()))
llm_table = table(
"llm_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"api_key",
sa.String(),
nullable=False,
),
sa.Column(
"temp_column",
sa.LargeBinary(),
nullable=False,
),
)
results = connection.execute(sa.select(llm_table))
for row_id, api_key, _ in results:
llm_key = encrypt_string_to_bytes(api_key)
connection.execute(
llm_table.update()
.where(llm_table.c.id == row_id)
.values(temp_column=llm_key)
)
op.drop_column("llm_provider", "api_key")
op.alter_column("llm_provider", "temp_column", new_column_name="api_key")
def downgrade() -> None:
# Some information loss but this is ok. Should not allow decryption via downgrade.
op.drop_column("credential", "credential_json")
op.drop_column("llm_provider", "api_key")
op.add_column("llm_provider", sa.Column("api_key", sa.String()))
op.add_column(
"credential",
sa.Column("credential_json", postgresql.JSONB(astext_type=sa.Text())),
)
op.execute("DELETE FROM key_value_store WHERE value IS NULL")
op.alter_column("key_value_store", "value", nullable=False)
op.drop_column("key_value_store", "encrypted_value")

View File

@ -13,8 +13,8 @@ from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "15326fcec57e"
down_revision = "77d07dffae64"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,29 @@
"""Port Config Store
Revision ID: 173cae5bba26
Revises: e50154680a5c
Create Date: 2024-03-19 15:30:44.425436
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "173cae5bba26"
down_revision = "e50154680a5c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"key_value_store",
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.PrimaryKeyConstraint("key"),
)
def downgrade() -> None:
op.drop_table("key_value_store")

View File

@ -13,8 +13,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = "2666d766cb9b"
down_revision = "6d387b3196c2"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "27c6ecc08586"
down_revision = "2666d766cb9b"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "30c1d5744104"
down_revision = "7f99be1cb9f5"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,45 @@
"""Add tool table
Revision ID: 3879338f8ba1
Revises: f1c6478c3fd8
Create Date: 2024-05-11 16:11:23.718084
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3879338f8ba1"
down_revision = "f1c6478c3fd8"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"tool",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("in_code_tool_id", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"persona__tool",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("tool_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["tool_id"],
["tool.id"],
),
sa.PrimaryKeyConstraint("persona_id", "tool_id"),
)
def downgrade() -> None:
op.drop_table("persona__tool")
op.drop_table("tool")

View File

@ -0,0 +1,41 @@
"""Add chat session sharing
Revision ID: 38eda64af7fe
Revises: 776b3bbe9092
Create Date: 2024-03-27 19:41:29.073594
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "38eda64af7fe"
down_revision = "776b3bbe9092"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_session",
sa.Column(
"shared_status",
sa.Enum(
"PUBLIC",
"PRIVATE",
name="chatsessionsharedstatus",
native_enum=False,
),
nullable=True,
),
)
op.execute("UPDATE chat_session SET shared_status='PRIVATE'")
op.alter_column(
"chat_session",
"shared_status",
nullable=False,
)
def downgrade() -> None:
op.drop_column("chat_session", "shared_status")

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3b25685ff73c"
down_revision = "e0a68a81d434"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = "3c5e35aa9af0"
down_revision = "27c6ecc08586"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,49 @@
"""Add tables for UI-based LLM configuration
Revision ID: 401c1ac29467
Revises: 703313b75876
Create Date: 2024-04-13 18:07:29.153817
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "401c1ac29467"
down_revision = "703313b75876"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"llm_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("api_base", sa.String(), nullable=True),
sa.Column("api_version", sa.String(), nullable=True),
sa.Column(
"custom_config",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
sa.Column("default_model_name", sa.String(), nullable=False),
sa.Column("fast_default_model_name", sa.String(), nullable=True),
sa.Column("is_default_provider", sa.Boolean(), unique=True, nullable=True),
sa.Column("model_names", postgresql.ARRAY(sa.String()), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.add_column(
"persona",
sa.Column("llm_model_provider_override", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "llm_model_provider_override")
op.drop_table("llm_provider")

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "465f78d9b7f9"
down_revision = "3c5e35aa9af0"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ from sqlalchemy import String
# revision identifiers, used by Alembic.
revision = "46625e4745d4"
down_revision = "9d97fecfab7f"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,28 @@
"""PG File Store
Revision ID: 4738e4b3bae1
Revises: e91df4e935ef
Create Date: 2024-03-20 18:53:32.461518
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4738e4b3bae1"
down_revision = "e91df4e935ef"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"file_store",
sa.Column("file_name", sa.String(), nullable=False),
sa.Column("lobj_oid", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("file_name"),
)
def downgrade() -> None:
op.drop_table("file_store")

View File

@ -11,9 +11,9 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "47433d30de82"
down_revision = None
branch_labels = None
depends_on = None
down_revision: None = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,23 @@
"""Add name to api_key
Revision ID: 475fcefe8826
Revises: ecab2b3f1a3b
Create Date: 2024-04-11 11:05:18.414438
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "475fcefe8826"
down_revision = "ecab2b3f1a3b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("api_key", sa.Column("name", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("api_key", "name")

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "50b683a8295c"
down_revision = "7da0ae5ad583"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,27 @@
"""Track Danswerbot Explicitly
Revision ID: 570282d33c49
Revises: 7547d982db8f
Create Date: 2024-05-04 17:49:28.568109
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "570282d33c49"
down_revision = "7547d982db8f"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("danswerbot_flow", sa.Boolean(), nullable=True)
)
op.execute("UPDATE chat_session SET danswerbot_flow = one_shot")
op.alter_column("chat_session", "danswerbot_flow", nullable=False)
def downgrade() -> None:
op.drop_column("chat_session", "danswerbot_flow")

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "57b53544726e"
down_revision = "800f48024ae9"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -13,8 +13,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5809c0787398"
down_revision = "d929f0c1c6af"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5e84129c8be3"
down_revision = "e6a4bbc13fe4"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,27 @@
"""add removed documents to index_attempt
Revision ID: 5f4b8568a221
Revises: dbaa756c2ccf
Create Date: 2024-02-16 15:02:03.319907
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5f4b8568a221"
down_revision = "8987770549c0"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"index_attempt",
sa.Column("docs_removed_from_index", sa.Integer()),
)
op.execute("UPDATE index_attempt SET docs_removed_from_index = 0")
def downgrade() -> None:
op.drop_column("index_attempt", "docs_removed_from_index")

View File

@ -0,0 +1,45 @@
"""Add user-configured names to LLMProvider
Revision ID: 643a84a42a33
Revises: 0a98909f2757
Create Date: 2024-05-07 14:54:55.493100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "643a84a42a33"
down_revision = "0a98909f2757"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("llm_provider", sa.Column("provider", sa.String(), nullable=True))
# move "name" -> "provider" to match the new schema
op.execute("UPDATE llm_provider SET provider = name")
# pretty up display name
op.execute("UPDATE llm_provider SET name = 'OpenAI' WHERE name = 'openai'")
op.execute("UPDATE llm_provider SET name = 'Anthropic' WHERE name = 'anthropic'")
op.execute("UPDATE llm_provider SET name = 'Azure OpenAI' WHERE name = 'azure'")
op.execute("UPDATE llm_provider SET name = 'AWS Bedrock' WHERE name = 'bedrock'")
# update personas to use the new provider names
op.execute(
"UPDATE persona SET llm_model_provider_override = 'OpenAI' WHERE llm_model_provider_override = 'openai'"
)
op.execute(
"UPDATE persona SET llm_model_provider_override = 'Anthropic' WHERE llm_model_provider_override = 'anthropic'"
)
op.execute(
"UPDATE persona SET llm_model_provider_override = 'Azure OpenAI' WHERE llm_model_provider_override = 'azure'"
)
op.execute(
"UPDATE persona SET llm_model_provider_override = 'AWS Bedrock' WHERE llm_model_provider_override = 'bedrock'"
)
def downgrade() -> None:
op.execute("UPDATE llm_provider SET name = provider")
op.drop_column("llm_provider", "provider")

View File

@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "6d387b3196c2"
down_revision = "47433d30de82"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,83 @@
"""Add TokenRateLimit Tables
Revision ID: 703313b75876
Revises: fad14119fb92
Create Date: 2024-04-15 01:36:02.952809
"""
import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.dynamic_configs.factory import get_dynamic_config_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
down_revision = "fad14119fb92"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"token_rate_limit",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False),
sa.Column("token_budget", sa.Integer(), nullable=False),
sa.Column("period_hours", sa.Integer(), nullable=False),
sa.Column(
"scope",
sa.String(length=10),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"token_rate_limit__user_group",
sa.Column("rate_limit_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["rate_limit_id"],
["token_rate_limit.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
)
try:
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
token_budget = settings.get("token_budget", -1)
period_hours = settings.get("period_hours", -1)
if is_enabled and token_budget > 0 and period_hours > 0:
op.execute(
f"INSERT INTO token_rate_limit \
(enabled, token_budget, period_hours, scope) VALUES \
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
)
# Delete the dynamic config
get_dynamic_config_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found
pass
def downgrade() -> None:
op.drop_table("token_rate_limit__user_group")
op.drop_table("token_rate_limit")

View File

@ -0,0 +1,68 @@
"""More Descriptive Filestore
Revision ID: 70f00c45c0f2
Revises: 3879338f8ba1
Create Date: 2024-05-17 17:51:41.926893
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "70f00c45c0f2"
down_revision = "3879338f8ba1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("file_store", sa.Column("display_name", sa.String(), nullable=True))
op.add_column(
"file_store",
sa.Column(
"file_origin",
sa.String(),
nullable=False,
server_default="connector", # Default to connector
),
)
op.add_column(
"file_store",
sa.Column(
"file_type", sa.String(), nullable=False, server_default="text/plain"
),
)
op.add_column(
"file_store",
sa.Column(
"file_metadata",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.execute(
"""
UPDATE file_store
SET file_origin = CASE
WHEN file_name LIKE 'chat__%' THEN 'chat_upload'
ELSE 'connector'
END,
file_name = CASE
WHEN file_name LIKE 'chat__%' THEN SUBSTR(file_name, 7)
ELSE file_name
END,
file_type = CASE
WHEN file_name LIKE 'chat__%' THEN 'image/png'
ELSE 'text/plain'
END
"""
)
def downgrade() -> None:
op.drop_column("file_store", "file_metadata")
op.drop_column("file_store", "file_type")
op.drop_column("file_store", "file_origin")
op.drop_column("file_store", "display_name")

View File

@ -0,0 +1,81 @@
"""Permission Auto Sync Framework
Revision ID: 72bdc9929a46
Revises: 475fcefe8826
Create Date: 2024-04-14 21:15:28.659634
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "72bdc9929a46"
down_revision = "475fcefe8826"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"email_to_external_user_cache",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_user_id", sa.String(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"external_permission",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("external_permission_group", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"permission_sync_run",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("update_type", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.Column(
"status",
sa.String(),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("permission_sync_run")
op.drop_table("external_permission")
op.drop_table("email_to_external_user_cache")

View File

@ -0,0 +1,51 @@
"""Chat Folders
Revision ID: 7547d982db8f
Revises: ef7da92f7213
Create Date: 2024-05-02 15:18:56.573347
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "7547d982db8f"
down_revision = "ef7da92f7213"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"chat_folder",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("name", sa.String(), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("chat_session", sa.Column("folder_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"chat_session_chat_folder_fk",
"chat_session",
"chat_folder",
["folder_id"],
["id"],
)
def downgrade() -> None:
op.drop_constraint(
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
)
op.drop_column("chat_session", "folder_id")
op.drop_table("chat_folder")

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "767f1c2a00eb"
down_revision = "dba7f71618f5"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "76b60d407dfb"
down_revision = "b156fa702355"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,71 @@
"""Remove Remaining Enums
Revision ID: 776b3bbe9092
Revises: 4738e4b3bae1
Create Date: 2024-03-22 21:34:27.629444
"""
from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"
down_revision = "4738e4b3bae1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.alter_column(
"persona",
"search_type",
type_=sa.String,
existing_type=sa.Enum(SearchType, native_enum=False),
existing_nullable=False,
)
op.alter_column(
"persona",
"recency_bias",
type_=sa.String,
existing_type=sa.Enum(RecencyBiasSetting, native_enum=False),
existing_nullable=False,
)
# Because the indexmodelstatus enum does not have a mapping to a string type
# we need this workaround instead of directly changing the type
op.add_column("embedding_model", sa.Column("temp_status", sa.String))
op.execute("UPDATE embedding_model SET temp_status = status::text")
op.drop_column("embedding_model", "status")
op.alter_column("embedding_model", "temp_status", new_column_name="status")
op.execute("DROP TYPE IF EXISTS searchtype")
op.execute("DROP TYPE IF EXISTS recencybiassetting")
op.execute("DROP TYPE IF EXISTS indexmodelstatus")
def downgrade() -> None:
op.alter_column(
"persona",
"search_type",
type_=sa.Enum(SearchType, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)
op.alter_column(
"persona",
"recency_bias",
type_=sa.Enum(RecencyBiasSetting, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)
op.alter_column(
"embedding_model",
"status",
type_=sa.Enum(IndexModelStatus, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)

View File

@ -12,8 +12,8 @@ from sqlalchemy import String
# revision identifiers, used by Alembic.
revision = "77d07dffae64"
down_revision = "d61e513bef0a"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "78dbe7e38469"
down_revision = "7ccea01261f6"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "79acd316403a"
down_revision = "904e5138fffb"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "7ccea01261f6"
down_revision = "a570b80a5f20"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7da0ae5ad583"
down_revision = "e86866a9c78a"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "7da543f5672f"
down_revision = "febe9eaa0644"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7f726bad5367"
down_revision = "79acd316403a"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = "7f99be1cb9f5"
down_revision = "78dbe7e38469"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.schema import Sequence, CreateSequence
# revision identifiers, used by Alembic.
revision = "800f48024ae9"
down_revision = "767f1c2a00eb"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "80696cf850ae"
down_revision = "15326fcec57e"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "891cd83c87a8"
down_revision = "76b60d407dfb"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8987770549c0"
down_revision = "ec3ec2eabf7b"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "8aabb57f3b49"
down_revision = "5e84129c8be3"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e26726b7683"
down_revision = "5809c0787398"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "904451035c9b"
down_revision = "3b25685ff73c"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "904e5138fffb"
down_revision = "891cd83c87a8"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,36 @@
"""Remove DocumentSource from Tag
Revision ID: 91fd3b470d1a
Revises: 173cae5bba26
Create Date: 2024-03-21 12:05:23.956734
"""
from alembic import op
import sqlalchemy as sa
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "91fd3b470d1a"
down_revision = "173cae5bba26"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.alter_column(
"tag",
"source",
type_=sa.String(length=50),
existing_type=sa.Enum(DocumentSource, native_enum=False),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"tag",
"source",
type_=sa.Enum(DocumentSource, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "9d97fecfab7f"
down_revision = "ffc707a226b4"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,27 @@
"""Add chosen_assistants to User table
Revision ID: a3bfd0d64902
Revises: ec85f2b3c544
Create Date: 2024-05-26 17:22:24.834741
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a3bfd0d64902"
down_revision = "ec85f2b3c544"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
def downgrade() -> None:
op.drop_column("user", "chosen_assistants")

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a570b80a5f20"
down_revision = "904451035c9b"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ae62505e3acc"
down_revision = "7da543f5672f"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b082fec533f0"
down_revision = "df0c7ad8a076"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -15,8 +15,8 @@ from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "b156fa702355"
down_revision = "baf71f781b9e"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
searchtype_enum = ENUM(

View File

@ -0,0 +1,28 @@
"""fix-file-type-migration
Revision ID: b85f02ec1308
Revises: a3bfd0d64902
Create Date: 2024-05-31 18:09:26.658164
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b85f02ec1308"
down_revision = "a3bfd0d64902"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE file_store
SET file_origin = UPPER(file_origin)
"""
)
def downgrade() -> None:
# Let's not break anything on purpose :)
pass

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "baf71f781b9e"
down_revision = "50b683a8295c"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "d5645c915d0e"
down_revision = "8e26726b7683"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d61e513bef0a"
down_revision = "46625e4745d4"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "d7111c1238cd"
down_revision = "465f78d9b7f9"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -13,8 +13,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d929f0c1c6af"
down_revision = "8aabb57f3b49"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dba7f71618f5"
down_revision = "d5645c915d0e"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -9,18 +9,18 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.db.embedding_model import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
)
from danswer.db.models import IndexModelStatus
# revision identifiers, used by Alembic.
revision = "dbaa756c2ccf"
down_revision = "7f726bad5367"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
@ -40,6 +40,9 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
# since all index attempts must be associated with an embedding model,
# need to put something in here to avoid nulls. On server startup,
# this value will be overriden
EmbeddingModel = table(
"embedding_model",
column("id", Integer),
@ -53,20 +56,44 @@ def upgrade() -> None:
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
),
)
# insert an embedding model row that corresponds to the embedding model
# the user selected via env variables before this change. This is needed since
# all index_attempts must be associated with an embedding model, so without this
# we will run into violations of non-null contraints
old_embedding_model = get_old_default_embedding_model()
op.bulk_insert(
EmbeddingModel,
[
{
"model_name": DOCUMENT_ENCODER_MODEL,
"model_dim": DOC_EMBEDDING_DIM,
"normalize": NORMALIZE_EMBEDDINGS,
"query_prefix": ASYM_QUERY_PREFIX,
"passage_prefix": ASYM_PASSAGE_PREFIX,
"index_name": "danswer_chunk",
"status": IndexModelStatus.PRESENT,
"model_name": old_embedding_model.model_name,
"model_dim": old_embedding_model.model_dim,
"normalize": old_embedding_model.normalize,
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": old_embedding_model.status,
}
],
)
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model(is_present=False)
op.bulk_insert(
EmbeddingModel,
[
{
"model_name": new_embedding_model.model_name,
"model_dim": new_embedding_model.model_dim,
"normalize": new_embedding_model.normalize,
"query_prefix": new_embedding_model.query_prefix,
"passage_prefix": new_embedding_model.passage_prefix,
"index_name": new_embedding_model.index_name,
"status": IndexModelStatus.FUTURE,
}
],
)
op.add_column(
"index_attempt",
sa.Column("embedding_model_id", sa.Integer(), nullable=True),

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "df0c7ad8a076"
down_revision = "d7111c1238cd"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e0a68a81d434"
down_revision = "ae62505e3acc"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,38 @@
"""No Source Enum
Revision ID: e50154680a5c
Revises: fcd135795f21
Create Date: 2024-03-14 18:06:08.523106
"""
from alembic import op
import sqlalchemy as sa
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "e50154680a5c"
down_revision = "fcd135795f21"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.alter_column(
"search_doc",
"source_type",
type_=sa.String(length=50),
existing_type=sa.Enum(DocumentSource, native_enum=False),
existing_nullable=False,
)
op.execute("DROP TYPE IF EXISTS documentsource")
def downgrade() -> None:
op.alter_column(
"search_doc",
"source_type",
type_=sa.Enum(DocumentSource, native_enum=False),
existing_type=sa.String(length=50),
existing_nullable=False,
)

View File

@ -11,8 +11,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = "e6a4bbc13fe4"
down_revision = "b082fec533f0"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e86866a9c78a"
down_revision = "80696cf850ae"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,118 @@
"""Private Personas DocumentSets
Revision ID: e91df4e935ef
Revises: 91fd3b470d1a
Create Date: 2024-03-17 11:47:24.675881
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e91df4e935ef"
down_revision = "91fd3b470d1a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"document_set__user",
sa.Column("document_set_id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.ForeignKeyConstraint(
["document_set_id"],
["document_set.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("document_set_id", "user_id"),
)
op.create_table(
"persona__user",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("persona_id", "user_id"),
)
op.create_table(
"document_set__user_group",
sa.Column("document_set_id", sa.Integer(), nullable=False),
sa.Column(
"user_group_id",
sa.Integer(),
nullable=False,
),
sa.ForeignKeyConstraint(
["document_set_id"],
["document_set.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("document_set_id", "user_group_id"),
)
op.create_table(
"persona__user_group",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column(
"user_group_id",
sa.Integer(),
nullable=False,
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("persona_id", "user_group_id"),
)
op.add_column(
"document_set",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
# fill in is_public for existing rows
op.execute("UPDATE document_set SET is_public = true WHERE is_public IS NULL")
op.alter_column("document_set", "is_public", nullable=False)
op.add_column(
"persona",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
# fill in is_public for existing rows
op.execute("UPDATE persona SET is_public = true WHERE is_public IS NULL")
op.alter_column("persona", "is_public", nullable=False)
def downgrade() -> None:
op.drop_column("persona", "is_public")
op.drop_column("document_set", "is_public")
op.drop_table("persona__user")
op.drop_table("document_set__user")
op.drop_table("persona__user_group")
op.drop_table("document_set__user_group")

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ec3ec2eabf7b"
down_revision = "dbaa756c2ccf"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,31 @@
"""Remove Last Attempt Status from CC Pair
Revision ID: ec85f2b3c544
Revises: 3879338f8ba1
Create Date: 2024-05-23 21:39:46.126010
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ec85f2b3c544"
down_revision = "70f00c45c0f2"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("connector_credential_pair", "last_attempt_status")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_attempt_status",
sa.VARCHAR(),
autoincrement=False,
nullable=True,
),
)

View File

@ -0,0 +1,40 @@
"""Add overrides to the chat session
Revision ID: ecab2b3f1a3b
Revises: 38eda64af7fe
Create Date: 2024-04-01 19:08:21.359102
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "ecab2b3f1a3b"
down_revision = "38eda64af7fe"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_session",
sa.Column(
"llm_override",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"chat_session",
sa.Column(
"prompt_override",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_session", "prompt_override")
op.drop_column("chat_session", "llm_override")

View File

@ -0,0 +1,27 @@
"""Add files to ChatMessage
Revision ID: ef7da92f7213
Revises: 401c1ac29467
Create Date: 2024-04-28 16:59:33.199153
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "ef7da92f7213"
down_revision = "401c1ac29467"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "files")

View File

@ -0,0 +1,25 @@
"""Add pre-defined feedback
Revision ID: f1c6478c3fd8
Revises: 643a84a42a33
Create Date: 2024-05-09 18:11:49.210667
"""
from alembic import op
import sqlalchemy as sa
revision = "f1c6478c3fd8"
down_revision = "643a84a42a33"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_feedback",
sa.Column("predefined_feedback", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_feedback", "predefined_feedback")

View File

@ -0,0 +1,39 @@
"""Delete Tags with wrong Enum
Revision ID: fad14119fb92
Revises: 72bdc9929a46
Create Date: 2024-04-25 17:05:09.695703
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "fad14119fb92"
down_revision = "72bdc9929a46"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Some documents may lose their tags but this is the only way as the enum
# mapping may have changed since tag switched to string (it will be reindexed anyway)
op.execute(
"""
DELETE FROM document__tag
WHERE tag_id IN (
SELECT id FROM tag
WHERE source ~ '^[0-9]+$'
)
"""
)
op.execute(
"""
DELETE FROM tag
WHERE source ~ '^[0-9]+$'
"""
)
def downgrade() -> None:
pass

View File

@ -0,0 +1,39 @@
"""Add slack bot display type
Revision ID: fcd135795f21
Revises: 0a2b51deb0b8
Create Date: 2024-03-04 17:03:27.116284
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fcd135795f21"
down_revision = "0a2b51deb0b8"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"slack_bot_config",
sa.Column(
"response_type",
sa.Enum(
"QUOTES",
"CITATIONS",
name="slackbotresponsetype",
native_enum=False,
),
nullable=True,
),
)
op.execute(
"UPDATE slack_bot_config SET response_type = 'QUOTES' WHERE response_type IS NULL"
)
op.alter_column("slack_bot_config", "response_type", nullable=False)
def downgrade() -> None:
op.drop_column("slack_bot_config", "response_type")

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "febe9eaa0644"
down_revision = "57b53544726e"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "ffc707a226b4"
down_revision = "30c1d5744104"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from typing import Any
from typing import cast
from danswer.auth.schemas import UserRole
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
return UserInfo(
id="__no_auth_user__",
email="anonymous@danswer.ai",
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
)

View File

@ -23,8 +23,8 @@ from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserCreate
@ -33,15 +33,18 @@ from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
@ -69,6 +72,20 @@ def verify_auth_setting() -> None:
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
def get_display_email(email: str | None, space_less: bool = False) -> str:
if email and email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN):
name = email.split("@")[0]
if name == DANSWER_API_KEY_PREFIX + UNNAMED_KEY_PLACEHOLDER:
return "Unnamed API Key"
if space_less:
return name
return name.replace("API_KEY__", "API Key: ")
return email or ""
def user_needs_to_be_verified() -> bool:
# all other auth types besides basic should require users to be
# verified
@ -133,8 +150,8 @@ def send_user_verification_email(
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = SECRET
verification_token_secret = SECRET
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
async def create(
self,
@ -213,7 +230,10 @@ async def get_user_manager(
yield UserManager(user_db)
cookie_transport = CookieTransport(cookie_max_age=SESSION_EXPIRE_TIME_SECONDS)
cookie_transport = CookieTransport(
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
cookie_secure=WEB_DOMAIN.startswith("https"),
)
def get_database_strategy(
@ -276,13 +296,32 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
# take care of that in `double_check_user` ourself. This is needed, since
# we want the /me endpoint to still return a user even if they are not
# yet verified, so that the frontend knows they exist
optional_valid_user = fastapi_users.current_user(active=True, optional=True)
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
async def double_check_user(
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
return user
async def optional_user(
request: Request,
user: User | None = Depends(optional_fastapi_current_user),
db_session: Session = Depends(get_session),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
) -> User | None:
if optional:
@ -304,15 +343,9 @@ async def double_check_user(
async def current_user(
request: Request,
user: User | None = Depends(optional_valid_user),
db_session: Session = Depends(get_session),
user: User | None = Depends(optional_user),
) -> User | None:
double_check_user = fetch_versioned_implementation(
"danswer.auth.users", "double_check_user"
)
user = await double_check_user(request, user, db_session)
return user
return await double_check_user(user)
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:

View File

@ -1,6 +1,4 @@
import os
from datetime import timedelta
from pathlib import Path
from typing import cast
from celery import Celery # type: ignore
@ -10,16 +8,14 @@ from danswer.background.connector_deletion import delete_connector_credential_pa
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.file.utils import file_age_in_hours
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set
from danswer.db.document_set import fetch_documents_for_document_set_paginated
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string
@ -31,7 +27,6 @@ from danswer.db.tasks import get_latest_task
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -42,7 +37,7 @@ celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
_SYNC_BATCH_SIZE = 1000
_SYNC_BATCH_SIZE = 100
#####
@ -67,15 +62,18 @@ def cleanup_connector_credential_pair_task(
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
if not cc_pair:
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
f"{connector_id} and Credential ID: {credential_id} does not exist."
)
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, db_session=db_session
)
if deletion_attempt_disallowed_reason:
raise ValueError(deletion_attempt_disallowed_reason)
try:
# The bulk of the work is in here, updates Postgres and Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
@ -98,15 +96,13 @@ def sync_document_set_task(document_set_id: int) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
def _sync_document_batch(document_ids: list[str]) -> None:
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# begin a transaction, release lock at the end
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
)
# Acquires a lock on the documents so that no other process can modify them
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
@ -131,17 +127,21 @@ def sync_document_set_task(document_set_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
try:
documents_to_update = fetch_documents_for_document_set(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
)
for document_batch in batch_generator(
documents_to_update, _SYNC_BATCH_SIZE
):
_sync_document_batch(
document_ids=[document.id for document in document_batch]
cursor = None
while True:
document_batch, cursor = fetch_documents_for_document_set_paginated(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
last_document_id=cursor,
limit=_SYNC_BATCH_SIZE,
)
_sync_document_batch(
document_ids=[document.id for document in document_batch],
db_session=db_session,
)
if cursor is None:
break
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
@ -182,7 +182,7 @@ def check_for_document_sets_sync_task() -> None:
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
db_session=db_session, include_outdated=True
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
@ -203,21 +203,6 @@ def check_for_document_sets_sync_task() -> None:
)
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
def clean_old_temp_files_task(
age_threshold_in_hours: float | int = 24 * 7, # 1 week,
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
) -> None:
"""Files added via the File connector need to be deleted after ingestion
Currently handled async of the indexing job"""
os.makedirs(base_path, exist_ok=True)
for file in os.listdir(base_path):
full_file_path = Path(base_path) / file
if file_age_in_hours(full_file_path) > age_threshold_in_hours:
logger.info(f"Cleaning up uploaded file: {full_file_path}")
os.remove(full_file_path)
#####
# Celery Beat (Periodic Tasks) Settings
#####
@ -226,8 +211,4 @@ celery_app.conf.beat_schedule = {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"clean-old-temp-files": {
"task": "clean_old_temp_files_task",
"schedule": timedelta(minutes=30),
},
}

View File

@ -19,8 +19,8 @@ from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.document import delete_document_by_connector_credential_pair
from danswer.db.document import delete_documents_complete
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
@ -47,60 +47,65 @@ def _delete_connector_credential_pair_batch(
credential_id: int,
document_index: DocumentIndex,
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
it gets permanently deleted.
"""
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
document_connector_cnts = get_document_connector_cnts(
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
):
document_connector_cnts = get_document_connector_cnts(
db_session=db_session, document_ids=document_ids
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
delete_document_by_connector_credential_pair(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
def cleanup_synced_entities(

View File

@ -6,16 +6,16 @@ NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing import Process
from typing import Any
from typing import Literal
from typing import Optional
from torch import multiprocessing
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.logger import setup_logger
logger = setup_logger()
JobStatusType = (
Literal["error"]
| Literal["finished"]
@ -25,12 +25,28 @@ JobStatusType = (
)
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Ensure the parent proc's database connections are not touched
in the new connection pool
Based on the recommended approach in the SQLAlchemy docs found:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}
get_sqlalchemy_engine().dispose(close=False)
return func(*args, **kwargs)
@dataclass
class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
id: int
process: multiprocessing.Process | None = None
process: Optional["Process"] = None
def cancel(self) -> bool:
return self.release()
@ -95,7 +111,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = multiprocessing.Process(target=func, args=args, daemon=True)
process = Process(target=_initializer(func=func, args=args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@ -4,10 +4,13 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
import torch
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import (
_delete_connector_credential_pair_batch,
)
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
@ -19,10 +22,11 @@ from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
@ -42,8 +46,14 @@ def _get_document_generator(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
) -> GenerateDocumentsOutput:
"""NOTE: `start_time` and `end_time` are only used for poll connectors"""
) -> tuple[GenerateDocumentsOutput, bool]:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
Returns an interator of document batches and whether the returned documents
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
task = attempt.connector.input_type
try:
@ -65,7 +75,7 @@ def _get_document_generator(
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
is_listing_complete = True
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
@ -78,12 +88,13 @@ def _get_document_generator(
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
is_listing_complete = False
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator
return doc_batch_generator, is_listing_complete
def _run_indexing(
@ -104,16 +115,6 @@ def _run_indexing(
# Secondary index syncs at the end when swapping
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
# Mark as started
mark_attempt_in_progress(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
)
# Indexing is only done into one index at a time
document_index = get_default_document_index(
primary_index_name=index_name, secondary_index_name=None
@ -131,6 +132,7 @@ def _run_indexing(
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (db_embedding_model.status == IndexModelStatus.FUTURE),
db_session=db_session,
)
db_connector = index_attempt.connector
@ -158,19 +160,20 @@ def _run_indexing(
source_type=db_connector.source,
)
):
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
)
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
doc_batch_generator, is_listing_complete = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
)
all_connector_doc_ids: set[str] = set()
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
@ -186,6 +189,7 @@ def _run_indexing(
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError("Index Attempt was canceled")
logger.debug(
@ -202,6 +206,7 @@ def _run_indexing(
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
all_connector_doc_ids.update(doc.id for doc in doc_batch)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@ -216,6 +221,40 @@ def _run_indexing(
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=0,
)
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
# clean up all documents from the index that have not been returned from the connector
all_indexed_document_ids = {
d.id
for d in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
)
}
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids
)
logger.debug(
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
)
# delete docs from cc-pair and receive the number of completely deleted docs in return
_delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=db_connector.id,
credential_id=db_credential.id,
document_index=document_index,
)
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=len(doc_ids_to_remove),
)
run_end_dt = window_end
@ -224,7 +263,6 @@ def _run_indexing(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
@ -255,7 +293,6 @@ def _run_indexing(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
)
raise e
@ -270,7 +307,6 @@ def _run_indexing(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
run_dt=run_end_dt,
)
@ -282,7 +318,35 @@ def _run_indexing(
)
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
# status and marking it as in_progress. This setting will be discarded
# after the next commit:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress__no_commit(attempt)
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
db_session.commit()
return attempt
def run_indexing_entrypoint(index_attempt_id: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
@ -291,17 +355,10 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
# make sure that it is valid to run this indexing attempt + mark it
# as in progress
attempt = _prepare_index_attempt(db_session, index_attempt_id)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
@ -309,10 +366,7 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
f"with credentials: '{attempt.credential_id}'"
)
_run_indexing(
db_session=db_session,
index_attempt=attempt,
)
_run_indexing(db_session, attempt)
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "

View File

@ -3,7 +3,6 @@ import time
from datetime import datetime
import dask
import torch
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
@ -15,21 +14,13 @@ from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
@ -41,7 +32,12 @@ from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@ -54,22 +50,19 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
)
"""Util funcs"""
def _get_num_threads() -> int:
"""Get # of "threads" to use for ML models in an indexing job. By default uses
the torch implementation, which returns the # of physical cores on the machine.
"""
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
def _should_create_new_indexing(
connector: Connector,
last_index: IndexAttempt | None,
model: EmbeddingModel,
secondary_index_building: bool,
db_session: Session,
) -> bool:
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
return False
# When switching over models, always index at least once
if model.status == IndexModelStatus.FUTURE and not last_index:
if connector.id == 0: # Ingestion API
@ -124,17 +117,6 @@ def _mark_run_failed(
db_session=db_session,
failure_reason=failure_reason,
)
if (
index_attempt.connector_id is not None
and index_attempt.credential_id is not None
and index_attempt.embedding_model.status == IndexModelStatus.PRESENT
):
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_id,
credential_id=index_attempt.credential_id,
attempt_status=IndexingStatus.FAILED,
)
"""Main funcs"""
@ -185,7 +167,11 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
connector.id, credential.id, model.id, db_session
)
if not _should_create_new_indexing(
connector, last_attempt, model, db_session
connector=connector,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
@ -193,16 +179,6 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
connector.id, credential.id, model.id, db_session
)
# CC-Pair will have the status that it should for the primary index
# Will be re-sync-ed once the indices are swapped
if model.status == IndexModelStatus.PRESENT:
update_connector_credential_pair(
db_session=db_session,
connector_id=connector.id,
credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED,
)
def cleanup_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
@ -254,6 +230,9 @@ def cleanup_indexing_jobs(
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
@ -328,12 +307,10 @@ def kickoff_indexing_jobs(
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
run_indexing_entrypoint, attempt.id, pure=False
)
else:
run = client.submit(
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
)
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
if run:
secondary_str = "(secondary index) " if use_secondary_index else ""
@ -348,49 +325,22 @@ def kickoff_indexing_jobs(
return existing_jobs_copy
def check_index_swap(db_session: Session) -> None:
"""Get count of cc-pairs and count of index_attempts for the new model grouped by
connector + credential, if it's the same, then assume new index is done building.
This does not take into consideration if the attempt failed or not"""
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = len(all_cc_pairs) - 1
embedding_model = get_secondary_db_embedding_model(db_session)
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
if not embedding_model:
return
unique_cc_indexings = count_unique_cc_pairs_with_index_attempts(
embedding_model_id=embedding_model.id, db_session=db_session
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if unique_cc_indexings > cc_pair_count:
raise RuntimeError("More unique indexings than cc pairs, should not occur")
if cc_pair_count == unique_cc_indexings:
# Swap indices
now_old_embedding_model = get_current_db_embedding_model(db_session)
update_embedding_model_status(
embedding_model=now_old_embedding_model,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_embedding_model_status(
embedding_model=embedding_model,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
@ -417,12 +367,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client_secondary = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
# This ensures that bad states get cleaned up
mark_all_in_progress_cc_pairs_failed(db_session)
while True:
start = time.time()
@ -454,12 +398,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
def update__main() -> None:
# needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup
# before any other torch code has been run
if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn")
logger.info("Starting Indexing Loop")
update_loop()

View File

@ -1,168 +1,39 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from functools import lru_cache
from collections.abc import Sequence
from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.prompt_utils import get_current_llm_day_time
from danswer.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
)
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
# Maps connector enum string to a more natural language representation for the LLM
# If not on the list, uses the original but slightly cleaned up, see below
CONNECTOR_NAME_MAP = {
"web": "Website",
"requesttracker": "Request Tracker",
"github": "GitHub",
"file": "File Upload",
}
logger = setup_logger()
def clean_up_source(source_str: str) -> str:
if source_str in CONNECTOR_NAME_MAP:
return CONNECTOR_NAME_MAP[source_str]
return source_str.replace("_", " ").title()
def build_doc_context_str(
semantic_identifier: str,
source_type: DocumentSource,
content: str,
metadata_dict: dict[str, str | list[str]],
updated_at: datetime | None,
ind: int,
include_metadata: bool = True,
) -> str:
context_str = ""
if include_metadata:
context_str += f"DOCUMENT {ind}: {semantic_identifier}\n"
context_str += f"Source: {clean_up_source(source_type)}\n"
for k, v in metadata_dict.items():
if isinstance(v, list):
v_str = ", ".join(v)
context_str += f"{k.capitalize()}: {v_str}\n"
else:
context_str += f"{k.capitalize()}: {v}\n"
if updated_at:
update_str = updated_at.strftime("%B %d, %Y %H:%M")
context_str += f"Updated: {update_str}\n"
context_str += f"{CODE_BLOCK_PAT.format(content.strip())}\n\n\n"
return context_str
def build_complete_context_str(
context_docs: list[LlmDoc | InferenceChunk],
include_metadata: bool = True,
) -> str:
context_str = ""
for ind, doc in enumerate(context_docs, start=1):
context_str += build_doc_context_str(
semantic_identifier=doc.semantic_identifier,
source_type=doc.source_type,
content=doc.content,
metadata_dict=doc.metadata,
updated_at=doc.updated_at,
ind=ind,
include_metadata=include_metadata,
)
return context_str.strip()
@lru_cache()
def build_chat_system_message(
prompt: Prompt,
context_exists: bool,
llm_tokenizer_encode_func: Callable,
citation_line: str = REQUIRE_CITATION_STATEMENT,
no_citation_line: str = NO_CITATION_STATEMENT,
) -> tuple[SystemMessage | None, int]:
system_prompt = prompt.system_prompt.strip()
if prompt.include_citations:
if context_exists:
system_prompt += citation_line
else:
system_prompt += no_citation_line
if prompt.datetime_aware:
if system_prompt:
system_prompt += (
f"\n\nAdditional Information:\n\t- {get_current_llm_day_time()}."
)
else:
system_prompt = get_current_llm_day_time()
if not system_prompt:
return None, 0
token_count = len(llm_tokenizer_encode_func(system_prompt))
system_msg = SystemMessage(content=system_prompt)
return system_msg, token_count
def build_task_prompt_reminders(
prompt: Prompt,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
citation_str: str = CITATION_REMINDER,
language_hint_str: str = LANGUAGE_HINT,
) -> str:
base_task = prompt.task_prompt
citation_or_nothing = citation_str if prompt.include_citations else ""
language_hint_or_nothing = language_hint_str.lstrip() if use_language_hint else ""
return base_task + citation_or_nothing + language_hint_or_nothing
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inf_chunk.document_id,
content=inf_chunk.content,
# This one is using the combined content of all the chunks of the section
# In default settings, this is the same as just the content of base chunk
content=inf_chunk.combined_content,
blurb=inf_chunk.blurb,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,
metadata=inf_chunk.metadata,
updated_at=inf_chunk.updated_at,
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
source_links=inf_chunk.source_links,
)
def map_document_id_order(
chunks: list[InferenceChunk | LlmDoc], one_indexed: bool = True
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
order_mapping = {}
current = 1 if one_indexed else 0
@ -174,157 +45,6 @@ def map_document_id_order(
return order_mapping
def build_chat_user_message(
chat_message: ChatMessage,
prompt: Prompt,
context_docs: list[LlmDoc],
llm_tokenizer_encode_func: Callable,
all_doc_useful: bool,
user_prompt_template: str = CHAT_USER_PROMPT,
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
ignore_str: str = DEFAULT_IGNORE_STATEMENT,
) -> tuple[HumanMessage, int]:
user_query = chat_message.message
if not context_docs:
# Simpler prompt for cases where there is no context
user_prompt = (
context_free_template.format(
task_prompt=prompt.task_prompt, user_query=user_query
)
if prompt.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
context_docs_str = build_complete_context_str(
cast(list[LlmDoc | InferenceChunk], context_docs)
)
optional_ignore = "" if all_doc_useful else ignore_str
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
user_prompt = user_prompt_template.format(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=user_query,
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
def _get_usable_chunks(
chunks: list[InferenceChunk], token_limit: int
) -> list[InferenceChunk]:
total_token_count = 0
usable_chunks = []
for chunk in chunks:
chunk_token_count = check_number_of_tokens(chunk.content)
if total_token_count + chunk_token_count > token_limit:
break
total_token_count += chunk_token_count
usable_chunks.append(chunk)
# try and return at least one chunk if possible. This chunk will
# get truncated later on in the pipeline. This would only occur if
# the first chunk is larger than the token limit (usually due to character
# count -> token count mismatches caused by special characters / non-ascii
# languages)
if not usable_chunks and chunks:
usable_chunks = [chunks[0]]
return usable_chunks
def get_usable_chunks(
chunks: list[InferenceChunk],
token_limit: int,
offset: int = 0,
) -> list[InferenceChunk]:
offset_into_chunks = 0
usable_chunks: list[InferenceChunk] = []
for _ in range(min(offset + 1, 1)): # go through this process at least once
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
raise ValueError(
"Chunks offset too large, should not retry this many times"
)
usable_chunks = _get_usable_chunks(
chunks=chunks[offset_into_chunks:], token_limit=token_limit
)
offset_into_chunks += len(usable_chunks)
return usable_chunks
def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
token_limit: int | None,
batch_offset: int = 0,
) -> list[int]:
"""
Gives back indices of chunks to pass into the LLM for Q&A.
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
by the LLM in a separate flow (this can be turned off)
Note, the batch_offset calculation has to count the batches from the beginning each time as
there's no way to know which chunks were included in the prior batches without recounting atm,
this is somewhat slow as it requires tokenizing all the chunks again
"""
batch_index = 0
latest_batch_indices: list[int] = []
token_count = 0
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
for selection_target in [True, False]:
for ind, chunk in enumerate(chunks):
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
IGNORE_FOR_QA
):
continue
# We calculate it live in case the user uses a different LLM + tokenizer
chunk_token = check_number_of_tokens(chunk.content)
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
token_count += chunk_token + 50
# Always use at least 1 chunk
if (
token_limit is None
or token_count <= token_limit
or not latest_batch_indices
):
latest_batch_indices.append(ind)
current_chunk_unused = False
else:
current_chunk_unused = True
if token_limit is not None and token_count >= token_limit:
if batch_index < batch_offset:
batch_index += 1
if current_chunk_unused:
latest_batch_indices = [ind]
token_count = chunk_token
else:
latest_batch_indices = []
token_count = 0
else:
return latest_batch_indices
return latest_batch_indices
def create_chat_chain(
chat_session_id: int,
db_session: Session,
@ -340,7 +60,7 @@ def create_chat_chain(
id_to_msg = {msg.id: msg for msg in all_chat_messages}
if not all_chat_messages:
raise ValueError("No messages in Chat Session")
raise RuntimeError("No messages in Chat Session")
root_message = all_chat_messages[0]
if root_message.parent_message is not None:
@ -370,7 +90,7 @@ def create_chat_chain(
def combine_message_chain(
messages: list[ChatMessage],
messages: list[ChatMessage] | list[PreviousMessage],
token_limit: int,
msg_limit: int | None = None,
) -> str:
@ -381,7 +101,7 @@ def combine_message_chain(
if msg_limit is not None:
messages = messages[-msg_limit:]
for message in reversed(messages):
for message in cast(list[ChatMessage] | list[PreviousMessage], reversed(messages)):
message_token_count = message.token_count
if total_token_count + message_token_count > token_limit:
@ -394,218 +114,58 @@ def combine_message_chain(
return "\n\n".join(message_strs)
_PER_MESSAGE_TOKEN_BUFFER = 7
def reorganize_citations(
answer: str, citations: list[CitationInfo]
) -> tuple[str, list[CitationInfo]]:
"""For a complete, citation-aware response, we want to reorganize the citations so that
they are in the order of the documents that were used in the response. This just looks nicer / avoids
confusion ("Why is there [7] when only 2 documents are cited?")."""
# Regular expression to find all instances of [[x]](LINK)
pattern = r"\[\[(.*?)\]\]\((.*?)\)"
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
all_citation_matches = re.findall(pattern, answer)
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
max_allowed_tokens: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = find_last_index(
all_tokens, max_prompt_tokens=max_allowed_tokens
)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def in_code_block(llm_text: str) -> bool:
count = llm_text.count(TRIPLE_BACKTICK)
return count % 2 != 0
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
llm_out = ""
max_citation_num = len(context_docs)
curr_segment = ""
prepend_bracket = False
cited_inds = set()
hold = ""
for raw_token in tokens:
if stop_stream:
next_hold = hold + raw_token
if stop_stream in next_hold:
break
if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold
new_citation_info: dict[int, CitationInfo] = {}
for citation_match in all_citation_matches:
try:
citation_num = int(citation_match[0])
if citation_num in new_citation_info:
continue
token = next_hold
hold = ""
else:
token = raw_token
matching_citation = next(
iter([c for c in citations if c.citation_num == int(citation_num)]),
None,
)
if matching_citation is None:
continue
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
new_citation_info[citation_num] = CitationInfo(
citation_num=len(new_citation_info) + 1,
document_id=matching_citation.document_id,
)
except Exception:
pass
curr_segment += token
llm_out += token
# Function to replace citations with their new number
def slack_link_format(match: re.Match) -> str:
link_text = match.group(1)
try:
citation_num = int(link_text)
if citation_num in new_citation_info:
link_text = new_citation_info[citation_num].citation_num
except Exception:
pass
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
link_url = match.group(2)
return f"[[{link_text}]]({link_url})"
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
# Substitute all matches in the input text
new_answer = re.sub(pattern, slack_link_format, answer)
if citation_found and not in_code_block(llm_out):
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[
numerical_value - 1
] # remove 1 index offset
# if any citations weren't parsable, just add them back to be safe
for citation in citations:
if citation.citation_num not in new_citation_info:
new_citation_info[citation.citation_num] = citation
link = context_llm_doc.link
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
# Use the citation number for the document's rank in
# the search (or selected docs) results
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)
def get_prompt_tokens(prompt: Prompt) -> int:
return (
check_number_of_tokens(prompt.system_prompt)
+ check_number_of_tokens(prompt.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
)
# buffer just to be safe so that we don't overflow the token limit due to
# a small miscalculation
_MISC_BUFFER = 40
def compute_max_document_tokens(
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
"""Estimates the number of tokens available for context documents. Formula is roughly:
(
model_context_window - reserved_output_tokens - prompt_tokens
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
)
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
if we're trying to determine if the user should be able to select another document) then we just set an
arbitrary "upper bound".
"""
llm_name = GEN_AI_MODEL_VERSION
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
# if we can't find a number of tokens, just assume some common default
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(model_name=llm_name)
)
if persona.prompts:
# TODO this may not always be the first prompt
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
raise RuntimeError("Persona has no prompts - this should never happen")
user_input_tokens = (
check_number_of_tokens(actual_user_input)
if actual_user_input is not None
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
)
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
def compute_max_llm_input_tokens(persona: Persona) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
llm_name = GEN_AI_MODEL_VERSION
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
input_tokens = get_max_input_tokens(model_name=llm_name)
return input_tokens - _MISC_BUFFER
return new_answer, list(new_citation_info.values())

View File

@ -13,7 +13,7 @@ from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Prompt as PromptDBModel
from danswer.search.models import RecencyBiasSetting
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
@ -24,7 +24,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user_id=None,
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
@ -34,7 +34,6 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
shared=True,
db_session=db_session,
commit=True,
)
@ -67,9 +66,7 @@ def load_personas_from_yaml(
prompts: list[PromptDBModel | None] | None = None
else:
prompts = [
get_prompt_by_name(
prompt_name, user_id=None, shared=True, db_session=db_session
)
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
@ -78,22 +75,26 @@ def load_personas_from_yaml(
if not prompts:
prompts = None
p_id = persona.get("id")
upsert_persona(
user_id=None,
persona_id=persona.get("id"),
user=None,
# Negative to not conflict with existing personas
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
default_persona=True,
shared=True,
is_public=True,
db_session=db_session,
)

View File

@ -5,10 +5,10 @@ from typing import Any
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.search.models import QueryFlow
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.search.models import SearchType
class LlmDoc(BaseModel):
@ -16,11 +16,13 @@ class LlmDoc(BaseModel):
document_id: str
content: str
blurb: str
semantic_identifier: str
source_type: DocumentSource
metadata: dict[str, str | list[str]]
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
# First chunk of info for streaming QA
@ -100,9 +102,21 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None
AnswerQuestionStreamReturn = Iterator[
DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError
]
class ImageGenerationDisplay(BaseModel):
file_ids: list[str]
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| ImageGenerationDisplay
| StreamingError
)
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):

View File

@ -5,9 +5,9 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Default"
name: "Danswer"
description: >
Default Danswer Question Answering functionality.
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml
prompts:
- "Answer-Question"
@ -39,22 +39,23 @@ personas:
document_sets: []
- name: "Summarize"
- id: 1
name: "GPT"
description: >
A less creative assistant which summarizes relevant documents but does not try to
extrapolate any answers for you.
Assistant with no access to documents. Chat with just the Language Model.
prompts:
- "Summarize"
num_chunks: 10
- "OnlyLLM"
num_chunks: 0
llm_relevance_filter: true
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
- name: "Paraphrase"
- id: 2
name: "Paraphrase"
description: >
The least creative default assistant that only provides quotes from the documents.
Assistant that is heavily constrained and only provides exact quotes from Connected Sources.
prompts:
- "Paraphrase"
num_chunks: 10

Some files were not shown because too many files have changed in this diff Show More