diff --git a/indexer/indexer.go b/indexer/indexer.go index a1abe725..122113b7 100644 --- a/indexer/indexer.go +++ b/indexer/indexer.go @@ -7,6 +7,7 @@ import ( "api.audius.co/config" dbv1 "api.audius.co/database" + "api.audius.co/jobs" "api.audius.co/logging" "connectrpc.com/connect" corev1 "github.com/OpenAudio/go-openaudio/pkg/api/core/v1" @@ -66,9 +67,35 @@ func (ci *CoreIndexer) Start(ctx context.Context) error { eg.Go(func() error { return ci.run(ctx) }) + ci.startParityJobs(ctx) return eg.Wait() } +// startParityJobs schedules the periodic jobs that mirror what the legacy +// Python discovery-provider celery beat used to run. Each job's ScheduleEvery +// launches its own goroutine and exits when ctx is cancelled, so we don't +// need to add them to the errgroup — they self-manage. +// +// Intervals match apps' celery beat_schedule in src/app.py where applicable. +// update_delist_statuses isn't in apps' beat (apps invokes it externally), +// so we pick a conservative default. +func (ci *CoreIndexer) startParityJobs(ctx context.Context) { + jobs.NewHourlyPlayCountsJob(ci.Config, ci.pool). + ScheduleEvery(ctx, 30*time.Second) + + jobs.NewPrunePlaysJob(ci.Config, ci.pool). + ScheduleEvery(ctx, 30*time.Second) + + jobs.NewUserListeningHistoryJob(ci.Config, ci.pool). + ScheduleEvery(ctx, 5*time.Second) + + jobs.NewTrendingJob(ci.Config, ci.pool). + ScheduleEvery(ctx, 10*time.Second) + + jobs.NewUpdateDelistStatusesJob(ci.Config, ci.pool). + ScheduleEvery(ctx, 5*time.Minute) +} + func (ci *CoreIndexer) run(ctx context.Context) error { go logging.SyncOnTicks(ctx, ci.logger, time.Second*10) var height int64 diff --git a/jobs/delegate_auth.go b/jobs/delegate_auth.go new file mode 100644 index 00000000..5a94dbab --- /dev/null +++ b/jobs/delegate_auth.go @@ -0,0 +1,95 @@ +package jobs + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/ethereum/go-ethereum/crypto" +) + +// signedHTTPGet performs an HTTP GET signed with a delegate private key. +// Mirrors apps' src/utils/auth_helpers.py:signed_get(): +// +// timestamp = round(time.time() * 1000) +// signature = sign(timestamp, private_key) +// nonce = f"{timestamp}:{signature.hex()}" +// Authorization: Basic base64(nonce) +// +// where sign() does: +// digest = keccak(timestamp_text) +// signature = ETH-personal-sign(digest_hex, private_key) +// +// `delegatePrivateKey` is the 0x-prefixed (or unprefixed) hex string from +// config.Cfg.DelegatePrivateKey. +func signedHTTPGet(client *http.Client, url, delegatePrivateKey string) (*http.Response, error) { + auth, err := basicAuthNonce(delegatePrivateKey, time.Now()) + if err != nil { + return nil, fmt.Errorf("build auth: %w", err) + } + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", auth) + + if client == nil { + client = http.DefaultClient + } + return client.Do(req) +} + +// drainResponse closes the body of a non-2xx response and returns an error +// including the status + first 512 bytes of body for context. +func drainResponse(resp *http.Response) error { + defer resp.Body.Close() + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) +} + +// basicAuthNonce returns the Authorization header value used by apps' +// signed_get. Exported through signedHTTPGet but extracted for testing. +func basicAuthNonce(delegatePrivateKey string, now time.Time) (string, error) { + pkHex := strings.TrimPrefix(delegatePrivateKey, "0x") + pkBytes, err := hex.DecodeString(pkHex) + if err != nil { + return "", fmt.Errorf("decode private key: %w", err) + } + pk, err := crypto.ToECDSA(pkBytes) + if err != nil { + return "", fmt.Errorf("parse private key: %w", err) + } + + // Python: timestamp = round(time.time() * 1000) + tsStr := fmt.Sprintf("%d", now.UnixMilli()) + + // Python: digest = Web3.keccak(text=timestamp_str).hex() + digestBytes := crypto.Keccak256([]byte(tsStr)) + + // Python: encode_defunct(hexstr=digest_hex) -> EIP-191 personal-sign envelope + // over the BYTES that digest_hex decodes to (i.e. the raw keccak digest). + // We replicate that here. + prefix := fmt.Sprintf("\x19Ethereum Signed Message:\n%d", len(digestBytes)) + personalDigest := crypto.Keccak256(append([]byte(prefix), digestBytes...)) + + sig, err := crypto.Sign(personalDigest, pk) + if err != nil { + return "", fmt.Errorf("sign: %w", err) + } + // go-ethereum's crypto.Sign returns recovery id 0 or 1; web3.py returns + // 27 or 28. Normalize to match apps. + if sig[64] < 27 { + sig[64] += 27 + } + + nonce := tsStr + ":" + hex.EncodeToString(sig) + return "Basic " + base64.StdEncoding.EncodeToString([]byte(nonce)), nil +} diff --git a/jobs/delegate_auth_test.go b/jobs/delegate_auth_test.go new file mode 100644 index 00000000..91c443e0 --- /dev/null +++ b/jobs/delegate_auth_test.go @@ -0,0 +1,40 @@ +package jobs + +import ( + "encoding/base64" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestBasicAuthNonce_Shape verifies the Authorization header has the +// expected "Basic base64(:)" structure and decodes cleanly. +// Cross-format verification against Python apps is left to integration +// testing — the signing primitives (keccak + eth personal-sign) are +// covered by go-ethereum's own tests. +func TestBasicAuthNonce_Shape(t *testing.T) { + // Test key from api/'s default dev config. + key := "13422b9affd75ff80f94f1ea394e6a6097830cb58cda2d3542f37464ecaee7df" + now := time.UnixMilli(1700000000000) + + auth, err := basicAuthNonce(key, now) + require.NoError(t, err) + + require.True(t, strings.HasPrefix(auth, "Basic ")) + decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic ")) + require.NoError(t, err) + + parts := strings.SplitN(string(decoded), ":", 2) + require.Len(t, parts, 2) + assert.Equal(t, "1700000000000", parts[0]) + // signature is 65 bytes = 130 hex chars + assert.Len(t, parts[1], 130) +} + +func TestBasicAuthNonce_BadKey(t *testing.T) { + _, err := basicAuthNonce("not-hex", time.Now()) + assert.Error(t, err) +} diff --git a/jobs/index_hourly_play_counts.go b/jobs/index_hourly_play_counts.go new file mode 100644 index 00000000..2604ef64 --- /dev/null +++ b/jobs/index_hourly_play_counts.go @@ -0,0 +1,152 @@ +package jobs + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "api.audius.co/config" + "api.audius.co/database" + "api.audius.co/logging" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" +) + +// HourlyPlayCountsJob rolls up plays into the hourly_play_counts table. +// Mirrors apps/packages/discovery-provider/src/tasks/index_hourly_play_counts.py. +// +// On each run: +// - reads the last processed plays.id checkpoint from indexing_checkpoints, +// - buckets newer plays by date_trunc('hour', created_at), +// - upserts into hourly_play_counts with sum-on-conflict, +// - advances the checkpoint to max(plays.id) seen. +// +// hourly_play_counts powers GET /v1/metrics/plays. If this job stops running +// the endpoint returns an empty array (it doesn't error). +type HourlyPlayCountsJob struct { + pool database.DbPool + logger *zap.Logger + + mutex sync.Mutex + isRunning bool +} + +// HourlyPlayCountsCheckpoint is the indexing_checkpoints.tablename used to +// remember the highest plays.id already rolled up. +const HourlyPlayCountsCheckpoint = "hourly_play_counts" + +func NewHourlyPlayCountsJob(cfg config.Config, pool database.DbPool) *HourlyPlayCountsJob { + return &HourlyPlayCountsJob{ + pool: pool, + logger: logging.NewZapLogger(cfg).Named("HourlyPlayCountsJob"), + } +} + +// ScheduleEvery runs the job every `interval` until the context is cancelled. +func (j *HourlyPlayCountsJob) ScheduleEvery(ctx context.Context, interval time.Duration) *HourlyPlayCountsJob { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + j.Run(ctx) + case <-ctx.Done(): + j.logger.Info("Job shutting down") + return + } + } + }() + return j +} + +// Run executes the job once. +func (j *HourlyPlayCountsJob) Run(ctx context.Context) { + if err := j.run(ctx); err != nil { + j.logger.Error("Job run failed", zap.Error(err)) + } +} + +func (j *HourlyPlayCountsJob) run(ctx context.Context) error { + j.mutex.Lock() + if j.isRunning { + j.mutex.Unlock() + return fmt.Errorf("job is already running") + } + j.isRunning = true + j.mutex.Unlock() + defer func() { + j.mutex.Lock() + j.isRunning = false + j.mutex.Unlock() + }() + + prev, err := getCheckpoint(ctx, j.pool, HourlyPlayCountsCheckpoint) + if err != nil { + return fmt.Errorf("read checkpoint: %w", err) + } + + var newMax *int64 + err = j.pool.QueryRow(ctx, "SELECT MAX(id) FROM plays").Scan(&newMax) + if err != nil { + return fmt.Errorf("read max(plays.id): %w", err) + } + if newMax == nil || *newMax == prev { + j.logger.Debug("No new plays since last run") + return nil + } + + // Bucket and upsert in a single statement: the SUM-on-conflict means we + // can safely process the same row twice (idempotent on full batches), + // but the checkpoint advances per run so the steady-state cost stays + // bounded. + res, err := j.pool.Exec(ctx, ` + INSERT INTO hourly_play_counts (hourly_timestamp, play_count) + SELECT date_trunc('hour', created_at) AS hourly_timestamp, + COUNT(*) AS play_count + FROM plays + WHERE id > @prev AND id <= @new_max + GROUP BY date_trunc('hour', created_at) + ON CONFLICT (hourly_timestamp) + DO UPDATE SET play_count = hourly_play_counts.play_count + EXCLUDED.play_count + `, pgx.NamedArgs{ + "prev": prev, + "new_max": *newMax, + }) + if err != nil { + return fmt.Errorf("upsert hourly buckets: %w", err) + } + + if err := saveCheckpoint(ctx, j.pool, HourlyPlayCountsCheckpoint, *newMax); err != nil { + return fmt.Errorf("save checkpoint: %w", err) + } + + j.logger.Info("Hourly play counts updated", + zap.Int64("prev_checkpoint", prev), + zap.Int64("new_checkpoint", *newMax), + zap.Int64("hours_touched", res.RowsAffected())) + return nil +} + +// getCheckpoint reads the named checkpoint value (last_checkpoint column), +// returning 0 when the row does not yet exist. +func getCheckpoint(ctx context.Context, pool database.DbPool, name string) (int64, error) { + var v int64 + err := pool.QueryRow(ctx, "SELECT last_checkpoint FROM indexing_checkpoints WHERE tablename = $1", name).Scan(&v) + if errors.Is(err, pgx.ErrNoRows) { + return 0, nil + } + return v, err +} + +// saveCheckpoint upserts the named checkpoint. +func saveCheckpoint(ctx context.Context, pool database.DbPool, name string, value int64) error { + _, err := pool.Exec(ctx, ` + INSERT INTO indexing_checkpoints (tablename, last_checkpoint) + VALUES ($1, $2) + ON CONFLICT (tablename) DO UPDATE SET last_checkpoint = EXCLUDED.last_checkpoint + `, name, value) + return err +} diff --git a/jobs/index_hourly_play_counts_test.go b/jobs/index_hourly_play_counts_test.go new file mode 100644 index 00000000..e38ea4f5 --- /dev/null +++ b/jobs/index_hourly_play_counts_test.go @@ -0,0 +1,89 @@ +package jobs + +import ( + "context" + "testing" + "time" + + "api.audius.co/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestHourlyPlayCountsJob_FullPipeline verifies the job buckets new plays +// into the right hour, advances the checkpoint, and is idempotent across +// runs (subsequent runs with no new plays don't change anything). +func TestHourlyPlayCountsJob_FullPipeline(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_jobs") + defer pool.Close() + + ctx := context.Background() + t1 := time.Date(2025, 1, 15, 10, 17, 30, 0, time.UTC) // bucketed to 10:00 + t2 := time.Date(2025, 1, 15, 10, 45, 0, 0, time.UTC) // same 10:00 bucket + t3 := time.Date(2025, 1, 15, 11, 5, 0, 0, time.UTC) // 11:00 bucket + + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": 1, "wallet": "0x01"}, + }, + "tracks": { + {"track_id": 100, "owner_id": 1, "title": "T"}, + }, + "plays": { + {"id": 1, "user_id": 1, "play_item_id": 100, "created_at": t1}, + {"id": 2, "user_id": 1, "play_item_id": 100, "created_at": t2}, + {"id": 3, "user_id": 1, "play_item_id": 100, "created_at": t3}, + }, + }) + + job := NewHourlyPlayCountsJob(newTestConfig(), pool) + + require.NoError(t, job.run(ctx)) + + type row struct { + Ts time.Time + Count int32 + } + var rows []row + r, err := pool.Query(ctx, "SELECT hourly_timestamp, play_count FROM hourly_play_counts ORDER BY hourly_timestamp") + require.NoError(t, err) + for r.Next() { + var x row + require.NoError(t, r.Scan(&x.Ts, &x.Count)) + rows = append(rows, x) + } + r.Close() + + require.Len(t, rows, 2, "expected 2 hourly buckets") + assert.Equal(t, time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC), rows[0].Ts.UTC()) + assert.Equal(t, int32(2), rows[0].Count) + assert.Equal(t, time.Date(2025, 1, 15, 11, 0, 0, 0, time.UTC), rows[1].Ts.UTC()) + assert.Equal(t, int32(1), rows[1].Count) + + // Checkpoint advanced to max(id). + cp, err := getCheckpoint(ctx, pool, HourlyPlayCountsCheckpoint) + require.NoError(t, err) + assert.Equal(t, int64(3), cp) + + // Re-running with no new plays must be a no-op (counts unchanged). + require.NoError(t, job.run(ctx)) + var firstBucketCount int32 + require.NoError(t, pool.QueryRow(ctx, + "SELECT play_count FROM hourly_play_counts WHERE hourly_timestamp = $1", + time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC), + ).Scan(&firstBucketCount)) + assert.Equal(t, int32(2), firstBucketCount, "second run should not double-count") +} + +func TestHourlyPlayCountsJob_NoPlays(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_jobs") + defer pool.Close() + + job := NewHourlyPlayCountsJob(newTestConfig(), pool) + require.NoError(t, job.run(context.Background())) + + var n int + require.NoError(t, pool.QueryRow(context.Background(), + "SELECT COUNT(*) FROM hourly_play_counts").Scan(&n)) + assert.Equal(t, 0, n) +} diff --git a/jobs/index_trending.go b/jobs/index_trending.go new file mode 100644 index 00000000..fb6ca87a --- /dev/null +++ b/jobs/index_trending.go @@ -0,0 +1,377 @@ +package jobs + +import ( + "context" + "fmt" + "sync" + "time" + + "api.audius.co/config" + "api.audius.co/database" + "api.audius.co/logging" + "go.uber.org/zap" +) + +// TrendingJob recomputes trending scores. Ports the score-computation half +// of apps/packages/discovery-provider/src/tasks/index_trending.py. +// +// On each run: +// 1. Refresh aggregate_interval_plays MV (feeds week/month listen counts). +// 2. Refresh trending_params MV (per-track inputs). +// 3. For each (entity_type, version), run a delete-then-bulk-insert into +// track_trending_scores / playlist_trending_scores using the strategy's +// SQL template. +// +// What is intentionally NOT ported here: +// - Trending notifications (top-10 mover diff against the notification +// table). Lands with the challenges/notifications work. +// - index_tastemaker (writes challenge events). Depends on the challenge +// bus, which is a separate effort. +// +// Trending score templates are copied verbatim from +// apps/packages/discovery-provider/src/trending_strategies/{pnagD,AnlGe}_*.py +// so the scoring stays bit-identical to discovery. +type TrendingJob struct { + pool database.DbPool + logger *zap.Logger + + mutex sync.Mutex + isRunning bool +} + +func NewTrendingJob(cfg config.Config, pool database.DbPool) *TrendingJob { + return &TrendingJob{ + pool: pool, + logger: logging.NewZapLogger(cfg).Named("TrendingJob"), + } +} + +// ScheduleEvery runs the job every `interval` until the context is cancelled. +func (j *TrendingJob) ScheduleEvery(ctx context.Context, interval time.Duration) *TrendingJob { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + j.Run(ctx) + case <-ctx.Done(): + j.logger.Info("Job shutting down") + return + } + } + }() + return j +} + +// Run executes the job once. +func (j *TrendingJob) Run(ctx context.Context) { + if err := j.run(ctx); err != nil { + j.logger.Error("Job run failed", zap.Error(err)) + } +} + +func (j *TrendingJob) run(ctx context.Context) error { + j.mutex.Lock() + if j.isRunning { + j.mutex.Unlock() + return fmt.Errorf("job is already running") + } + j.isRunning = true + j.mutex.Unlock() + defer func() { + j.mutex.Lock() + j.isRunning = false + j.mutex.Unlock() + }() + + start := time.Now() + + for _, mv := range []string{"aggregate_interval_plays", "trending_params"} { + if _, err := j.pool.Exec(ctx, "REFRESH MATERIALIZED VIEW "+mv); err != nil { + return fmt.Errorf("refresh %s: %w", mv, err) + } + } + + if err := j.computeTrendingTracks(ctx, "TRACKS", "pnagD", trackParamsPnagD); err != nil { + return fmt.Errorf("trending tracks pnagD: %w", err) + } + if err := j.computeTrendingTracks(ctx, "TRACKS", "AnlGe", trackParamsAnlGe); err != nil { + return fmt.Errorf("trending tracks AnlGe: %w", err) + } + // Underground uses the same TRACKS strategies but registers under a + // different trending_type key; apps only registers pnagD for it. + if err := j.computeTrendingTracks(ctx, "UNDERGROUND_TRACKS", "pnagD", trackParamsPnagD); err != nil { + return fmt.Errorf("trending underground tracks pnagD: %w", err) + } + + if err := j.computeTrendingPlaylists(ctx, "PLAYLISTS", "pnagD"); err != nil { + return fmt.Errorf("trending playlists pnagD: %w", err) + } + + j.logger.Info("Trending scores recomputed", zap.Duration("duration", time.Since(start))) + return nil +} + +// trackScoreParams holds the scalar weights used by a tracks strategy. The +// only behavior difference between pnagD and AnlGe is whether the trailing +// multiplier is `karma` or `(1 + LOG(1 + karma))`. +type trackScoreParams struct { + karmaExpr string +} + +var ( + trackParamsPnagD = trackScoreParams{karmaExpr: "tp.karma"} + trackParamsAnlGe = trackScoreParams{karmaExpr: "(1 + LOG(1 + tp.karma))"} +) + +// Score weights matching apps' constants in pnagD_trending_tracks_strategy.py +// (and AnlGe — they share the same constants). +const ( + trendingN = 1 + trendingF = 50 + trendingO = 1 + trendingR = 0.25 + trendingI = 0.01 + trendingQ = 100000.0 + trendingY = 3 + trendingWk = 7 + trendingMo = 30 +) + +// computeTrendingTracks runs the three-time-range score insert for one +// (trending_type, version) combination. Mirrors update_track_score_query +// from apps' track strategies. +func (j *TrendingJob) computeTrendingTracks(ctx context.Context, trendingType, version string, p trackScoreParams) error { + tx, err := j.pool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + if _, err := tx.Exec(ctx, + "DELETE FROM track_trending_scores WHERE type = $1 AND version = $2", + trendingType, version, + ); err != nil { + return err + } + + // Week + month follow the same shape; we generate one query per range. + for _, r := range []struct { + name string + days int + listenField string + repostField string + saveField string + }{ + {"week", trendingWk, "aip.week_listen_counts", "tp.repost_week_count", "tp.save_week_count"}, + {"month", trendingMo, "aip.month_listen_counts", "tp.repost_month_count", "tp.save_month_count"}, + } { + q := fmt.Sprintf(` + INSERT INTO track_trending_scores + (track_id, genre, type, version, time_range, score, created_at) + SELECT + tp.track_id, + tp.genre, + $1, $2, $3, + CASE + WHEN tp.owner_follower_count < $4 THEN 0 + WHEN EXTRACT(DAYS FROM now() - ( + CASE + WHEN tp.release_date > now() THEN aip.created_at + ELSE GREATEST(tp.release_date, aip.created_at) + END + )) > $5 THEN GREATEST( + 1.0 / $6, + POW($6, GREATEST( + -10, + 1.0 - 1.0 * EXTRACT(DAYS FROM now() - ( + CASE + WHEN tp.release_date > now() THEN aip.created_at + ELSE GREATEST(tp.release_date, aip.created_at) + END + )) / $5 + )) + ) * ( + $7 * %s + $8 * %s + $9 * %s + $10 * tp.repost_count + $11 * tp.save_count + ) * %s + ELSE ( + $7 * %s + $8 * %s + $9 * %s + $10 * tp.repost_count + $11 * tp.save_count + ) * %s + END, + now() + FROM trending_params tp + INNER JOIN aggregate_interval_plays aip ON tp.track_id = aip.track_id + `, + r.listenField, r.repostField, r.saveField, p.karmaExpr, + r.listenField, r.repostField, r.saveField, p.karmaExpr, + ) + if _, err := tx.Exec(ctx, q, + trendingType, version, r.name, + trendingY, r.days, trendingQ, + trendingN, trendingF, trendingO, trendingR, trendingI, + ); err != nil { + return fmt.Errorf("%s range: %w", r.name, err) + } + } + + // All-time uses aggregate_plays.count rather than aggregate_interval_plays. + q := fmt.Sprintf(` + INSERT INTO track_trending_scores + (track_id, genre, type, version, time_range, score, created_at) + SELECT + tp.track_id, tp.genre, $1, $2, 'allTime', + CASE + WHEN tp.owner_follower_count < $3 THEN 0 + ELSE ($4 * ap.count + $5 * tp.repost_count + $6 * tp.save_count) * %s + END, + now() + FROM trending_params tp + INNER JOIN aggregate_plays ap ON tp.track_id = ap.play_item_id + INNER JOIN tracks t ON ap.play_item_id = t.track_id + WHERE t.is_current IS TRUE + AND t.is_delete IS FALSE + AND t.is_unlisted IS FALSE + AND t.stem_of IS NULL + `, p.karmaExpr) + if _, err := tx.Exec(ctx, q, + trendingType, version, + trendingY, trendingN, trendingR, trendingI, + ); err != nil { + return fmt.Errorf("allTime range: %w", err) + } + + return tx.Commit(ctx) +} + +// computeTrendingPlaylists recomputes playlist_trending_scores for the given +// (type, version). Ports apps' pnagD playlist score template. +func (j *TrendingJob) computeTrendingPlaylists(ctx context.Context, trendingType, version string) error { + tx, err := j.pool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + if _, err := tx.Exec(ctx, + "DELETE FROM playlist_trending_scores WHERE type = $1 AND version = $2", + trendingType, version, + ); err != nil { + return err + } + + // Each time range gets its own query because the interval (':week days' + // vs ':month days' vs ':year days') is part of the WHERE clauses. + for _, r := range []struct { + name string + days int + }{ + {"week", trendingWk}, + {"month", trendingMo}, + {"year", 365}, + } { + intervalLit := fmt.Sprintf("%d days", r.days) + q := ` + INSERT INTO playlist_trending_scores + (playlist_id, type, version, time_range, score, created_at) + WITH saves_and_reposts AS ( + SELECT user_id, repost_item_id AS item_id + FROM reposts + WHERE is_current IS TRUE + AND is_delete IS FALSE + AND repost_type = 'playlist' + AND repost_item_id IN (SELECT playlist_id FROM playlists WHERE is_current IS TRUE AND is_delete IS FALSE AND is_private IS FALSE) + AND created_at >= NOW() - $1::interval + UNION ALL + SELECT user_id, save_item_id AS item_id + FROM saves + WHERE is_current IS TRUE + AND is_delete IS FALSE + AND save_type = 'playlist' + AND save_item_id IN (SELECT playlist_id FROM playlists WHERE is_current IS TRUE AND is_delete IS FALSE AND is_private IS FALSE) + AND created_at >= NOW() - $1::interval + ), + filtered_users AS ( + SELECT sr.user_id, sr.item_id + FROM saves_and_reposts sr + JOIN users u ON sr.user_id = u.user_id + WHERE (u.cover_photo IS NOT NULL OR u.cover_photo_sizes IS NOT NULL) + AND (u.profile_picture IS NOT NULL OR u.profile_picture_sizes IS NOT NULL) + AND u.bio IS NOT NULL + AND u.is_current IS TRUE + ), + karma_scores AS ( + SELECT item_id, CAST(SUM(au.follower_count) AS integer) AS karma + FROM filtered_users f + JOIN aggregate_user au ON f.user_id = au.user_id + GROUP BY item_id + ), + windowed_saves AS ( + SELECT save_item_id, COUNT(*) AS week_count + FROM saves + WHERE is_current IS TRUE + AND is_delete IS FALSE + AND save_type = 'playlist' + AND created_at > NOW() - $1::interval + GROUP BY save_item_id + ), + windowed_reposts AS ( + SELECT repost_item_id, COUNT(*) AS week_count + FROM reposts + WHERE is_current IS TRUE + AND is_delete IS FALSE + AND repost_type = 'playlist' + AND created_at > NOW() - $1::interval + GROUP BY repost_item_id + ) + SELECT + p.playlist_id, + $2, $3, $4, + CASE + WHEN au.follower_count < $5 THEN 0 + WHEN EXTRACT(DAYS FROM now() - ( + CASE WHEN p.release_date > now() THEN p.created_at + ELSE GREATEST(p.release_date, p.created_at) END + )) > $6 THEN GREATEST( + 1.0 / $7, + POW($7, GREATEST( + -10, + 1.0 - 1.0 * EXTRACT(DAYS FROM now() - ( + CASE WHEN p.release_date > now() THEN p.created_at + ELSE GREATEST(p.release_date, p.created_at) END + )) / $6 + )) + ) * ( + $8 * 1 + $9 * COALESCE(rp.week_count, 0) + $10 * COALESCE(s.week_count, 0) + + $11 * COALESCE(ap.repost_count, 0) + $12 * COALESCE(ap.save_count, 0) + ) * COALESCE(k.karma, 1) + ELSE ( + $8 * 1 + $9 * COALESCE(rp.week_count, 0) + $10 * COALESCE(s.week_count, 0) + + $11 * COALESCE(ap.repost_count, 0) + $12 * COALESCE(ap.save_count, 0) + ) * COALESCE(k.karma, 1) + END, + now() + FROM playlists p + INNER JOIN aggregate_user au ON p.playlist_owner_id = au.user_id + LEFT JOIN aggregate_playlist ap ON p.playlist_id = ap.playlist_id + LEFT JOIN windowed_saves s ON p.playlist_id = s.save_item_id + LEFT JOIN windowed_reposts rp ON p.playlist_id = rp.repost_item_id + LEFT JOIN karma_scores k ON p.playlist_id = k.item_id + WHERE p.is_current IS TRUE + AND p.is_delete IS FALSE + AND p.is_private IS FALSE + AND jsonb_array_length(p.playlist_contents->'track_ids') >= 3 + ` + if _, err := tx.Exec(ctx, q, + intervalLit, + trendingType, version, r.name, + trendingY, r.days, trendingQ, + trendingN, trendingF, trendingO, trendingR, trendingI, + ); err != nil { + return fmt.Errorf("%s range: %w", r.name, err) + } + } + + return tx.Commit(ctx) +} + diff --git a/jobs/index_trending_test.go b/jobs/index_trending_test.go new file mode 100644 index 00000000..63a5600d --- /dev/null +++ b/jobs/index_trending_test.go @@ -0,0 +1,53 @@ +package jobs + +import ( + "context" + "testing" + "time" + + "api.audius.co/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestTrendingJob_PopulatesScores is a smoke test against a seeded DB: +// after a run, track_trending_scores has rows for each (type, version, +// time_range) tuple we register. The actual score values follow apps' +// formula verbatim and aren't independently validated here — drift would +// surface in parity comparison against discovery, not in unit tests. +func TestTrendingJob_PopulatesScores(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_jobs") + defer pool.Close() + ctx := context.Background() + + now := time.Now() + database.Seed(pool, database.FixtureMap{ + "users": {{"user_id": 1, "wallet": "0x01", "name": "Alice"}}, + "tracks": {{ + "track_id": 100, "owner_id": 1, "title": "T", + "release_date": now.Add(-2 * 24 * time.Hour), + "created_at": now.Add(-2 * 24 * time.Hour), + }}, + "plays": { + {"id": 1, "user_id": 1, "play_item_id": 100, "created_at": now.Add(-1 * time.Hour)}, + }, + }) + + job := NewTrendingJob(newTestConfig(), pool) + require.NoError(t, job.run(ctx)) + + // Track scores: 3 ranges (week, month, allTime) × 2 versions (pnagD, + // AnlGe) × 2 types (TRACKS, UNDERGROUND_TRACKS via pnagD only) = + // 2 (week/month + allTime) × 3 + 3 = 9 (TRACKS pnagD: 3, TRACKS AnlGe: 3, UNDERGROUND_TRACKS pnagD: 3) + // At minimum we expect 9 rows for the seeded track. + var nTracks int + require.NoError(t, pool.QueryRow(ctx, "SELECT COUNT(*) FROM track_trending_scores WHERE track_id = 100").Scan(&nTracks)) + assert.GreaterOrEqual(t, nTracks, 9, "expected at least 9 score rows for the seeded track") + + // Re-running should not duplicate (DELETE-then-INSERT per + // (type, version)). + require.NoError(t, job.run(ctx)) + var nAfter int + require.NoError(t, pool.QueryRow(ctx, "SELECT COUNT(*) FROM track_trending_scores WHERE track_id = 100").Scan(&nAfter)) + assert.Equal(t, nTracks, nAfter, "second run must not duplicate rows") +} diff --git a/jobs/index_user_listening_history.go b/jobs/index_user_listening_history.go new file mode 100644 index 00000000..0f5675a4 --- /dev/null +++ b/jobs/index_user_listening_history.go @@ -0,0 +1,313 @@ +package jobs + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "sync" + "time" + + "api.audius.co/config" + "api.audius.co/database" + "api.audius.co/logging" + "go.uber.org/zap" +) + +// rawPlay is one row from the plays table fed into the listening-history merge. +type rawPlay struct { + ID int64 + UserID int64 + TrackID int64 + CreatedAt time.Time +} + +// UserListeningHistoryJob materializes per-user listening history from the +// plays table into user_listening_history.listening_history JSONB. +// Mirrors apps/packages/discovery-provider/src/tasks/user_listening_history/. +// +// On each run: +// - reads the last processed plays.id checkpoint from indexing_checkpoints, +// - selects the next UserListeningHistoryBatchSize plays with non-null user_id, +// - merges each user's new plays into their existing history (per-track +// latest timestamp + cumulative play_count, sorted desc by timestamp, +// capped at UserListeningHistoryTrackLimit entries), +// - rewrites the listening_history JSONB blob for each touched user, +// - advances the checkpoint to the highest play.id processed. +type UserListeningHistoryJob struct { + pool database.DbPool + logger *zap.Logger + + batchSize int + limit int + + mutex sync.Mutex + isRunning bool +} + +const ( + // UserListeningHistoryCheckpoint is the indexing_checkpoints.tablename + // used to track the highest plays.id already merged. + UserListeningHistoryCheckpoint = "user_listening_history" + // UserListeningHistoryBatchSize caps how many plays are merged per run. + UserListeningHistoryBatchSize = 100_000 + // UserListeningHistoryTrackLimit caps how many tracks each user's + // listening_history JSONB retains. + UserListeningHistoryTrackLimit = 1000 +) + +func NewUserListeningHistoryJob(cfg config.Config, pool database.DbPool) *UserListeningHistoryJob { + return &UserListeningHistoryJob{ + pool: pool, + logger: logging.NewZapLogger(cfg).Named("UserListeningHistoryJob"), + batchSize: UserListeningHistoryBatchSize, + limit: UserListeningHistoryTrackLimit, + } +} + +// ScheduleEvery runs the job every `interval` until the context is cancelled. +func (j *UserListeningHistoryJob) ScheduleEvery(ctx context.Context, interval time.Duration) *UserListeningHistoryJob { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + j.Run(ctx) + case <-ctx.Done(): + j.logger.Info("Job shutting down") + return + } + } + }() + return j +} + +// Run executes the job once. +func (j *UserListeningHistoryJob) Run(ctx context.Context) { + if err := j.run(ctx); err != nil { + j.logger.Error("Job run failed", zap.Error(err)) + } +} + +// listenEntry mirrors apps' ListenHistory.to_dict() shape so the persisted +// JSONB stays compatible with reads in apps and api/. +type listenEntry struct { + TrackID int64 `json:"track_id"` + Timestamp string `json:"timestamp"` + PlayCount int64 `json:"play_count"` +} + +func (j *UserListeningHistoryJob) run(ctx context.Context) error { + j.mutex.Lock() + if j.isRunning { + j.mutex.Unlock() + return fmt.Errorf("job is already running") + } + j.isRunning = true + j.mutex.Unlock() + defer func() { + j.mutex.Lock() + j.isRunning = false + j.mutex.Unlock() + }() + + prev, err := getCheckpoint(ctx, j.pool, UserListeningHistoryCheckpoint) + if err != nil { + return fmt.Errorf("read checkpoint: %w", err) + } + + rows, err := j.pool.Query(ctx, ` + SELECT id, user_id, play_item_id, created_at + FROM plays + WHERE id > $1 AND user_id IS NOT NULL + ORDER BY id ASC + LIMIT $2 + `, prev, j.batchSize) + if err != nil { + return fmt.Errorf("read plays: %w", err) + } + + var newPlays []rawPlay + for rows.Next() { + var p rawPlay + if err := rows.Scan(&p.ID, &p.UserID, &p.TrackID, &p.CreatedAt); err != nil { + rows.Close() + return fmt.Errorf("scan play: %w", err) + } + newPlays = append(newPlays, p) + } + rows.Close() + if err := rows.Err(); err != nil { + return fmt.Errorf("iter plays: %w", err) + } + + if len(newPlays) == 0 { + return nil + } + + // Group new plays by user. + byUser := make(map[int64][]rawPlay) + for _, p := range newPlays { + byUser[p.UserID] = append(byUser[p.UserID], p) + } + + userIDs := make([]int64, 0, len(byUser)) + for uid := range byUser { + userIDs = append(userIDs, uid) + } + + // Load existing histories for any users with new plays in one query. + existing := make(map[int64][]listenEntry, len(userIDs)) + existingRows, err := j.pool.Query(ctx, + "SELECT user_id, listening_history FROM user_listening_history WHERE user_id = ANY($1)", + userIDs) + if err != nil { + return fmt.Errorf("read existing history: %w", err) + } + for existingRows.Next() { + var uid int64 + var raw []byte + if err := existingRows.Scan(&uid, &raw); err != nil { + existingRows.Close() + return fmt.Errorf("scan history: %w", err) + } + var entries []listenEntry + if len(raw) > 0 { + if err := json.Unmarshal(raw, &entries); err != nil { + j.logger.Warn("Could not parse existing listening_history; treating as empty", + zap.Int64("user_id", uid), zap.Error(err)) + entries = nil + } + } + existing[uid] = entries + } + existingRows.Close() + if err := existingRows.Err(); err != nil { + return fmt.Errorf("iter history: %w", err) + } + + maxID := newPlays[len(newPlays)-1].ID + + // Merge per user and write back. We do per-user UPSERTs rather than a + // batched COPY because the merge is non-trivial (per-track latest + // timestamp + cumulative play_count) and a single batched write would + // require recomputing every user's merged blob in SQL. + updated := 0 + for uid, plays := range byUser { + merged := mergeListeningHistory(existing[uid], plays, j.limit) + blob, err := json.Marshal(merged) + if err != nil { + return fmt.Errorf("marshal merged history: %w", err) + } + _, err = j.pool.Exec(ctx, ` + INSERT INTO user_listening_history (user_id, listening_history) + VALUES ($1, $2::jsonb) + ON CONFLICT (user_id) DO UPDATE SET listening_history = EXCLUDED.listening_history + `, uid, blob) + if err != nil { + return fmt.Errorf("upsert history for user %d: %w", uid, err) + } + updated++ + } + + if err := saveCheckpoint(ctx, j.pool, UserListeningHistoryCheckpoint, maxID); err != nil { + return fmt.Errorf("save checkpoint: %w", err) + } + + j.logger.Info("Indexed user listening history", + zap.Int("plays_processed", len(newPlays)), + zap.Int("users_touched", updated), + zap.Int64("new_checkpoint", maxID)) + return nil +} + +// mergeListeningHistory combines an existing history with new plays for the +// same user. For each track, the entry keeps the latest timestamp and the +// cumulative play_count (existing.play_count + count of new plays for that +// track). The result is sorted by timestamp desc and capped at `limit`. +// +// Mirrors apps' separate_new_plays + sort_listening_history_desc_by_timestamp +// behavior. Exported for tests. +func mergeListeningHistory(existing []listenEntry, newPlays []rawPlay, limit int) []listenEntry { + type acc struct { + latest time.Time + latestStr string + count int64 + } + byTrack := make(map[int64]*acc) + + // Seed with existing entries. + for _, e := range existing { + t, _ := parseHistoryTimestamp(e.Timestamp) + byTrack[e.TrackID] = &acc{ + latest: t, + latestStr: e.Timestamp, + count: e.PlayCount, + } + } + + // Apply new plays. + for _, p := range newPlays { + a, ok := byTrack[p.TrackID] + if !ok { + a = &acc{} + byTrack[p.TrackID] = a + } + if p.CreatedAt.After(a.latest) { + a.latest = p.CreatedAt + a.latestStr = formatHistoryTimestamp(p.CreatedAt) + } + a.count++ + } + + out := make([]listenEntry, 0, len(byTrack)) + for trackID, a := range byTrack { + out = append(out, listenEntry{ + TrackID: trackID, + Timestamp: a.latestStr, + PlayCount: a.count, + }) + } + sort.SliceStable(out, func(i, k int) bool { + // Latest first; ties broken by track_id desc to keep deterministic order. + ti, _ := parseHistoryTimestamp(out[i].Timestamp) + tk, _ := parseHistoryTimestamp(out[k].Timestamp) + if !ti.Equal(tk) { + return ti.After(tk) + } + return out[i].TrackID > out[k].TrackID + }) + if len(out) > limit { + out = out[:limit] + } + return out +} + +// formatHistoryTimestamp matches Python's `str(datetime)` shape closely +// enough to be parseable by existing consumers. Python uses +// "YYYY-MM-DD HH:MM:SS[.ffffff][+TZ]" — Go's "2006-01-02 15:04:05.999999" +// is the equivalent for timezone-naive timestamps stored in `plays`. +func formatHistoryTimestamp(t time.Time) string { + return t.UTC().Format("2006-01-02 15:04:05.999999") +} + +// parseHistoryTimestamp accepts both the format we write and a few +// permissive variants apps might have written (with trailing TZ). +func parseHistoryTimestamp(s string) (time.Time, error) { + for _, layout := range []string{ + "2006-01-02 15:04:05.999999", + "2006-01-02 15:04:05", + "2006-01-02 15:04:05.999999-07:00", + "2006-01-02 15:04:05-07:00", + time.RFC3339Nano, + time.RFC3339, + } { + if t, err := time.Parse(layout, s); err == nil { + return t, nil + } + } + // Fall back to zero time on bad input so the entry sorts last. + return time.Time{}, errors.New("unrecognized timestamp") +} diff --git a/jobs/index_user_listening_history_test.go b/jobs/index_user_listening_history_test.go new file mode 100644 index 00000000..1ef8d403 --- /dev/null +++ b/jobs/index_user_listening_history_test.go @@ -0,0 +1,99 @@ +package jobs + +import ( + "context" + "encoding/json" + "testing" + "time" + + "api.audius.co/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUserListeningHistoryJob_InsertsFirstHistory(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_jobs") + defer pool.Close() + ctx := context.Background() + + t1 := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + t2 := time.Date(2025, 1, 15, 11, 0, 0, 0, time.UTC) + + database.Seed(pool, database.FixtureMap{ + "users": {{"user_id": 1, "wallet": "0x01"}}, + "tracks": {{"track_id": 100, "owner_id": 1, "title": "T"}, {"track_id": 101, "owner_id": 1, "title": "T2"}}, + "plays": { + {"id": 1, "user_id": 1, "play_item_id": 100, "created_at": t1}, + {"id": 2, "user_id": 1, "play_item_id": 100, "created_at": t2}, // 2 plays of track 100 + {"id": 3, "user_id": 1, "play_item_id": 101, "created_at": t1}, // 1 play of track 101 + }, + }) + + job := NewUserListeningHistoryJob(newTestConfig(), pool) + require.NoError(t, job.run(ctx)) + + var raw []byte + require.NoError(t, pool.QueryRow(ctx, + "SELECT listening_history FROM user_listening_history WHERE user_id = 1").Scan(&raw)) + var entries []listenEntry + require.NoError(t, json.Unmarshal(raw, &entries)) + + require.Len(t, entries, 2) + // Sorted desc by timestamp — track 100 last played at t2. + assert.Equal(t, int64(100), entries[0].TrackID) + assert.Equal(t, int64(2), entries[0].PlayCount) + assert.Equal(t, int64(101), entries[1].TrackID) + assert.Equal(t, int64(1), entries[1].PlayCount) +} + +func TestUserListeningHistoryJob_MergesAcrossRuns(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_jobs") + defer pool.Close() + ctx := context.Background() + + t1 := time.Date(2025, 1, 15, 10, 0, 0, 0, time.UTC) + t2 := time.Date(2025, 1, 15, 11, 0, 0, 0, time.UTC) + + database.Seed(pool, database.FixtureMap{ + "users": {{"user_id": 1, "wallet": "0x01"}}, + "tracks": {{"track_id": 100, "owner_id": 1, "title": "T"}}, + "plays": { + {"id": 1, "user_id": 1, "play_item_id": 100, "created_at": t1}, + }, + }) + + job := NewUserListeningHistoryJob(newTestConfig(), pool) + require.NoError(t, job.run(ctx)) + + // Insert a new play with id past the checkpoint. + _, err := pool.Exec(ctx, + "INSERT INTO plays (id, user_id, play_item_id, created_at) VALUES (2, 1, 100, $1)", t2) + require.NoError(t, err) + + require.NoError(t, job.run(ctx)) + + var raw []byte + require.NoError(t, pool.QueryRow(ctx, + "SELECT listening_history FROM user_listening_history WHERE user_id = 1").Scan(&raw)) + var entries []listenEntry + require.NoError(t, json.Unmarshal(raw, &entries)) + + require.Len(t, entries, 1) + assert.Equal(t, int64(100), entries[0].TrackID) + assert.Equal(t, int64(2), entries[0].PlayCount, "play_count should sum across runs") +} + +func TestMergeListeningHistory_CapsToLimit(t *testing.T) { + now := time.Now() + plays := make([]rawPlay, 0, 50) + for i := 0; i < 50; i++ { + plays = append(plays, rawPlay{ + TrackID: int64(1000 + i), + CreatedAt: now.Add(time.Duration(i) * time.Minute), + }) + } + merged := mergeListeningHistory(nil, plays, 10) + assert.Len(t, merged, 10) + // Newest plays first. + assert.Equal(t, int64(1049), merged[0].TrackID) +} diff --git a/jobs/prune_plays.go b/jobs/prune_plays.go new file mode 100644 index 00000000..ab727b65 --- /dev/null +++ b/jobs/prune_plays.go @@ -0,0 +1,135 @@ +package jobs + +import ( + "context" + "fmt" + "sync" + "time" + + "api.audius.co/config" + "api.audius.co/database" + "api.audius.co/logging" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" +) + +// PrunePlaysJob deletes old rows from the plays table to bound table size. +// Mirrors apps/packages/discovery-provider/src/tasks/prune_plays.py. +// +// Each run deletes up to PrunePlaysBatchSize plays older than +// (now - PrunePlaysAge). At the apps default schedule (30s interval), +// this can prune ~144k plays per hour. +// +// Important: deletes happen in oldest-first order so backlog is drained +// monotonically; if the job lags, the next run picks up where it left off. +type PrunePlaysJob struct { + pool database.DbPool + logger *zap.Logger + + // age is how old a play has to be before it's eligible for pruning. + // Configurable to make testing reasonable; production uses + // PrunePlaysAge. + age time.Duration + // batchSize caps the per-run delete. Configurable for the same reason. + batchSize int + + mutex sync.Mutex + isRunning bool +} + +const ( + // PrunePlaysAge matches apps: rows older than ~400 days are pruned. + PrunePlaysAge = 400 * 24 * time.Hour + // PrunePlaysBatchSize matches apps' DEFAULT_MAX_BATCH. + PrunePlaysBatchSize = 50_000 +) + +func NewPrunePlaysJob(cfg config.Config, pool database.DbPool) *PrunePlaysJob { + return &PrunePlaysJob{ + pool: pool, + logger: logging.NewZapLogger(cfg).Named("PrunePlaysJob"), + age: PrunePlaysAge, + batchSize: PrunePlaysBatchSize, + } +} + +// WithAge overrides the prune age (for tests). +func (j *PrunePlaysJob) WithAge(age time.Duration) *PrunePlaysJob { + j.age = age + return j +} + +// WithBatchSize overrides the per-run cap (for tests). +func (j *PrunePlaysJob) WithBatchSize(n int) *PrunePlaysJob { + j.batchSize = n + return j +} + +// ScheduleEvery runs the job every `interval` until the context is cancelled. +func (j *PrunePlaysJob) ScheduleEvery(ctx context.Context, interval time.Duration) *PrunePlaysJob { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + j.Run(ctx) + case <-ctx.Done(): + j.logger.Info("Job shutting down") + return + } + } + }() + return j +} + +// Run executes the job once. +func (j *PrunePlaysJob) Run(ctx context.Context) { + if err := j.run(ctx); err != nil { + j.logger.Error("Job run failed", zap.Error(err)) + } +} + +func (j *PrunePlaysJob) run(ctx context.Context) error { + j.mutex.Lock() + if j.isRunning { + j.mutex.Unlock() + return fmt.Errorf("job is already running") + } + j.isRunning = true + j.mutex.Unlock() + defer func() { + j.mutex.Lock() + j.isRunning = false + j.mutex.Unlock() + }() + + cutoff := time.Now().Add(-j.age) + + res, err := j.pool.Exec(ctx, ` + WITH to_prune AS ( + SELECT id + FROM plays + WHERE created_at <= @cutoff + ORDER BY created_at ASC + LIMIT @batch + ) + DELETE FROM plays + WHERE id IN (SELECT id FROM to_prune) + `, pgx.NamedArgs{ + "cutoff": cutoff, + "batch": j.batchSize, + }) + if err != nil { + return fmt.Errorf("delete plays: %w", err) + } + + deleted := res.RowsAffected() + if deleted > 0 { + j.logger.Info("Pruned plays", + zap.Int64("deleted", deleted), + zap.Time("cutoff", cutoff), + zap.Int("batch_size", j.batchSize)) + } + return nil +} diff --git a/jobs/prune_plays_test.go b/jobs/prune_plays_test.go new file mode 100644 index 00000000..48a8a9b3 --- /dev/null +++ b/jobs/prune_plays_test.go @@ -0,0 +1,80 @@ +package jobs + +import ( + "context" + "testing" + "time" + + "api.audius.co/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrunePlaysJob_DeletesOnlyOld(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_jobs") + defer pool.Close() + ctx := context.Background() + + old1 := time.Now().Add(-500 * 24 * time.Hour) + old2 := time.Now().Add(-450 * 24 * time.Hour) + recent := time.Now().Add(-1 * time.Hour) + + database.Seed(pool, database.FixtureMap{ + "users": { + {"user_id": 1, "wallet": "0x01"}, + }, + "tracks": { + {"track_id": 100, "owner_id": 1, "title": "T"}, + }, + "plays": { + {"id": 1, "user_id": 1, "play_item_id": 100, "created_at": old1}, + {"id": 2, "user_id": 1, "play_item_id": 100, "created_at": old2}, + {"id": 3, "user_id": 1, "play_item_id": 100, "created_at": recent}, + }, + }) + + job := NewPrunePlaysJob(newTestConfig(), pool). + WithAge(400 * 24 * time.Hour). + WithBatchSize(100) + + require.NoError(t, job.run(ctx)) + + var remaining []int64 + r, err := pool.Query(ctx, "SELECT id FROM plays ORDER BY id") + require.NoError(t, err) + for r.Next() { + var id int64 + require.NoError(t, r.Scan(&id)) + remaining = append(remaining, id) + } + r.Close() + + assert.Equal(t, []int64{3}, remaining, "only the recent play should remain") +} + +func TestPrunePlaysJob_RespectsBatchSize(t *testing.T) { + pool := database.CreateTestDatabase(t, "test_jobs") + defer pool.Close() + ctx := context.Background() + + old := time.Now().Add(-500 * 24 * time.Hour) + database.Seed(pool, database.FixtureMap{ + "users": {{"user_id": 1, "wallet": "0x01"}}, + "tracks": {{"track_id": 100, "owner_id": 1, "title": "T"}}, + "plays": { + {"id": 1, "user_id": 1, "play_item_id": 100, "created_at": old}, + {"id": 2, "user_id": 1, "play_item_id": 100, "created_at": old.Add(time.Second)}, + {"id": 3, "user_id": 1, "play_item_id": 100, "created_at": old.Add(2 * time.Second)}, + }, + }) + + job := NewPrunePlaysJob(newTestConfig(), pool). + WithAge(400 * 24 * time.Hour). + WithBatchSize(2) + + require.NoError(t, job.run(ctx)) + + var n int + require.NoError(t, pool.QueryRow(ctx, "SELECT COUNT(*) FROM plays").Scan(&n)) + assert.Equal(t, 1, n, "1 row should remain after deleting 2") +} diff --git a/jobs/update_delist_statuses.go b/jobs/update_delist_statuses.go new file mode 100644 index 00000000..921b9bc9 --- /dev/null +++ b/jobs/update_delist_statuses.go @@ -0,0 +1,499 @@ +package jobs + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "api.audius.co/config" + "api.audius.co/database" + "api.audius.co/logging" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" +) + +// UpdateDelistStatusesJob polls a trusted-notifier endpoint for delist +// status updates and applies them locally. Mirrors apps' +// src/tasks/update_delist_statuses.py. +// +// For each entity type (users, tracks) the job: +// 1. Reads the per-host/entity cursor from delist_status_cursor. +// 2. GETs statuses/?batchSize=...&cursor=... with a delegate +// signed Authorization header. +// 3. Inserts the raw status rows into user_delist_statuses / +// track_delist_statuses (audit log). +// 4. Applies the most-recent status per id to users.is_available / +// tracks.is_available + is_delete / is_deactivated, holding back any +// status whose target row hasn't been indexed yet (matches apps' +// wait-for-indexing logic). +// 5. Advances the cursor to the last `createdAt` from the batch. +// +// TrustedNotifierEndpoint is currently hardcoded to notifier.audius.co per +// the agreed scoping; promote to a config knob if/when multi-notifier +// support is needed. +type UpdateDelistStatusesJob struct { + pool database.DbPool + logger *zap.Logger + endpoint string + delegateKey string + httpClient *http.Client + + mutex sync.Mutex + isRunning bool +} + +const ( + // TrustedNotifierEndpoint is the canonical trusted notifier base URL. + // apps reads this from config; we hardcode it for now and revisit if + // multi-notifier support is needed. + TrustedNotifierEndpoint = "https://notifier.audius.co/" + // delistBatchSize matches apps' DELIST_BATCH_SIZE. + delistBatchSize = 5000 +) + +func NewUpdateDelistStatusesJob(cfg config.Config, pool database.DbPool) *UpdateDelistStatusesJob { + return &UpdateDelistStatusesJob{ + pool: pool, + logger: logging.NewZapLogger(cfg).Named("UpdateDelistStatusesJob"), + endpoint: TrustedNotifierEndpoint, + delegateKey: cfg.DelegatePrivateKey, + httpClient: &http.Client{Timeout: 30 * time.Second}, + } +} + +// ScheduleEvery runs the job every `interval` until the context is cancelled. +func (j *UpdateDelistStatusesJob) ScheduleEvery(ctx context.Context, interval time.Duration) *UpdateDelistStatusesJob { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + j.Run(ctx) + case <-ctx.Done(): + j.logger.Info("Job shutting down") + return + } + } + }() + return j +} + +// Run executes the job once. +func (j *UpdateDelistStatusesJob) Run(ctx context.Context) { + if err := j.run(ctx); err != nil { + j.logger.Error("Job run failed", zap.Error(err)) + } +} + +func (j *UpdateDelistStatusesJob) run(ctx context.Context) error { + if j.delegateKey == "" { + j.logger.Debug("Skipping: delegate private key not configured") + return nil + } + + j.mutex.Lock() + if j.isRunning { + j.mutex.Unlock() + return fmt.Errorf("job is already running") + } + j.isRunning = true + j.mutex.Unlock() + defer func() { + j.mutex.Lock() + j.isRunning = false + j.mutex.Unlock() + }() + + // apps uses the latest indexed block timestamp as the "current block + // timestamp" to decide whether to wait for missing entities. Without + // that hook here we just use now() — equivalent behavior when the + // indexer is caught up, and slightly more eager when it isn't (we + // skip waiting for entities whose delist was created before now, + // which is the safe default). + currentBlockTime := time.Now() + + for _, entity := range []string{"users", "tracks"} { + if err := j.processEntity(ctx, entity, currentBlockTime); err != nil { + j.logger.Error("processing delist entity failed", + zap.String("entity", entity), zap.Error(err)) + } + } + return nil +} + +type delistUserResp struct { + Result struct { + Users []delistUser `json:"users"` + } `json:"result"` +} + +type delistTrackResp struct { + Result struct { + Tracks []delistTrack `json:"tracks"` + } `json:"result"` +} + +type delistUser struct { + CreatedAt string `json:"createdAt"` + UserID int64 `json:"userId"` + Delisted bool `json:"delisted"` + Reason string `json:"reason"` +} + +type delistTrack struct { + CreatedAt string `json:"createdAt"` + TrackID int64 `json:"trackId"` + OwnerID int64 `json:"ownerId"` + TrackCID string `json:"trackCid"` + Delisted bool `json:"delisted"` + Reason string `json:"reason"` +} + +func (j *UpdateDelistStatusesJob) processEntity(ctx context.Context, entity string, currentBlockTime time.Time) error { + cursorBefore, err := j.readCursor(ctx, entity) + if err != nil { + return fmt.Errorf("read cursor: %w", err) + } + + pollURL := fmt.Sprintf("%sstatuses/%s?batchSize=%d", j.endpoint, entity, delistBatchSize) + if !cursorBefore.IsZero() { + // apps formats the cursor as RFC3339-ish "%Y-%m-%dT%H:%M:%S.%fZ" then + // URL-escapes it. time.RFC3339Nano includes the fractional seconds + // we need; queryEscape gives the same %-encoding apps applies. + pollURL += "&cursor=" + url.QueryEscape(cursorBefore.UTC().Format("2006-01-02T15:04:05.999999Z")) + } + + resp, err := signedHTTPGet(j.httpClient, pollURL, j.delegateKey) + if err != nil { + return fmt.Errorf("signed GET: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + // drainResponse closes the body and constructs an error. + return fmt.Errorf("notifier returned non-2xx: %w", drainResponse(resp)) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return fmt.Errorf("read body: %w", err) + } + + if entity == "users" { + return j.processUsersBatch(ctx, body, currentBlockTime) + } + return j.processTracksBatch(ctx, body, currentBlockTime) +} + +func (j *UpdateDelistStatusesJob) processUsersBatch(ctx context.Context, body []byte, currentBlockTime time.Time) error { + var parsed delistUserResp + if err := json.Unmarshal(body, &parsed); err != nil { + return fmt.Errorf("parse users response: %w", err) + } + users := parsed.Result.Users + if len(users) == 0 { + return nil + } + + tx, err := j.pool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + // Insert raw audit-log rows. + if err := insertUserDelistStatuses(ctx, tx, users); err != nil { + return fmt.Errorf("insert user_delist_statuses: %w", err) + } + + // Build per-user-latest-status map (input is asc by createdAt so later + // wins). + type entry struct { + delisted bool + createdAt time.Time + } + latest := make(map[int64]entry, len(users)) + for _, u := range users { + ts, err := parseDelistTimestamp(u.CreatedAt) + if err != nil { + j.logger.Warn("skipping user with unparseable createdAt", + zap.Int64("user_id", u.UserID), zap.String("createdAt", u.CreatedAt)) + continue + } + latest[u.UserID] = entry{delisted: u.Delisted, createdAt: ts} + } + + userIDs := make([]int64, 0, len(latest)) + for id := range latest { + userIDs = append(userIDs, id) + } + existing, err := selectExistingIDs(ctx, tx, "SELECT user_id FROM users WHERE user_id = ANY($1) AND is_current = true", userIDs) + if err != nil { + return fmt.Errorf("select existing users: %w", err) + } + + // Wait-for-indexing: if any missing user's delist timestamp is after + // the current block timestamp, halt this batch — apps' invariant. + for id, e := range latest { + if _, ok := existing[id]; ok { + continue + } + if e.createdAt.After(currentBlockTime) { + j.logger.Info("waiting for unindexed user; skipping batch", + zap.Int64("user_id", id)) + return nil + } + } + + for id := range existing { + e := latest[id] + var sql string + if e.delisted { + sql = `UPDATE users SET is_available = false, is_deactivated = true + WHERE user_id = $1 AND is_current = true AND is_available = true` + } else { + sql = `UPDATE users SET is_available = true, is_deactivated = false + WHERE user_id = $1 AND is_current = true AND is_available = false` + } + if _, err := tx.Exec(ctx, sql, id); err != nil { + return fmt.Errorf("update user %d availability: %w", id, err) + } + } + + // Advance cursor to the last createdAt in the batch. + lastCreatedAt, err := parseDelistTimestamp(users[len(users)-1].CreatedAt) + if err != nil { + return fmt.Errorf("parse final cursor: %w", err) + } + if err := upsertDelistCursor(ctx, tx, j.endpoint, "USERS", lastCreatedAt); err != nil { + return fmt.Errorf("save cursor: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return err + } + + j.logger.Info("Applied user delist batch", + zap.Int("batch_size", len(users)), + zap.Int("updated", len(existing))) + return nil +} + +func (j *UpdateDelistStatusesJob) processTracksBatch(ctx context.Context, body []byte, currentBlockTime time.Time) error { + var parsed delistTrackResp + if err := json.Unmarshal(body, &parsed); err != nil { + return fmt.Errorf("parse tracks response: %w", err) + } + tracks := parsed.Result.Tracks + if len(tracks) == 0 { + return nil + } + + tx, err := j.pool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + if err := insertTrackDelistStatuses(ctx, tx, tracks); err != nil { + return fmt.Errorf("insert track_delist_statuses: %w", err) + } + + type entry struct { + delisted bool + createdAt time.Time + } + latest := make(map[int64]entry, len(tracks)) + for _, t := range tracks { + ts, err := parseDelistTimestamp(t.CreatedAt) + if err != nil { + j.logger.Warn("skipping track with unparseable createdAt", + zap.Int64("track_id", t.TrackID), zap.String("createdAt", t.CreatedAt)) + continue + } + latest[t.TrackID] = entry{delisted: t.Delisted, createdAt: ts} + } + + trackIDs := make([]int64, 0, len(latest)) + for id := range latest { + trackIDs = append(trackIDs, id) + } + existing, err := selectExistingIDs(ctx, tx, "SELECT track_id FROM tracks WHERE track_id = ANY($1) AND is_current = true", trackIDs) + if err != nil { + return fmt.Errorf("select existing tracks: %w", err) + } + + for id, e := range latest { + if _, ok := existing[id]; ok { + continue + } + if e.createdAt.After(currentBlockTime) { + j.logger.Info("waiting for unindexed track; skipping batch", + zap.Int64("track_id", id)) + return nil + } + } + + for id := range existing { + e := latest[id] + var sql string + if e.delisted { + sql = `UPDATE tracks SET is_available = false, is_delete = true + WHERE track_id = $1 AND is_current = true AND is_available = true` + } else { + sql = `UPDATE tracks SET is_available = true, is_delete = false + WHERE track_id = $1 AND is_current = true AND is_available = false` + } + if _, err := tx.Exec(ctx, sql, id); err != nil { + return fmt.Errorf("update track %d availability: %w", id, err) + } + } + + lastCreatedAt, err := parseDelistTimestamp(tracks[len(tracks)-1].CreatedAt) + if err != nil { + return fmt.Errorf("parse final cursor: %w", err) + } + if err := upsertDelistCursor(ctx, tx, j.endpoint, "TRACKS", lastCreatedAt); err != nil { + return fmt.Errorf("save cursor: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return err + } + + j.logger.Info("Applied track delist batch", + zap.Int("batch_size", len(tracks)), + zap.Int("updated", len(existing))) + return nil +} + +func (j *UpdateDelistStatusesJob) readCursor(ctx context.Context, entity string) (time.Time, error) { + entityEnum := strings.ToUpper(entity) + var t time.Time + err := j.pool.QueryRow(ctx, + "SELECT created_at FROM delist_status_cursor WHERE host = $1 AND entity = $2", + j.endpoint, entityEnum, + ).Scan(&t) + if errors.Is(err, pgx.ErrNoRows) { + return time.Time{}, nil + } + return t, err +} + +func upsertDelistCursor(ctx context.Context, tx pgx.Tx, host, entity string, createdAt time.Time) error { + _, err := tx.Exec(ctx, ` + INSERT INTO delist_status_cursor (host, entity, created_at) + VALUES ($1, $2::delist_entity, $3) + ON CONFLICT (host, entity) DO UPDATE SET created_at = EXCLUDED.created_at + `, host, entity, createdAt) + return err +} + +func insertUserDelistStatuses(ctx context.Context, tx pgx.Tx, users []delistUser) error { + if len(users) == 0 { + return nil + } + createdAts := make([]string, len(users)) + userIDs := make([]int64, len(users)) + delisteds := make([]bool, len(users)) + reasons := make([]string, len(users)) + for i, u := range users { + createdAts[i] = u.CreatedAt + userIDs[i] = u.UserID + delisteds[i] = u.Delisted + reasons[i] = u.Reason + } + _, err := tx.Exec(ctx, ` + INSERT INTO user_delist_statuses (created_at, user_id, delisted, reason) + SELECT + unnest($1::timestamptz[]), + unnest($2::int[]), + unnest($3::bool[]), + unnest($4::delist_user_reason[]) + ON CONFLICT DO NOTHING + `, createdAts, userIDs, delisteds, reasons) + return err +} + +func insertTrackDelistStatuses(ctx context.Context, tx pgx.Tx, tracks []delistTrack) error { + if len(tracks) == 0 { + return nil + } + createdAts := make([]string, len(tracks)) + trackIDs := make([]int64, len(tracks)) + ownerIDs := make([]int64, len(tracks)) + cids := make([]string, len(tracks)) + delisteds := make([]bool, len(tracks)) + reasons := make([]string, len(tracks)) + for i, t := range tracks { + createdAts[i] = t.CreatedAt + trackIDs[i] = t.TrackID + ownerIDs[i] = t.OwnerID + cids[i] = t.TrackCID + delisteds[i] = t.Delisted + reasons[i] = t.Reason + } + _, err := tx.Exec(ctx, ` + INSERT INTO track_delist_statuses (created_at, track_id, owner_id, track_cid, delisted, reason) + SELECT + unnest($1::timestamptz[]), + unnest($2::int[]), + unnest($3::int[]), + unnest($4::text[]), + unnest($5::bool[]), + unnest($6::delist_track_reason[]) + ON CONFLICT DO NOTHING + `, createdAts, trackIDs, ownerIDs, cids, delisteds, reasons) + return err +} + +// selectExistingIDs runs a SELECT that returns one int64 per row and returns +// the result as a set. +func selectExistingIDs(ctx context.Context, tx pgx.Tx, sql string, ids []int64) (map[int64]struct{}, error) { + if len(ids) == 0 { + return map[int64]struct{}{}, nil + } + rows, err := tx.Query(ctx, sql, ids) + if err != nil { + return nil, err + } + defer rows.Close() + out := make(map[int64]struct{}, len(ids)) + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + out[id] = struct{}{} + } + return out, rows.Err() +} + +// parseDelistTimestamp accepts the two formats apps' notifier emits: +// +// "YYYY-MM-DD HH:MM:SS.ffffff+00" +// "YYYY-MM-DD HH:MM:SS+00" +// +// plus the RFC3339 forms in case the notifier upgrades. +func parseDelistTimestamp(s string) (time.Time, error) { + for _, layout := range []string{ + "2006-01-02 15:04:05.999999+00", + "2006-01-02 15:04:05+00", + time.RFC3339Nano, + time.RFC3339, + "2006-01-02T15:04:05.999999Z", + "2006-01-02T15:04:05Z", + } { + if t, err := time.Parse(layout, s); err == nil { + return t, nil + } + } + return time.Time{}, fmt.Errorf("unrecognized timestamp: %q", s) +} +