diff --git a/ami/main/admin.py b/ami/main/admin.py index 0605c238e..bce4c19ac 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -245,8 +245,108 @@ def update_calculated_fields(self, request: HttpRequest, queryset: QuerySet[Even update_calculated_fields_for_events(qs=queryset) self.message_user(request, f"Updated {queryset.count()} events.") + @admin.action(description="Run Occurrence Tracking on selected events") + def run_tracking_on_events(self, request: HttpRequest, queryset: QuerySet[Event]): + from collections import defaultdict + + from django.contrib import messages + from django.template.response import TemplateResponse + + from ami.ml.post_processing.admin_forms import TrackingActionForm + + # Superuser-only: queues background jobs and exposes tunables that change + # determination scoring across an event. Project admins can request a run + # via a superuser; widening the gate is a separate decision. + if not request.user.is_superuser: + self.message_user(request, "Only superusers can trigger tracking jobs.", level=messages.ERROR) + return None + + if request.POST.get("confirm"): + form = TrackingActionForm(request.POST, events=queryset) + if not form.is_valid(): + # Re-render with errors. + return TemplateResponse( + request, + "admin/main/tracking_confirmation.html", + { + **self.admin_site.each_context(request), + "title": "Run Occurrence Tracking", + "queryset": queryset, + "scope_label": f"{queryset.count()} event(s)", + "scope_summary": ( + "One Job is enqueued per project. Each Job processes all " + "selected events from that project and is visible on the " + "Jobs admin page where you can watch its log stream." + ), + "form": form, + "action_name": "run_tracking_on_events", + "action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME, + "opts": self.model._meta, + }, + ) + + config = form.to_config() + by_project: dict[int, list[int]] = defaultdict(list) + null_project_event_ids: list[int] = [] + for ev in queryset.values("pk", "project_id"): + if ev["project_id"] is None: + null_project_event_ids.append(ev["pk"]) + continue + by_project[ev["project_id"]].append(ev["pk"]) + + if null_project_event_ids: + self.message_user( + request, + f"Skipped {len(null_project_event_ids)} event(s) without a project: " + f"{null_project_event_ids}. Fix Event.project before tracking those.", + level=messages.WARNING, + ) + + jobs = [] + for project_id, event_ids in by_project.items(): + job = Job.objects.create( + name=f"Post-processing: Tracking on {len(event_ids)} event(s)", + project_id=project_id, + job_type_key="post_processing", + params={ + "task": "tracking", + "config": {**config, "event_ids": event_ids}, + }, + ) + job.enqueue() + jobs.append(job.pk) + + self.message_user( + request, + f"Queued Tracking for {sum(len(v) for v in by_project.values())} event(s) " + f"across {len(by_project)} project(s). Jobs: {jobs}", + ) + return None + + # GET / first POST without confirm — render the intermediate page. + form = TrackingActionForm(events=queryset) + return TemplateResponse( + request, + "admin/main/tracking_confirmation.html", + { + **self.admin_site.each_context(request), + "title": "Run Occurrence Tracking", + "queryset": queryset, + "scope_label": f"{queryset.count()} event(s)", + "scope_summary": ( + "One Job is enqueued per project. Each Job processes all " + "selected events from that project and is visible on the " + "Jobs admin page where you can watch its log stream." + ), + "form": form, + "action_name": "run_tracking_on_events", + "action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME, + "opts": self.model._meta, + }, + ) + list_filter = ("deployment", "project", "start") - actions = [update_calculated_fields] + actions = [update_calculated_fields, run_tracking_on_events] @admin.register(SourceImage) @@ -668,10 +768,109 @@ def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceI self.message_user(request, f"Queued Small Size Filter for {queryset.count()} capture set(s). Jobs: {jobs}") + @admin.action(description="Run Occurrence Tracking on selected capture sets") + def run_tracking(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]): + from django.contrib import messages + from django.template.response import TemplateResponse + + from ami.main.models import Event + from ami.ml.post_processing.admin_forms import TrackingActionForm + + # Superuser-only: queues background jobs and exposes tunables that change + # determination scoring across an event. Mirrors EventAdmin.run_tracking_on_events. + if not request.user.is_superuser: + self.message_user(request, "Only superusers can trigger tracking jobs.", level=messages.ERROR) + return None + + # Aggregate Event queryset across all selected collections; the form uses this + # to scope the feature-extraction-algorithm dropdown. + events_qs = Event.objects.filter(captures__collections__in=queryset).distinct() + + if request.POST.get("confirm"): + form = TrackingActionForm(request.POST, events=events_qs) + if not form.is_valid(): + return TemplateResponse( + request, + "admin/main/tracking_confirmation.html", + { + **self.admin_site.each_context(request), + "title": "Run Occurrence Tracking", + "queryset": queryset, + "scope_label": f"{queryset.count()} capture set(s)", + "scope_summary": ( + "One Job is enqueued per capture set. Each Job tracks every " + "event whose images belong to the set and is visible on the " + "Jobs admin page where you can watch its log stream." + ), + "form": form, + "action_name": "run_tracking", + "action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME, + "opts": self.model._meta, + }, + ) + + config = form.to_config() + jobs = [] + empty_collections: list[int] = [] + for collection in queryset: + event_ids = list( + Event.objects.filter(captures__collections=collection) + .values_list("pk", flat=True) + .distinct() + .order_by("pk") + ) + if not event_ids: + empty_collections.append(collection.pk) + continue + job = Job.objects.create( + name=f"Post-processing: Tracking on Capture Set {collection.pk}", + project=collection.project, + source_image_collection=collection, + job_type_key="post_processing", + params={ + "task": "tracking", + "config": {**config, "event_ids": event_ids}, + }, + ) + job.enqueue() + jobs.append(job.pk) + + if empty_collections: + self.message_user( + request, + f"Skipped {len(empty_collections)} capture set(s) with no events: {empty_collections}.", + level=messages.WARNING, + ) + self.message_user(request, f"Queued Tracking for {len(jobs)} capture set(s). Jobs: {jobs}") + return None + + # GET / first POST without confirm — render the intermediate page. + form = TrackingActionForm(events=events_qs) + return TemplateResponse( + request, + "admin/main/tracking_confirmation.html", + { + **self.admin_site.each_context(request), + "title": "Run Occurrence Tracking", + "queryset": queryset, + "scope_label": f"{queryset.count()} capture set(s)", + "scope_summary": ( + "One Job is enqueued per capture set. Each Job tracks every " + "event whose images belong to the set and is visible on the " + "Jobs admin page where you can watch its log stream." + ), + "form": form, + "action_name": "run_tracking", + "action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME, + "opts": self.model._meta, + }, + ) + actions = [ populate_collection, populate_collection_async, run_small_size_filter, + run_tracking, ] # Hide images many-to-many field from form. This would list all source images in the database. diff --git a/ami/main/migrations/0084_add_pgvector_extension.py b/ami/main/migrations/0084_add_pgvector_extension.py new file mode 100644 index 000000000..bcde7aaec --- /dev/null +++ b/ami/main/migrations/0084_add_pgvector_extension.py @@ -0,0 +1,16 @@ +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0083_dedupe_taxalist_names"), + ] + + operations = [ + migrations.RunSQL( + sql="CREATE EXTENSION IF NOT EXISTS vector;", + # No-op on reverse: the extension may be shared with other features/databases, + # and dropping it can be restricted in some hosted environments. + reverse_sql=migrations.RunSQL.noop, + ), + ] diff --git a/ami/main/migrations/0085_classification_features_2048.py b/ami/main/migrations/0085_classification_features_2048.py new file mode 100644 index 000000000..b43c9c168 --- /dev/null +++ b/ami/main/migrations/0085_classification_features_2048.py @@ -0,0 +1,20 @@ +from django.db import migrations +import pgvector.django.vector + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0084_add_pgvector_extension"), + ] + + operations = [ + migrations.AddField( + model_name="classification", + name="features_2048", + field=pgvector.django.vector.VectorField( + dimensions=2048, + null=True, + help_text="Feature embedding from the model backbone", + ), + ), + ] diff --git a/ami/main/migrations/0086_detection_next_detection.py b/ami/main/migrations/0086_detection_next_detection.py new file mode 100644 index 000000000..78192ec9d --- /dev/null +++ b/ami/main/migrations/0086_detection_next_detection.py @@ -0,0 +1,23 @@ +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0085_classification_features_2048"), + ] + + operations = [ + migrations.AddField( + model_name="detection", + name="next_detection", + field=models.OneToOneField( + blank=True, + help_text="The detection that follows this one in the tracking sequence.", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="previous_detection", + to="main.detection", + ), + ), + ] diff --git a/ami/main/models.py b/ami/main/models.py index e91b395b7..8c3fd49b9 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -9,6 +9,7 @@ from io import BytesIO from typing import Final, final # noqa: F401 +import pgvector.django import PIL.Image import pydantic from django.apps import apps @@ -2590,6 +2591,11 @@ class Classification(BaseModel): null=True, help_text="The probabilities the model, calibrated by the model maker, likely the softmax output", ) + features_2048 = pgvector.django.VectorField( + dimensions=2048, + null=True, + help_text="Feature embedding from the model backbone", + ) category_map = models.ForeignKey("ml.AlgorithmCategoryMap", on_delete=models.PROTECT, null=True) algorithm = models.ForeignKey( @@ -2784,6 +2790,15 @@ class Detection(BaseModel): similarity_vector = models.JSONField(null=True, blank=True) + next_detection = models.OneToOneField( + "self", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="previous_detection", + help_text="The detection that follows this one in the tracking sequence.", + ) + # For type hints classifications: models.QuerySet["Classification"] source_image_id: int @@ -3363,9 +3378,17 @@ def update_occurrence_determination( new_score = top_identification.score elif not top_identification: top_prediction = occurrence.best_prediction - if top_prediction and top_prediction.taxon and top_prediction.taxon != current_determination: - new_determination = top_prediction.taxon - new_score = top_prediction.score + if top_prediction and top_prediction.taxon: + if top_prediction.taxon != current_determination: + new_determination = top_prediction.taxon + new_score = top_prediction.score + elif top_prediction.score != occurrence.determination_score: + # Taxon unchanged but a higher-scoring classification has appeared + # for the same taxon (e.g. tracking merged a new detection into the + # chain whose top species classification scored higher than the + # keeper's). Refresh the score so determination_score reflects the + # best evidence available across the occurrence's detections. + new_score = top_prediction.score if new_determination and new_determination != current_determination: logger.debug(f"Changing det. of {occurrence} from {current_determination} to {new_determination}") diff --git a/ami/main/test_admin.py b/ami/main/test_admin.py new file mode 100644 index 000000000..77d850055 --- /dev/null +++ b/ami/main/test_admin.py @@ -0,0 +1,117 @@ +"""Admin-action tests for the Occurrence Tracking trigger. + +Covers both entry-points (EventAdmin + SourceImageCollectionAdmin), exercising +the intermediate confirmation page, the per-project Job partition, and the +config-passthrough from form to ``Job.params['config']``. +""" +from django.contrib import admin as django_admin +from django.test import Client, TestCase +from django.urls import reverse + +from ami.jobs.models import Job +from ami.main.models import SourceImageCollection +from ami.tests.fixtures.main import create_captures, setup_test_project +from ami.users.models import User + + +class _AdminTrackingCase(TestCase): + def setUp(self) -> None: + self.superuser = User.objects.create_superuser( + email="trackadmin@example.com", + password="x", + ) + self.client = Client() + self.client.force_login(self.superuser) + + self.project, self.deployment = setup_test_project(reuse=False) + create_captures(deployment=self.deployment, num_nights=1, images_per_night=2, interval_minutes=1) + self.event = self.project.events.first() + assert self.event is not None + + +class TestEventAdminTrackingAction(_AdminTrackingCase): + def _post_action(self, data): + url = reverse("admin:main_event_changelist") + payload = { + "action": "run_tracking_on_events", + django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.event.pk)], + **data, + } + return self.client.post(url, data=payload) + + def test_renders_intermediate_page_without_confirm(self): + response = self._post_action({}) + self.assertEqual(response.status_code, 200) + self.assertIn(b"Run Occurrence Tracking", response.content) + self.assertIn(b"Tracking parameters", response.content) + # No Job created on the GET-equivalent step. + self.assertEqual(Job.objects.filter(project=self.project, job_type_key="post_processing").count(), 0) + + def test_creates_job_per_project_and_passes_config_through(self): + # Build a second event in a different project to exercise per-project partitioning. + other_project, other_deployment = setup_test_project(reuse=False) + create_captures(deployment=other_deployment, num_nights=1, images_per_night=2, interval_minutes=1) + other_event = other_project.events.first() + assert other_event is not None + + url = reverse("admin:main_event_changelist") + response = self.client.post( + url, + data={ + "action": "run_tracking_on_events", + django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.event.pk), str(other_event.pk)], + "confirm": "yes", + "cost_threshold": "0.35", + "skip_if_human_identifications": "on", + # require_fresh_event intentionally omitted = unchecked. + "feature_extraction_algorithm_id": "", + }, + ) + self.assertEqual(response.status_code, 302) + + jobs = Job.objects.filter(job_type_key="post_processing").order_by("project_id") + self.assertEqual(jobs.count(), 2) + by_project = {j.project_id: j for j in jobs} + self.assertIn(self.project.pk, by_project) + self.assertIn(other_project.pk, by_project) + + for job in jobs: + cfg = job.params["config"] + self.assertEqual(cfg["cost_threshold"], 0.35) + self.assertTrue(cfg["skip_if_human_identifications"]) + self.assertFalse(cfg["require_fresh_event"]) + self.assertNotIn("feature_extraction_algorithm_id", cfg) + # Each job carries only its own project's events. + self.assertEqual(len(cfg["event_ids"]), 1) + + self.assertEqual(by_project[self.project.pk].params["config"]["event_ids"], [self.event.pk]) + self.assertEqual(by_project[other_project.pk].params["config"]["event_ids"], [other_event.pk]) + + +class TestCollectionAdminTrackingAction(_AdminTrackingCase): + def setUp(self) -> None: + super().setUp() + self.collection = SourceImageCollection.objects.create( + project=self.project, name="Tracking admin test collection" + ) + self.collection.images.set(self.event.captures.all()) + + def test_creates_job_with_event_ids_from_collection(self): + url = reverse("admin:main_sourceimagecollection_changelist") + response = self.client.post( + url, + data={ + "action": "run_tracking", + django_admin.helpers.ACTION_CHECKBOX_NAME: [str(self.collection.pk)], + "confirm": "yes", + "cost_threshold": "0.2", + "skip_if_human_identifications": "on", + "require_fresh_event": "on", + "feature_extraction_algorithm_id": "", + }, + ) + self.assertEqual(response.status_code, 302) + + job = Job.objects.get(job_type_key="post_processing", source_image_collection=self.collection) + self.assertEqual(job.project_id, self.project.pk) + self.assertEqual(job.params["config"]["event_ids"], [self.event.pk]) diff --git a/ami/main/tests.py b/ami/main/tests.py index 0e7672e1c..6b0e8637e 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -4329,3 +4329,68 @@ def test_source_image_cached_counts_refresh_on_threshold_change(self): f"SourceImage {image.pk} cache stale after raising threshold: " f"cache={image.detections_count}, fresh={image.get_detections_count()}", ) + + +class TestOccurrenceDeterminationScoreRefresh(TestCase): + """Regression test for `update_occurrence_determination`. + + Tracking can fold a new detection into an existing occurrence whose top + species classification scores higher than the keeper's. The taxon stays the + same across the chain, but the keeper's `determination_score` should + refresh to reflect the strongest evidence available — not stay pinned to + the score from the first detection. + """ + + def setUp(self) -> None: + from ami.ml.models import Algorithm + + self.project, self.deployment = setup_test_project() + create_captures(deployment=self.deployment, num_nights=1, images_per_night=2) + create_taxa(project=self.project) + self.taxon = Taxon.objects.filter(projects=self.project).first() + self.algorithm = Algorithm.objects.create(name="Test Species Classifier", version=1) + + def _make_detection_with_classification(self, source_image: SourceImage, score: float) -> Detection: + occurrence = Occurrence.objects.create( + project=self.project, + event=source_image.event, + deployment=self.deployment, + determination=self.taxon, + determination_score=score, + ) + detection = Detection.objects.create( + source_image=source_image, + occurrence=occurrence, + timestamp=source_image.timestamp, + ) + Classification.objects.create( + detection=detection, + taxon=self.taxon, + algorithm=self.algorithm, + score=score, + terminal=True, + timestamp=source_image.timestamp, + ) + return detection + + def test_score_refreshes_when_taxon_unchanged(self) -> None: + """When best_prediction has the same taxon but a higher score, the + keeper's determination_score must update on save().""" + captures = list(self.deployment.captures.order_by("timestamp")[:2]) + keeper_det = self._make_detection_with_classification(captures[0], score=0.20) + absorbed_det = self._make_detection_with_classification(captures[1], score=0.55) + keeper = keeper_det.occurrence + absorbed = absorbed_det.occurrence + assert keeper is not None and absorbed is not None + + # Simulate tracking merge: reassign absorbed detection to keeper. + absorbed_det.occurrence = keeper + absorbed_det.save() + Occurrence.objects.filter(pk=absorbed.pk).delete() + + # Keeper.save() should pick up the higher score from the merged-in detection. + keeper.save() + keeper.refresh_from_db() + + self.assertEqual(keeper.determination, self.taxon) + self.assertAlmostEqual(keeper.determination_score, 0.55, places=4) diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index c259e4aea..0f7c1217c 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -688,7 +688,7 @@ def create_classification( if existing_classification: # @TODO remove this after all existing classifications have been updated (added 2024-12-20) - NEW_FIELDS = ["logits", "scores", "terminal", "category_map"] + NEW_FIELDS = ["logits", "scores", "terminal", "category_map", "features_2048"] logger.debug( "Duplicate classification found: " f"{existing_classification.taxon} from {existing_classification.algorithm}, " @@ -705,6 +705,9 @@ def create_classification( if field == "category_map": # Use the foreign key from the classification algorithm setattr(existing_classification, field, classification_algo.category_map) + elif field == "features_2048": + # The pipeline response carries this as `features`; the DB column is `features_2048`. + setattr(existing_classification, field, classification_resp.features) else: # Get the value from the classification response setattr(existing_classification, field, getattr(classification_resp, field)) @@ -722,6 +725,7 @@ def create_classification( timestamp=classification_resp.timestamp or now(), logits=classification_resp.logits, scores=classification_resp.scores, + features_2048=classification_resp.features, terminal=classification_resp.terminal, category_map=classification_algo.category_map, ) diff --git a/ami/ml/post_processing/admin_forms.py b/ami/ml/post_processing/admin_forms.py new file mode 100644 index 000000000..41bac52b3 --- /dev/null +++ b/ami/ml/post_processing/admin_forms.py @@ -0,0 +1,117 @@ +"""Forms for triggering post-processing tasks from the Django admin. + +Each form is the single source of truth for the human-readable labels, +help-text, and validation rules of one task's tunable parameters. The form's +`cleaned_data` becomes the ``config`` dict on the resulting Job. + +Algorithm scope (which events / collection / queryset) lives outside the form +because it varies per admin entry-point — see the per-action helpers in +``ami/main/admin.py``. +""" +from __future__ import annotations + +from django import forms +from django.db.models import QuerySet + +from ami.main.models import Classification, Event +from ami.ml.models import Algorithm +from ami.ml.post_processing.tracking_task import DEFAULT_TRACKING_PARAMS + + +def _feature_algorithm_choices_for_events(events: QuerySet[Event]) -> list[tuple[int, str]]: + """Algorithms that produced ``features_2048`` on the given events. + + Scoped to the operator's selection so the dropdown stays bounded on + production-sized DBs and never reveals algorithms from other projects. + """ + algorithm_ids = ( + Classification.objects.filter( + detection__source_image__event__in=events, + features_2048__isnull=False, + algorithm_id__isnull=False, + ) + .values_list("algorithm_id", flat=True) + .distinct() + ) + return [ + (a.pk, f"{a.name} (#{a.pk})") for a in Algorithm.objects.filter(pk__in=list(algorithm_ids)).order_by("name") + ] + + +class TrackingActionForm(forms.Form): + """Knobs surfaced when an admin triggers Occurrence Tracking. + + Pass the events the action will run on as ``events`` so the + feature-extraction-algorithm dropdown is scoped correctly. The class is + constructed once for the GET (rendering defaults) and once for the POST + (validating submitted values); pass the same queryset both times. + """ + + cost_threshold = forms.FloatField( + label="Cost threshold", + initial=DEFAULT_TRACKING_PARAMS.cost_threshold, + min_value=0.0, + help_text=( + "Maximum sum of (1 - cosine similarity) + (1 - IoU) + (1 - box ratio) + " + "(distance / image diagonal) for two detections to be considered the " + "same individual. Lower = stricter matching, fewer false links. " + "Default 0.2 is calibrated against synthetic features; tune per dataset." + ), + ) + + skip_if_human_identifications = forms.BooleanField( + label="Skip events with human identifications", + initial=DEFAULT_TRACKING_PARAMS.skip_if_human_identifications, + required=False, + help_text=( + "If checked, events that already have any user-confirmed identification " + "are skipped to preserve manual review work. Recommended on." + ), + ) + + require_fresh_event = forms.BooleanField( + label="Require fresh event (v1)", + initial=DEFAULT_TRACKING_PARAMS.require_fresh_event, + required=False, + help_text=( + "v1 only handles events where every detection has its own auto-created " + "occurrence (1:1) and no chain links exist. Skip already-tracked events. " + "Re-tracking lands in v2." + ), + ) + + feature_extraction_algorithm_id = forms.ChoiceField( + label="Feature extraction algorithm", + required=False, + help_text=( + "Override the algorithm whose embeddings are used for matching. Leave " + "blank to auto-detect (works when only one feature-extracting algorithm " + "ran on the event). Required when multiple algorithms have produced " + "embeddings on the same event." + ), + ) + + def __init__(self, *args, events: QuerySet[Event] | None = None, **kwargs): + super().__init__(*args, **kwargs) + choices: list[tuple[str, str]] = [("", "— auto-detect —")] + if events is not None: + choices.extend((str(pk), label) for pk, label in _feature_algorithm_choices_for_events(events)) + self.fields["feature_extraction_algorithm_id"].choices = choices + + def to_config(self) -> dict: + """Return the ``cleaned_data`` shape the TrackingTask expects in ``Job.params['config']``. + + Drops the algorithm override when blank so ``_params()`` falls through to + auto-detection rather than logging an "unknown key" warning. + """ + if not self.is_valid(): + raise ValueError(f"TrackingActionForm has errors: {self.errors.as_text()}") + config = { + "cost_threshold": self.cleaned_data["cost_threshold"], + "skip_if_human_identifications": self.cleaned_data["skip_if_human_identifications"], + "require_fresh_event": self.cleaned_data["require_fresh_event"], + } + algo_id_str = self.cleaned_data.get("feature_extraction_algorithm_id") or "" + if algo_id_str: + config["feature_extraction_algorithm_id"] = int(algo_id_str) + return config diff --git a/ami/ml/post_processing/registry.py b/ami/ml/post_processing/registry.py index c85f607f9..258724f39 100644 --- a/ami/ml/post_processing/registry.py +++ b/ami/ml/post_processing/registry.py @@ -1,8 +1,10 @@ # Registry of available post-processing tasks from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask +from ami.ml.post_processing.tracking_task import TrackingTask POSTPROCESSING_TASKS = { SmallSizeFilterTask.key: SmallSizeFilterTask, + TrackingTask.key: TrackingTask, } diff --git a/ami/ml/post_processing/tests/__init__.py b/ami/ml/post_processing/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/ml/post_processing/tests/test_tracking_task.py b/ami/ml/post_processing/tests/test_tracking_task.py new file mode 100644 index 000000000..2dac6cf22 --- /dev/null +++ b/ami/ml/post_processing/tests/test_tracking_task.py @@ -0,0 +1,143 @@ +import logging +from collections import defaultdict + +import numpy as np +from django.test import TestCase +from django.utils import timezone + +from ami.jobs.models import Job +from ami.main.models import Classification, Detection, Occurrence +from ami.ml.models import Algorithm +from ami.ml.post_processing.tracking_task import ( + DEFAULT_TRACKING_PARAMS, + TrackingTask, + assign_occurrences_by_tracking_images, +) +from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project + +logger = logging.getLogger(__name__) + + +class TestTracking(TestCase): + def setUp(self) -> None: + self.project, self.deployment = setup_test_project(reuse=False) + # 1 night, 5 captures spaced 1 minute apart so they group into one event. + create_captures(deployment=self.deployment, num_nights=1, images_per_night=5, interval_minutes=1) + create_taxa(self.project) + create_occurrences(deployment=self.deployment, num=6) + + self.event = self.project.events.first() + assert self.event is not None + self.source_images = list(self.event.captures.order_by("timestamp")) + + # Source images need dimensions for the cost function. + for img in self.source_images: + if not img.width or not img.height: + img.width = 4096 + img.height = 2160 + img.save(update_fields=["width", "height"]) + + self.algorithm = self._assign_mock_features_to_occurrence_detections(self.event) + + # Capture ground-truth groupings so we can compare after re-tracking. + self.ground_truth_groups = defaultdict(set) + for occ in Occurrence.objects.filter(event=self.event): + for det_id in Detection.objects.filter(occurrence=occ).values_list("id", flat=True): + self.ground_truth_groups[occ.pk].add(det_id) + + Detection.objects.filter(source_image__event=self.event).update(next_detection=None) + + def _assign_mock_features_to_occurrence_detections( + self, event, algorithm_name: str = "MockTrackingAlgorithm" + ) -> Algorithm: + algorithm, _ = Algorithm.objects.get_or_create(name=algorithm_name, key="mock-tracking-algo") + rng = np.random.default_rng(seed=42) + + for occurrence in event.occurrences.all(): + base_vector = rng.random(2048) + for det in occurrence.detections.all(): + noisy = base_vector + rng.normal(0, 0.001, size=2048) + Classification.objects.update_or_create( + detection=det, + algorithm=algorithm, + defaults={ + "timestamp": timezone.now(), + "features_2048": noisy.tolist(), + "terminal": True, + "score": 1.0, + }, + ) + return algorithm + + def test_tracking_reproduces_occurrence_groups(self): + # v1 fresh-data scenario: pipeline already created 1:1 detection/occurrence. + # Wipe only chain links so tracking has to rebuild them; occurrences stay so + # event_is_fresh() passes and tracking runs. + Detection.objects.filter(source_image__event=self.event).update(next_detection=None) + + # Sanity-check the fresh invariant before running. + orphans = Detection.objects.filter(source_image__event=self.event, occurrence__isnull=True).count() + self.assertEqual(orphans, 0, "Test setup expects every detection to have an occurrence") + + assign_occurrences_by_tracking_images( + event=self.event, + logger=logger, + algorithm=self.algorithm, + params=DEFAULT_TRACKING_PARAMS, + ) + + new_groups = { + occ.pk: set(Detection.objects.filter(occurrence=occ).values_list("id", flat=True)) + for occ in Occurrence.objects.filter(event=self.event) + } + + self.assertEqual( + len(new_groups), + len(self.ground_truth_groups), + f"Expected {len(self.ground_truth_groups)} groups, got {len(new_groups)}", + ) + + gt_values = list(self.ground_truth_groups.values()) + for new_set in new_groups.values(): + self.assertIn( + new_set, + gt_values, + f"Reconstructed group {new_set} does not match any ground-truth group", + ) + + +class TestTrackingTaskResolveEvents(TestCase): + """Scope-resolution unit tests for ``TrackingTask._resolve_events``.""" + + def setUp(self) -> None: + self.project, self.deployment = setup_test_project(reuse=False) + create_captures(deployment=self.deployment, num_nights=1, images_per_night=2, interval_minutes=1) + self.event = self.project.events.first() + assert self.event is not None + + def test_resolve_events_from_event_ids(self): + task = TrackingTask(logger=logger, event_ids=[self.event.pk]) + events = task._resolve_events() + self.assertEqual([e.pk for e in events], [self.event.pk]) + + def test_resolve_events_raises_when_no_event_ids(self): + task = TrackingTask(logger=logger) + with self.assertRaises(ValueError): + task._resolve_events() + + def test_resolve_events_drops_cross_project_ids(self): + # Make a foreign event in a different project; the job project should win. + other_project, other_deployment = setup_test_project(reuse=False) + create_captures(deployment=other_deployment, num_nights=1, images_per_night=2, interval_minutes=1) + foreign_event = other_project.events.first() + assert foreign_event is not None + + job = Job.objects.create( + name="Tracking scope test", + project=self.project, + job_type_key="post_processing", + params={"task": "tracking", "config": {"event_ids": [self.event.pk, foreign_event.pk]}}, + ) + task = TrackingTask(job=job, event_ids=[self.event.pk, foreign_event.pk]) + events = task._resolve_events() + self.assertEqual([e.pk for e in events], [self.event.pk]) diff --git a/ami/ml/post_processing/tracking_task.py b/ami/ml/post_processing/tracking_task.py new file mode 100644 index 000000000..0e62883d2 --- /dev/null +++ b/ami/ml/post_processing/tracking_task.py @@ -0,0 +1,485 @@ +import dataclasses +import logging +import math +import typing +from collections.abc import Iterable + +import numpy as np +from django.db import transaction +from django.db.models import Count + +from ami.main.models import Classification, Detection, Event, Occurrence, SourceImage +from ami.ml.models import Algorithm +from ami.ml.post_processing.base import BasePostProcessingTask + +if typing.TYPE_CHECKING: + pass + + +@dataclasses.dataclass +class TrackingParams: + # cost_threshold: max sum of (1-cosine) + (1-IoU) + (1-box_ratio) + (distance/diag). + # WARNING: calibrated against synthetic features in tests. Real backbone embeddings + # have very different statistical properties (sparsity, norm distribution); tune + # per-dataset before relying on the default. + cost_threshold: float = 0.2 + skip_if_human_identifications: bool = True + require_completely_processed_session: bool = False + # v1 only operates on fresh data: every detection has its own auto-created + # occurrence (1:1) and no chain links exist yet. Re-tracking previously-tracked + # data is a v2 concern (see PR #1272 for incremental/append-prepend plan). + require_fresh_event: bool = True + feature_extraction_algorithm_id: int | None = None + + +DEFAULT_TRACKING_PARAMS = TrackingParams() + + +def cosine_similarity(v1: Iterable[float], v2: Iterable[float]) -> float: + a = np.array(v1) + b = np.array(v2) + sim = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + return float(np.clip(sim, 0.0, 1.0)) + + +def iou(bb1, bb2) -> float: + xA = max(bb1[0], bb2[0]) + yA = max(bb1[1], bb2[1]) + xB = min(bb1[2], bb2[2]) + yB = min(bb1[3], bb2[3]) + inter = max(0, xB - xA + 1) * max(0, yB - yA + 1) + area1 = (bb1[2] - bb1[0] + 1) * (bb1[3] - bb1[1] + 1) + area2 = (bb2[2] - bb2[0] + 1) * (bb2[3] - bb2[1] + 1) + union = area1 + area2 - inter + return inter / union if union > 0 else 0.0 + + +def box_ratio(bb1, bb2) -> float: + area1 = (bb1[2] - bb1[0] + 1) * (bb1[3] - bb1[1] + 1) + area2 = (bb2[2] - bb2[0] + 1) * (bb2[3] - bb2[1] + 1) + return min(area1, area2) / max(area1, area2) + + +def distance_ratio(bb1, bb2, img_diag: float) -> float: + cx1 = (bb1[0] + bb1[2]) / 2 + cy1 = (bb1[1] + bb1[3]) / 2 + cx2 = (bb2[0] + bb2[2]) / 2 + cy2 = (bb2[1] + bb2[3]) / 2 + dist = math.sqrt((cx2 - cx1) ** 2 + (cy2 - cy1) ** 2) + return dist / img_diag if img_diag > 0 else 1.0 + + +def image_diagonal(width: int, height: int) -> int: + return int(math.ceil(math.sqrt(width**2 + height**2))) + + +def total_cost(f1, f2, bb1, bb2, diag) -> float: + return ( + (1 - cosine_similarity(f1, f2)) + + (1 - iou(bb1, bb2)) + + (1 - box_ratio(bb1, bb2)) + + distance_ratio(bb1, bb2, diag) + ) + + +def get_unique_feature_algorithm_for_event(event: Event) -> tuple[Algorithm | None, list[Algorithm]]: + """ + Return ``(unique_algorithm, all_candidates)``. + + If exactly one feature-extraction algorithm produced ``features_2048`` for this + event, returns that algorithm and a single-element list. Otherwise returns + ``(None, candidates)`` so the caller can either skip with a warning or require + the operator to pass an explicit ``feature_extraction_algorithm_id``. + """ + algo_ids = ( + Classification.objects.filter( + detection__source_image__event=event, + features_2048__isnull=False, + algorithm_id__isnull=False, + ) + .values_list("algorithm_id", flat=True) + .distinct() + ) + candidates = list(Algorithm.objects.filter(pk__in=list(algo_ids))) + if len(candidates) == 1: + return candidates[0], candidates + return None, candidates + + +def event_is_fresh(event: Event) -> tuple[bool, str]: + """ + Fresh = every detection in the event has an occurrence AND every occurrence + in the event has exactly one detection. v1 tracking only operates on fresh + data (the state after pipeline processing creates 1:1 detection/occurrence + auto-mappings, before any chain consolidation). + """ + orphan_detections = Detection.objects.filter( + source_image__event=event, + occurrence__isnull=True, + ).count() + if orphan_detections: + return False, f"{orphan_detections} detection(s) without an occurrence" + + multi_detection_occurrences = ( + Occurrence.objects.filter(event=event).annotate(_n=Count("detections")).filter(_n__gt=1).count() + ) + if multi_detection_occurrences: + return False, f"{multi_detection_occurrences} occurrence(s) already span >1 detection" + + return True, "" + + +def event_fully_processed(event: Event, logger: logging.Logger, algorithm: Algorithm) -> bool: + total = event.captures.count() + processed = ( + event.captures.filter( + detections__classifications__features_2048__isnull=False, + detections__classifications__algorithm=algorithm, + ) + .distinct() + .count() + ) + if processed < total: + logger.info(f"Event {event.pk} not fully processed: {processed}/{total} captures") + return False + return True + + +def get_feature_vector(detection: Detection, algorithm: Algorithm): + return ( + detection.classifications.filter(features_2048__isnull=False, algorithm=algorithm) + .order_by("-timestamp") + .values_list("features_2048", flat=True) + .first() + ) + + +def assign_occurrences_from_detection_chains(source_images: list[SourceImage], logger: logging.Logger) -> None: + """ + Walk chains via ``Detection.next_detection`` and consolidate each chain into + a single occurrence using a merge-into-first strategy: + + - Pick the first existing occurrence in the chain as the keeper. + - Reassign every other detection in the chain to the keeper. + - Delete now-empty sibling occurrences. + - If no detection in the chain has an occurrence yet, create one. + + Designed for fresh-event input (1:1 detection/occurrence). v2 incremental tracking + can reuse this primitive for prepend/append: keeper survives, new detections fold in. + """ + visited: set[int] = set() + created = 0 + merged = 0 + existing = Occurrence.objects.filter(detections__source_image__in=source_images).distinct().count() + + for image in source_images: + for det in image.detections.all(): + if det.pk in visited: + continue + try: + has_prior = det.previous_detection is not None + except Detection.DoesNotExist: + has_prior = False + if has_prior: + continue + + chain: list[Detection] = [] + current: Detection | None = det + while current and current.pk not in visited: + chain.append(current) + visited.add(current.pk) + current = current.next_detection + + old_occ_ids = {d.occurrence_id for d in chain if d.occurrence_id} + all_assigned = all(d.occurrence_id is not None for d in chain) + + # Coherent: every detection assigned and all share one occurrence. Nothing to do. + if len(old_occ_ids) == 1 and all_assigned: + continue + + # Pick keeper: first existing occurrence in chain order. + keeper: Occurrence | None = None + for d in chain: + if d.occurrence_id: + keeper = d.occurrence + break + + if keeper is None: + keeper = Occurrence.objects.create( + event=chain[0].source_image.event, + deployment=chain[0].source_image.deployment, + project=chain[0].source_image.project, + ) + created += 1 + + # Reassign chain detections to keeper. + for d in chain: + if d.occurrence_id != keeper.pk: + d.occurrence = keeper + d.save() + + # Delete now-empty sibling occurrences. v1's fresh-event invariant guarantees + # these have no Identifications attached (nothing has been ratified yet), so + # CASCADE on Identification.occurrence is harmless. v2 must instead reassign + # Identification.occurrence to the keeper before deleting. + for occ_id in old_occ_ids - {keeper.pk}: + try: + Occurrence.objects.filter(id=occ_id).delete() + merged += 1 + except Exception as e: + logger.error(f"Failed to delete occurrence {occ_id}: {e}") + + keeper.save() + + new_count = Occurrence.objects.filter(detections__source_image__in=source_images).distinct().count() + removed = existing - new_count + if removed > 0: + logger.info(f"Merged {merged} sibling occurrences into chain keepers (net -{removed}).") + logger.info( + f"Materialized {created} new occurrences across {len(source_images)} images. " + f"Occurrences before: {existing}, after: {new_count}. Detections processed: {len(visited)}." + ) + + +def pair_detections( + current_detections: list[Detection], + next_detections: list[Detection], + image_width: int, + image_height: int, + cost_threshold: float, + algorithm: Algorithm, + logger: logging.Logger, +) -> None: + """ + Greedy lowest-cost matching between two adjacent images. Sets `next_detection` + on each detection in `current_detections` for the best partner in `next_detections`, + if that partner's cost is below `cost_threshold` and not already claimed. + """ + diag = image_diagonal(image_width, image_height) + candidates: list[tuple[Detection, Detection, float]] = [] + + # Cache feature lookups: one query per detection instead of O(m*n). + current_vectors = {det.pk: get_feature_vector(det, algorithm) for det in current_detections} + next_vectors = {nxt.pk: get_feature_vector(nxt, algorithm) for nxt in next_detections} + + for det in current_detections: + det_vec = current_vectors[det.pk] + if det_vec is None: + continue + for nxt in next_detections: + nxt_vec = next_vectors[nxt.pk] + if nxt_vec is None: + continue + cost = total_cost(det_vec, nxt_vec, det.bbox, nxt.bbox, diag) + if cost < cost_threshold: + candidates.append((det, nxt, cost)) + + # Secondary keys (det.pk, nxt.pk) keep tied costs deterministic across runs. + candidates.sort(key=lambda x: (x[2], x[0].pk, x[1].pk)) + + claimed_current: set[int] = set() + claimed_next: set[int] = set() + + for det, nxt, cost in candidates: + if det.id in claimed_current or nxt.id in claimed_next: + continue + # Detach any existing inbound link to `nxt` before reassigning. + try: + prior: Detection | None = nxt.previous_detection + except Detection.DoesNotExist: + prior = None + if prior is not None: + prior.next_detection = None + prior.save() + + det.next_detection = nxt + det.save() + claimed_current.add(det.id) + claimed_next.add(nxt.id) + logger.debug(f"Linked detection {det.id} -> {nxt.id} (cost {cost:.4f})") + + +def assign_occurrences_by_tracking_images( + event: Event, + logger: logging.Logger, + algorithm: Algorithm, + params: TrackingParams = DEFAULT_TRACKING_PARAMS, + progress_cb: typing.Callable[[float], None] | None = None, +) -> None: + source_images = list(event.captures.order_by("timestamp")) + if len(source_images) < 2: + logger.warning(f"Event {event.pk}: not enough images to track ({len(source_images)})") + return + + transitions = len(source_images) - 1 + skipped_transitions = 0 + # Per-event atomic boundary: a crash mid-event rolls back chain links + occurrence + # consolidation for THIS event only, leaving other events in the job intact. + with transaction.atomic(): + for i in range(transitions): + cur = source_images[i] + nxt = source_images[i + 1] + + if not cur.width or not cur.height: + logger.warning( + f"Image {cur.pk} has no dimensions; skipping transition {i + 1}/{transitions} " + f"for event {event.pk}." + ) + skipped_transitions += 1 + if progress_cb: + progress_cb((i + 1) / transitions) + continue + + pair_detections( + list(cur.detections.all()), + list(nxt.detections.all()), + image_width=cur.width, + image_height=cur.height, + cost_threshold=params.cost_threshold, + algorithm=algorithm, + logger=logger, + ) + if progress_cb: + progress_cb((i + 1) / transitions) + + if skipped_transitions: + logger.info( + f"Event {event.pk}: skipped {skipped_transitions}/{transitions} transitions " + "due to missing image dimensions." + ) + + assign_occurrences_from_detection_chains(source_images, logger) + + +class TrackingTask(BasePostProcessingTask): + """ + Reconstruct occurrences in a SourceImageCollection by tracking detections across + consecutive captures using feature embeddings + bbox geometry. Updates each + Detection's `next_detection` link and creates one Occurrence per chain. + """ + + key = "tracking" + name = "Occurrence Tracking" + + # Scope keys live outside TrackingParams (which is reserved for algorithm tunables). + # Mirrors the pattern: scope = where to run; params = how to run. + _SCOPE_CONFIG_KEYS = frozenset({"event_ids"}) + + def _params(self) -> TrackingParams: + config_keys = {f.name for f in dataclasses.fields(TrackingParams)} + overrides = {k: v for k, v in self.config.items() if k in config_keys} + unknown = set(self.config) - config_keys - self._SCOPE_CONFIG_KEYS + if unknown: + self.logger.warning(f"Ignoring unknown tracking config keys: {sorted(unknown)}") + return dataclasses.replace(DEFAULT_TRACKING_PARAMS, **overrides) + + def _resolve_events(self) -> list[Event]: + """ + Returns events to track from ``config["event_ids"]``. + + Both admin entry-points (EventAdmin, SourceImageCollectionAdmin) compute the + event id list at action time and pass it through ``config``. The task only + understands events; collections are flattened to event ids by the trigger. + + If a job is attached, every resolved event must belong to ``job.project``; + cross-project IDs are dropped with a warning. This guards against a trigger + that smuggles event IDs from a project the operator can't see. + """ + event_ids = self.config.get("event_ids") + if not event_ids: + raise ValueError("Tracking task requires `event_ids` in config.") + qs = Event.objects.filter(pk__in=event_ids) + if self.job and self.job.project_id: + cross_project = list(qs.exclude(project_id=self.job.project_id).values_list("pk", flat=True)) + if cross_project: + self.logger.warning( + f"Dropping {len(cross_project)} event(s) outside job project " + f"{self.job.project_id}: {cross_project}" + ) + qs = qs.filter(project_id=self.job.project_id) + events = list(qs.order_by("created_at").distinct()) + missing = set(event_ids) - {e.pk for e in events} + if missing: + self.logger.warning(f"Tracking requested {sorted(missing)} but those events were not found.") + return events + + def run(self) -> None: + params = self._params() + self.logger.info(f"Tracking starting with params: {params}") + + events = self._resolve_events() + total = len(events) + collection_ref = ( + f" (job collection #{self.job.source_image_collection.pk})" + if self.job and self.job.source_image_collection + else "" + ) + self.logger.info(f"Tracking: {total} events{collection_ref}") + + for idx, event in enumerate(events, start=1): + self.logger.info(f"Tracking event {idx}/{total} (id={event.pk})") + + if params.require_fresh_event: + fresh, reason = event_is_fresh(event) + if not fresh: + self.logger.info( + f"Skipping event {event.pk}: not fresh ({reason}). " + "v1 only handles 1:1 detection/occurrence input. " + "Re-tracking previously-tracked data lands in v2 (incremental)." + ) + continue + + if params.feature_extraction_algorithm_id is not None: + algorithm = Algorithm.objects.filter(pk=params.feature_extraction_algorithm_id).first() + if algorithm is None: + self.logger.warning( + f"Configured feature_extraction_algorithm_id=" + f"{params.feature_extraction_algorithm_id} not found; skipping event {event.pk}." + ) + continue + self.logger.info(f"Using configured feature-extraction algorithm {algorithm.pk} for event {event.pk}.") + else: + algorithm, candidates = get_unique_feature_algorithm_for_event(event) + if algorithm is None: + if candidates: + candidate_names = [f"#{a.pk} {a.name}" for a in candidates] + self.logger.warning( + f"Event {event.pk}: detections classified by {len(candidates)} different " + f"feature-extraction algorithms ({candidate_names}). Pass " + "feature_extraction_algorithm_id in the job config to disambiguate. Skipping." + ) + else: + self.logger.warning( + f"Event {event.pk}: no detections with feature embeddings. " + "Run the processing pipeline first. Skipping." + ) + continue + + if ( + params.skip_if_human_identifications + and Occurrence.objects.filter(event=event, identifications__isnull=False).exists() + ): + self.logger.info(f"Skipping event {event.pk}: has human identifications.") + continue + + if params.require_completely_processed_session and not event_fully_processed( + event, logger=self.logger, algorithm=algorithm + ): + self.logger.info(f"Skipping event {event.pk}: not fully processed.") + continue + + def _stage_progress(p: float, _idx=idx, _total=total) -> None: + # Aggregate per-event progress into overall task progress. + overall = ((_idx - 1) + p) / _total + self.update_progress(overall) + + assign_occurrences_by_tracking_images( + event=event, + logger=self.logger, + algorithm=algorithm, + params=params, + progress_cb=_stage_progress, + ) + + self.update_progress(1.0) + self.logger.info("Tracking finished.") diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 9322e4116..a2a49d646 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -133,6 +133,20 @@ class ClassificationResponse(pydantic.BaseModel): ) scores: list[float] = [] logits: list[float] | None = None + features: list[float] | None = pydantic.Field( + default=None, + description=( + "Optional feature embedding vector from the model backbone, used for tracking and similarity search. " + "Must be exactly 2048 floats to match the Classification.features_2048 column." + ), + ) + + @pydantic.validator("features") + def _features_length(cls, v): + if v is not None and len(v) != 2048: + raise ValueError(f"features must be length 2048 to match Classification.features_2048, got {len(v)}") + return v + inference_time: float | None = None algorithm: AlgorithmReference terminal: bool = True diff --git a/ami/templates/admin/main/tracking_confirmation.html b/ami/templates/admin/main/tracking_confirmation.html new file mode 100644 index 000000000..5bd0ff57e --- /dev/null +++ b/ami/templates/admin/main/tracking_confirmation.html @@ -0,0 +1,45 @@ +{% extends "admin/base_site.html" %} + +{% load i18n admin_urls %} + +{% block breadcrumbs %} + +{% endblock breadcrumbs %} +{% block content %} +
+ {% csrf_token %} +

+ You are about to run Occurrence Tracking on {{ scope_label }}. +

+ {% if scope_summary %}

{{ scope_summary }}

{% endif %} + {% if form.non_field_errors %} + + {% endif %} +
+

Tracking parameters

+ {% for field in form %} +
+ {{ field.errors }} + {{ field.label_tag }} + {{ field }} + {% if field.help_text %}
{{ field.help_text|safe }}
{% endif %} +
+ {% endfor %} +
+ {% for obj in queryset %}{% endfor %} + + +
+ + Cancel +
+
+{% endblock content %} diff --git a/compose/local/postgres/Dockerfile b/compose/local/postgres/Dockerfile index 5f864a4a0..627e7cf98 100644 --- a/compose/local/postgres/Dockerfile +++ b/compose/local/postgres/Dockerfile @@ -1,6 +1,11 @@ FROM postgres:16 # FROM esgn/pgtuned:latest +# Install pgvector (required by ami.main migration 0084 for Classification.features_2048). +RUN apt-get update \ + && apt-get install -y --no-install-recommends postgresql-16-pgvector \ + && rm -rf /var/lib/apt/lists/* + COPY ./compose/local/postgres/maintenance /usr/local/bin/maintenance RUN chmod +x /usr/local/bin/maintenance/* RUN mv /usr/local/bin/maintenance/* /usr/local/bin \ diff --git a/requirements/base.txt b/requirements/base.txt index 089a3d80a..4b8005a0b 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -19,6 +19,7 @@ django-pydantic-field==0.3.10 sentry-sdk==1.40.4 # https://github.com/getsentry/sentry-python django-cachalot==2.6.3 numpy==2.1 +pgvector==0.3.6 # https://github.com/pgvector/pgvector-python # Django # ------------------------------------------------------------------------------