141 lines
3.8 KiB
Go

package trending
import (
"database/sql"
"fmt"
"time"
)
// Schema version to track database migrations
const SchemaVersion = 1
// CreateTablesIfNotExist ensures that all necessary database tables for the trending system exist
func CreateTablesIfNotExist(db *sql.DB) error {
// Create the trending_history table if it doesn't exist
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS trending_history (
id SERIAL PRIMARY KEY,
calculation_time TIMESTAMPTZ NOT NULL,
kind INTEGER NOT NULL,
trending_data JSONB NOT NULL
)
`)
if err != nil {
return fmt.Errorf("failed to create trending_history table: %w", err)
}
// Create an index on calculation_time and kind for faster queries
_, err = db.Exec(`
CREATE INDEX IF NOT EXISTS idx_trending_history_time_kind
ON trending_history (calculation_time DESC, kind)
`)
if err != nil {
return fmt.Errorf("failed to create index on trending_history: %w", err)
}
// Create schema version table if it doesn't exist
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS trending_schema_version (
version INTEGER PRIMARY KEY,
updated_at TIMESTAMPTZ NOT NULL
)
`)
if err != nil {
return fmt.Errorf("failed to create trending_schema_version table: %w", err)
}
// Check if we need to initialize the schema version
var count int
err = db.QueryRow(`SELECT COUNT(*) FROM trending_schema_version`).Scan(&count)
if err != nil {
return fmt.Errorf("failed to check trending_schema_version: %w", err)
}
if count == 0 {
_, err = db.Exec(`
INSERT INTO trending_schema_version (version, updated_at)
VALUES ($1, $2)
`, SchemaVersion, time.Now())
if err != nil {
return fmt.Errorf("failed to initialize trending_schema_version: %w", err)
}
}
return nil
}
// GetLatestTrendingFromHistory retrieves the most recent trending data for the specified kind
func GetLatestTrendingFromHistory(db *sql.DB, kind int) ([]Post, time.Time, error) {
var (
trendingData []byte
calculationTime time.Time
)
err := db.QueryRow(`
SELECT trending_data, calculation_time
FROM trending_history
WHERE kind = $1
ORDER BY calculation_time DESC
LIMIT 1
`, kind).Scan(&trendingData, &calculationTime)
if err != nil {
if err == sql.ErrNoRows {
return []Post{}, time.Time{}, nil
}
return nil, time.Time{}, fmt.Errorf("failed to get latest trending data: %w", err)
}
posts, err := UnmarshalPosts(trendingData)
if err != nil {
return nil, time.Time{}, fmt.Errorf("failed to unmarshal trending posts: %w", err)
}
return posts, calculationTime, nil
}
// GetTrendingHistoryForKind retrieves trending history for the specified kind
// limit defines how many records to return, offset is for pagination
func GetTrendingHistoryForKind(db *sql.DB, kind int, limit, offset int) ([]TrendingHistoryEntry, error) {
rows, err := db.Query(`
SELECT id, calculation_time, trending_data
FROM trending_history
WHERE kind = $1
ORDER BY calculation_time DESC
LIMIT $2 OFFSET $3
`, kind, limit, offset)
if err != nil {
return nil, fmt.Errorf("failed to query trending history: %w", err)
}
defer rows.Close()
var entries []TrendingHistoryEntry
for rows.Next() {
var (
entry TrendingHistoryEntry
trendingData []byte
)
err := rows.Scan(&entry.ID, &entry.CalculationTime, &trendingData)
if err != nil {
return nil, fmt.Errorf("failed to scan trending history entry: %w", err)
}
entry.Posts, err = UnmarshalPosts(trendingData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal trending posts: %w", err)
}
entries = append(entries, entry)
}
return entries, nil
}
// TrendingHistoryEntry represents a historical record of trending data
type TrendingHistoryEntry struct {
ID int `json:"id"`
CalculationTime time.Time `json:"calculation_time"`
Posts []Post `json:"posts"`
}