diff --git a/osf/external/internet_archive/tasks.py b/osf/external/internet_archive/tasks.py index 48cac22a28d..da04962c69d 100644 --- a/osf/external/internet_archive/tasks.py +++ b/osf/external/internet_archive/tasks.py @@ -4,10 +4,25 @@ from framework.celery_tasks import app from framework.postcommit_tasks.handlers import get_task_from_postcommit_queue, enqueue_postcommit_task from osf.utils.workflows import RegistrationModerationStates +from osf.models import RegistrationProvider +from osf.management.commands.populate_internet_archives_collections import create_ia_subcollection from website import settings +@app.task(max_retries=5, default_retry_delay=60) +def _create_ia_provider_subcollection(provider_id): + provider = RegistrationProvider.objects.get(_id=provider_id) + resp = create_ia_subcollection(provider, settings.ID_VERSION, dry_run=False) + if resp and resp.status_code not in (200, 409): + raise Exception(f'Failed to create IA subcollection for {provider_id}: {resp.status_code}') + + +def create_ia_provider_subcollection(provider): + if settings.IA_ARCHIVE_ENABLED: + enqueue_postcommit_task(_create_ia_provider_subcollection, (provider._id,), {}, celery=True) + + @app.task(max_retries=5, default_retry_delay=60, ignore_results=False) def _archive_to_ia(node_id): requests_retry_session().post(f'{settings.OSF_PIGEON_URL}archive/{node_id}') diff --git a/osf/management/commands/populate_internet_archives_collections.py b/osf/management/commands/populate_internet_archives_collections.py index 0a6c8e386ed..ffd1253c62c 100644 --- a/osf/management/commands/populate_internet_archives_collections.py +++ b/osf/management/commands/populate_internet_archives_collections.py @@ -60,7 +60,7 @@ def update_ia_subcollection(provider, version_id, dry_run): def populate_internet_archives_collections(version_id, dry_run=False): for provider in RegistrationProvider.objects.all(): resp = create_ia_subcollection(provider, version_id, dry_run) - if resp.status_code == 409: + if resp and resp.status_code == 409: update_ia_subcollection(provider, version_id, dry_run) diff --git a/osf/models/provider.py b/osf/models/provider.py index 92681173240..2aa1294ce20 100644 --- a/osf/models/provider.py +++ b/osf/models/provider.py @@ -365,6 +365,13 @@ def is_moderator(self, user): return False +@receiver(post_save, sender=RegistrationProvider) +def sync_internet_archive_collection(sender, instance, created, **kwargs): + if created: + from osf.external.internet_archive.tasks import create_ia_provider_subcollection + create_ia_provider_subcollection(instance) + + class PreprintProvider(AbstractProvider): """ Model representing a provider of preprints. diff --git a/osf/models/registrations.py b/osf/models/registrations.py index f2001175c11..472e04afc55 100644 --- a/osf/models/registrations.py +++ b/osf/models/registrations.py @@ -179,10 +179,9 @@ def find_ia_backlog(): return Registration.objects.filter( (models.Q(ia_url__isnull=True) | models.Q(ia_url='')), is_public=True, - identifiers__category='doi' ).exclude( moderation_state='withdrawn', - ) + ).distinct() @staticmethod def find_doi_backlog(): diff --git a/osf_tests/test_pigeon.py b/osf_tests/test_pigeon.py index 561b26b2eb4..e8e82548266 100644 --- a/osf_tests/test_pigeon.py +++ b/osf_tests/test_pigeon.py @@ -1,7 +1,35 @@ from unittest import mock import pytest -from osf_tests.factories import RegistrationFactory, AuthUserFactory, EmbargoFactory, NodeFactory -from osf.external.internet_archive.tasks import _archive_to_ia, _update_ia_metadata +from osf_tests.factories import RegistrationFactory, AuthUserFactory, EmbargoFactory, NodeFactory, RegistrationProviderFactory +from osf.external.internet_archive.tasks import _archive_to_ia, _update_ia_metadata, _create_ia_provider_subcollection +from osf.models import Registration + + +@pytest.mark.django_db +class TestFindIABacklog: + + def test_includes_public_registration_without_doi(self): + reg = RegistrationFactory(is_public=True) + assert not reg.identifiers.filter(category='doi').exists() + assert reg in Registration.find_ia_backlog() + + def test_includes_public_registration_with_doi(self): + reg = RegistrationFactory(is_public=True, has_doi=True) + assert reg in Registration.find_ia_backlog() + + def test_excludes_withdrawn(self): + reg = RegistrationFactory(is_public=True) + Registration.objects.filter(pk=reg.pk).update(moderation_state='withdrawn') + assert reg not in Registration.find_ia_backlog() + + def test_excludes_private(self): + reg = RegistrationFactory() + assert reg not in Registration.find_ia_backlog() + + def test_excludes_already_archived(self): + reg = RegistrationFactory(is_public=True) + Registration.objects.filter(pk=reg.pk).update(ia_url='https://archive.org/details/test') + assert reg not in Registration.find_ia_backlog() @pytest.mark.django_db @@ -79,3 +107,11 @@ def test_pigeon_archive_schema_response(self, schema_response, mock_pigeon, mock mock.call(_archive_to_ia, (schema_response.parent._id,), {}, celery=True), mock.call(_archive_to_ia, (registration_with_child.nodes[0]._id,), {}, celery=True) ], any_order=True) + + @pytest.mark.enable_enqueue_task + @pytest.mark.enable_implicit_clean + def test_new_provider_creates_ia_subcollection(self, mock_pigeon, mock_celery): + provider = RegistrationProviderFactory() + mock_celery.assert_any_call( + _create_ia_provider_subcollection, (provider._id,), {}, celery=True + )