Add save-bulk lister to check origins prior their insertion in database
This new and special lister enables to verify a list of origins to archive provided by users (for instance through the Web API). Its purpose is to avoid polluting the scheduler database with origins that cannot be loaded into the archive. Each origin is identified by an URL and a visit type. For a given visit type the lister is checking if the origin URL can be found and if the visit type is valid. The supported visit types are those for VCS (bzr, cvs, hg, git and svn) plus the one for loading a tarball content into the archive. Accepted origins are inserted or upserted in the scheduler database. Rejected origins are stored in the lister state. Related to #4709
This commit is contained in:
parent
6618cf341c
commit
af24960bc2
10 changed files with 763 additions and 0 deletions
13
swh/lister/save_bulk/__init__.py
Normal file
13
swh/lister/save_bulk/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright (C) 2024 The Software Heritage developers
|
||||
# See the AUTHORS file at the top-level directory of this distribution
|
||||
# License: GNU General Public License version 3, or any later version
|
||||
# See top-level LICENSE file for more information
|
||||
|
||||
|
||||
def register():
|
||||
from .lister import SaveBulkLister
|
||||
|
||||
return {
|
||||
"lister": SaveBulkLister,
|
||||
"task_modules": [f"{__name__}.tasks"],
|
||||
}
|
416
swh/lister/save_bulk/lister.py
Normal file
416
swh/lister/save_bulk/lister.py
Normal file
|
@ -0,0 +1,416 @@
|
|||
# Copyright (C) 2024 The Software Heritage developers
|
||||
# See the AUTHORS file at the top-level directory of this distribution
|
||||
# License: GNU General Public License version 3, or any later version
|
||||
# See top-level LICENSE file for more information
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
import socket
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypedDict
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
from breezy.builtins import cmd_info
|
||||
from dulwich.porcelain import ls_remote
|
||||
from mercurial import hg, ui
|
||||
from requests import ConnectionError, RequestException
|
||||
from subvertpy import SubversionException, client
|
||||
from subvertpy.ra import Auth, get_username_provider
|
||||
|
||||
from swh.lister.utils import is_tarball
|
||||
from swh.scheduler.interface import SchedulerInterface
|
||||
from swh.scheduler.model import ListedOrigin
|
||||
|
||||
from ..pattern import CredentialsType, Lister
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log_invalid_origin_type_for_url(
|
||||
origin_url: str, origin_type: str, err_msg: Optional[str] = None
|
||||
):
|
||||
msg = f"Origin URL {origin_url} does not target a {origin_type}."
|
||||
if err_msg:
|
||||
msg += f"\nError details: {err_msg}"
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def is_valid_tarball_url(origin_url: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Checks if an URL targets a tarball using a set of heuritiscs.
|
||||
|
||||
Args:
|
||||
origin_url: The URL to check
|
||||
|
||||
Returns:
|
||||
a tuple whose first member indicates if the URL targets a tarball and
|
||||
second member holds an optional error message if check failed
|
||||
"""
|
||||
exc_str = None
|
||||
try:
|
||||
ret, _ = is_tarball([origin_url])
|
||||
except Exception as e:
|
||||
ret = False
|
||||
exc_str = str(e)
|
||||
if not ret:
|
||||
_log_invalid_origin_type_for_url(origin_url, "tarball", exc_str)
|
||||
return ret, exc_str
|
||||
|
||||
|
||||
def is_valid_git_url(origin_url: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if an URL targets a public git repository by attempting to list
|
||||
its remote refs.
|
||||
|
||||
Args:
|
||||
origin_url: The URL to check
|
||||
|
||||
Returns:
|
||||
a tuple whose first member indicates if the URL targets a public git
|
||||
repository and second member holds an error message if check failed
|
||||
"""
|
||||
try:
|
||||
ls_remote(origin_url)
|
||||
except Exception as e:
|
||||
exc_str = str(e)
|
||||
_log_invalid_origin_type_for_url(origin_url, "public git repository", exc_str)
|
||||
return False, exc_str
|
||||
else:
|
||||
return True, None
|
||||
|
||||
|
||||
def is_valid_svn_url(origin_url: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if an URL targets a public subversion repository by attempting to get
|
||||
repository information.
|
||||
|
||||
Args:
|
||||
origin_url: The URL to check
|
||||
|
||||
Returns:
|
||||
a tuple whose first member indicates if the URL targets a public subversion
|
||||
repository and second member holds an error message if check failed
|
||||
"""
|
||||
svn_client = client.Client(auth=Auth([get_username_provider()]))
|
||||
try:
|
||||
svn_client.info(quote(origin_url, safe="/:!$&'()*+,=@").rstrip("/"))
|
||||
except SubversionException as e:
|
||||
exc_str = str(e)
|
||||
_log_invalid_origin_type_for_url(
|
||||
origin_url, "public subversion repository", exc_str
|
||||
)
|
||||
return False, exc_str
|
||||
else:
|
||||
return True, None
|
||||
|
||||
|
||||
def is_valid_hg_url(origin_url: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if an URL targets a public mercurial repository by attempting to connect
|
||||
to the remote repository.
|
||||
|
||||
Args:
|
||||
origin_url: The URL to check
|
||||
|
||||
Returns:
|
||||
a tuple whose first member indicates if the URL targets a public mercurial
|
||||
repository and second member holds an error message if check failed
|
||||
"""
|
||||
hgui = ui.ui()
|
||||
hgui.setconfig(b"ui", b"interactive", False)
|
||||
try:
|
||||
hg.peer(hgui, {}, origin_url.encode())
|
||||
except Exception as e:
|
||||
exc_str = str(e)
|
||||
_log_invalid_origin_type_for_url(
|
||||
origin_url, "public mercurial repository", exc_str
|
||||
)
|
||||
return False, exc_str
|
||||
else:
|
||||
return True, None
|
||||
|
||||
|
||||
def is_valid_bzr_url(origin_url: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if an URL targets a public bazaar repository by attempting to get
|
||||
repository information.
|
||||
|
||||
Args:
|
||||
origin_url: The URL to check
|
||||
|
||||
Returns:
|
||||
a tuple whose first member indicates if the URL targets a public bazaar
|
||||
repository and second member holds an error message if check failed
|
||||
"""
|
||||
try:
|
||||
cmd_info().run_argv_aliases([origin_url])
|
||||
except Exception as e:
|
||||
exc_str = str(e)
|
||||
_log_invalid_origin_type_for_url(
|
||||
origin_url, "public bazaar repository", exc_str
|
||||
)
|
||||
return False, exc_str
|
||||
else:
|
||||
return True, None
|
||||
|
||||
|
||||
def is_valid_cvs_url(origin_url: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if an URL matches one of the formats expected by the CVS loader of
|
||||
Software Heritage.
|
||||
|
||||
Args:
|
||||
origin_url: The URL to check
|
||||
|
||||
Returns:
|
||||
a tuple whose first member indicates if the URL matches one of the formats
|
||||
expected by the CVS loader and second member holds an error message if
|
||||
check failed.
|
||||
"""
|
||||
err_msg = None
|
||||
rsync_url_format = "rsync://<hostname>[.*/]<project_name>/<module_name>"
|
||||
pserver_url_format = (
|
||||
"pserver://<usernmame>@<hostname>[.*/]<project_name>/<module_name>"
|
||||
)
|
||||
err_msg_prefix = (
|
||||
"The origin URL for the CVS repository is malformed, it should match"
|
||||
)
|
||||
|
||||
parsed_url = urlparse(origin_url)
|
||||
ret = (
|
||||
parsed_url.scheme in ("rsync", "pserver")
|
||||
and len(parsed_url.path.strip("/").split("/")) >= 2
|
||||
)
|
||||
if parsed_url.scheme == "rsync":
|
||||
if not ret:
|
||||
err_msg = f"{err_msg_prefix} '{rsync_url_format}'"
|
||||
elif parsed_url.scheme == "pserver":
|
||||
ret = ret and parsed_url.username is not None
|
||||
if not ret:
|
||||
err_msg = f"{err_msg_prefix} '{pserver_url_format}'"
|
||||
else:
|
||||
err_msg = f"{err_msg_prefix} '{rsync_url_format}' or '{pserver_url_format}'"
|
||||
|
||||
if not ret:
|
||||
_log_invalid_origin_type_for_url(origin_url, "CVS", err_msg)
|
||||
|
||||
return ret, err_msg
|
||||
|
||||
|
||||
CONNECTION_ERROR = "A connection error occurred when requesting origin URL."
|
||||
HTTP_ERROR = "An HTTP error occurred when requesting origin URL"
|
||||
HOSTNAME_ERROR = "The hostname could not be resolved."
|
||||
|
||||
|
||||
VISIT_TYPE_ERROR: Dict[str, str] = {
|
||||
"tarball-directory": "The origin URL does not target a tarball.",
|
||||
"git": "The origin URL does not target a public git repository.",
|
||||
"svn": "The origin URL does not target a public subversion repository.",
|
||||
"hg": "The origin URL does not target a public mercurial repository.",
|
||||
"bzr": "The origin URL does not target a public bazaar repository.",
|
||||
"cvs": "The origin URL does not target a public CVS repository.",
|
||||
}
|
||||
|
||||
|
||||
class SubmittedOrigin(TypedDict):
|
||||
origin_url: str
|
||||
visit_type: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RejectedOrigin:
|
||||
origin_url: str
|
||||
visit_type: str
|
||||
reason: str
|
||||
exception: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveBulkListerState:
|
||||
"""Stored lister state"""
|
||||
|
||||
rejected_origins: List[RejectedOrigin] = field(default_factory=list)
|
||||
"""
|
||||
List of origins rejected by the lister.
|
||||
"""
|
||||
|
||||
|
||||
SaveBulkListerPage = List[SubmittedOrigin]
|
||||
|
||||
|
||||
class SaveBulkLister(Lister[SaveBulkListerState, SaveBulkListerPage]):
|
||||
"""The save-bulk lister enables to verify a list of origins to archive provided
|
||||
by an HTTP endpoint. Its purpose is to avoid polluting the scheduler database with
|
||||
origins that cannot be loaded into the archive.
|
||||
|
||||
Each origin is identified by an URL and a visit type. For a given visit type the
|
||||
lister is checking if the origin URL can be found and if the visit type is valid.
|
||||
|
||||
The HTTP endpoint must return an origins list in a paginated way through the use
|
||||
of two integer query parameters: ``page`` indicates the page to fetch and `per_page`
|
||||
corresponds the number of origins in a page.
|
||||
The endpoint must return a JSON list in the following format:
|
||||
|
||||
.. code-block:: JSON
|
||||
|
||||
[
|
||||
{
|
||||
"origin_url": "https://git.example.org/user/project",
|
||||
"visit_type": "git"
|
||||
},
|
||||
{
|
||||
"origin_url": "https://example.org/downloads/project.tar.gz",
|
||||
"visit_type": "tarball-directory"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
The supported visit types are those for VCS (``bzr``, ``cvs``, ``hg``, ``git``
|
||||
and ``svn``) plus the one for loading a tarball content into the archive
|
||||
(``tarball-directory``).
|
||||
|
||||
Accepted origins are inserted or upserted in the scheduler database.
|
||||
|
||||
Rejected origins are stored in the lister state.
|
||||
"""
|
||||
|
||||
LISTER_NAME = "save-bulk"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
instance: str,
|
||||
scheduler: SchedulerInterface,
|
||||
credentials: Optional[CredentialsType] = None,
|
||||
max_origins_per_page: Optional[int] = None,
|
||||
max_pages: Optional[int] = None,
|
||||
enable_origins: bool = True,
|
||||
per_page: int = 1000,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler=scheduler,
|
||||
credentials=credentials,
|
||||
url=url,
|
||||
instance=instance,
|
||||
max_origins_per_page=max_origins_per_page,
|
||||
max_pages=max_pages,
|
||||
enable_origins=enable_origins,
|
||||
)
|
||||
self.rejected_origins: Set[RejectedOrigin] = set()
|
||||
self.per_page = per_page
|
||||
|
||||
def state_from_dict(self, d: Dict[str, Any]) -> SaveBulkListerState:
|
||||
return SaveBulkListerState(
|
||||
rejected_origins=[
|
||||
RejectedOrigin(**rej) for rej in d.get("rejected_origins", [])
|
||||
]
|
||||
)
|
||||
|
||||
def state_to_dict(self, state: SaveBulkListerState) -> Dict[str, Any]:
|
||||
return {"rejected_origins": [asdict(rej) for rej in state.rejected_origins]}
|
||||
|
||||
def get_pages(self) -> Iterator[SaveBulkListerPage]:
|
||||
current_page = 1
|
||||
origins = self.session.get(
|
||||
self.url, params={"page": current_page, "per_page": self.per_page}
|
||||
).json()
|
||||
while origins:
|
||||
yield origins
|
||||
current_page += 1
|
||||
origins = self.session.get(
|
||||
self.url, params={"page": current_page, "per_page": self.per_page}
|
||||
).json()
|
||||
|
||||
def get_origins_from_page(
|
||||
self, origins: SaveBulkListerPage
|
||||
) -> Iterator[ListedOrigin]:
|
||||
assert self.lister_obj.id is not None
|
||||
|
||||
for origin in origins:
|
||||
origin_url = origin["origin_url"]
|
||||
visit_type = origin["visit_type"]
|
||||
|
||||
logger.info(
|
||||
"Checking origin URL %s for visit type %s.", origin_url, visit_type
|
||||
)
|
||||
|
||||
rejection_details = None
|
||||
rejection_exception = None
|
||||
|
||||
parsed_url = urlparse(origin_url)
|
||||
if rejection_details is None:
|
||||
if parsed_url.scheme in ("http", "https"):
|
||||
try:
|
||||
response = self.session.head(origin_url, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
except ConnectionError as e:
|
||||
logger.info(
|
||||
"A connection error occurred when requesting %s.",
|
||||
origin_url,
|
||||
)
|
||||
rejection_details = CONNECTION_ERROR
|
||||
rejection_exception = str(e)
|
||||
except RequestException as e:
|
||||
if e.response is not None:
|
||||
status = e.response.status_code
|
||||
status_str = f"{status} - {HTTPStatus(status).phrase}"
|
||||
logger.info(
|
||||
"An HTTP error occurred when requesting %s: %s",
|
||||
origin_url,
|
||||
status_str,
|
||||
)
|
||||
rejection_details = f"{HTTP_ERROR}: {status_str}"
|
||||
else:
|
||||
logger.info(
|
||||
"An HTTP error occurred when requesting %s.",
|
||||
origin_url,
|
||||
)
|
||||
rejection_details = f"{HTTP_ERROR}."
|
||||
rejection_exception = str(e)
|
||||
else:
|
||||
try:
|
||||
socket.getaddrinfo(parsed_url.netloc, port=None)
|
||||
except OSError as e:
|
||||
logger.info(
|
||||
"Host name %s could not be resolved.", parsed_url.netloc
|
||||
)
|
||||
rejection_details = HOSTNAME_ERROR
|
||||
rejection_exception = str(e)
|
||||
|
||||
if rejection_details is None:
|
||||
visit_type_check_url = globals().get(
|
||||
f"is_valid_{visit_type.split('-', 1)[0]}_url"
|
||||
)
|
||||
if visit_type_check_url:
|
||||
url_valid, rejection_exception = visit_type_check_url(origin_url)
|
||||
if not url_valid:
|
||||
rejection_details = VISIT_TYPE_ERROR[visit_type]
|
||||
else:
|
||||
rejection_details = (
|
||||
f"Visit type {visit_type} is not supported "
|
||||
"for bulk on-demand archival."
|
||||
)
|
||||
logger.info(
|
||||
"Visit type %s for origin URL %s is not supported",
|
||||
visit_type,
|
||||
origin_url,
|
||||
)
|
||||
|
||||
if rejection_details is None:
|
||||
yield ListedOrigin(
|
||||
lister_id=self.lister_obj.id,
|
||||
url=origin["origin_url"],
|
||||
visit_type=origin["visit_type"],
|
||||
extra_loader_arguments=(
|
||||
{"checksum_layout": "standard", "checksums": {}}
|
||||
if origin["visit_type"] == "tarball-directory"
|
||||
else {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.rejected_origins.add(
|
||||
RejectedOrigin(
|
||||
origin_url=origin_url,
|
||||
visit_type=visit_type,
|
||||
reason=rejection_details,
|
||||
exception=rejection_exception,
|
||||
)
|
||||
)
|
||||
# update scheduler state at each rejected origin to get feedback
|
||||
# using Web API before end of listing
|
||||
self.state.rejected_origins = list(self.rejected_origins)
|
||||
self.set_state_in_scheduler()
|
19
swh/lister/save_bulk/tasks.py
Normal file
19
swh/lister/save_bulk/tasks.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (C) 2024 The Software Heritage developers
|
||||
# See the AUTHORS file at the top-level directory of this distribution
|
||||
# License: GNU General Public License version 3, or any later version
|
||||
# See top-level LICENSE file for more information
|
||||
|
||||
from celery import shared_task
|
||||
|
||||
from swh.lister.save_bulk.lister import SaveBulkLister
|
||||
|
||||
|
||||
@shared_task(name=__name__ + ".SaveBulkListerTask")
|
||||
def list_save_bulk(**kwargs):
|
||||
"""Task for save-bulk lister"""
|
||||
return SaveBulkLister.from_configfile(**kwargs).run().dict()
|
||||
|
||||
|
||||
@shared_task(name=__name__ + ".ping")
|
||||
def _ping():
|
||||
return "OK"
|
0
swh/lister/save_bulk/tests/__init__.py
Normal file
0
swh/lister/save_bulk/tests/__init__.py
Normal file
263
swh/lister/save_bulk/tests/test_lister.py
Normal file
263
swh/lister/save_bulk/tests/test_lister.py
Normal file
|
@ -0,0 +1,263 @@
|
|||
# Copyright (C) 2024 The Software Heritage developers
|
||||
# See the AUTHORS file at the top-level directory of this distribution
|
||||
# License: GNU General Public License version 3, or any later version
|
||||
# See top-level LICENSE file for more information
|
||||
|
||||
from operator import attrgetter, itemgetter
|
||||
import re
|
||||
import string
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from swh.lister.pattern import ListerStats
|
||||
from swh.lister.save_bulk.lister import (
|
||||
CONNECTION_ERROR,
|
||||
HOSTNAME_ERROR,
|
||||
HTTP_ERROR,
|
||||
VISIT_TYPE_ERROR,
|
||||
RejectedOrigin,
|
||||
SaveBulkLister,
|
||||
SubmittedOrigin,
|
||||
is_valid_cvs_url,
|
||||
)
|
||||
|
||||
URL = "https://example.org/origins/list/"
|
||||
INSTANCE = "some-instance"
|
||||
|
||||
PER_PAGE = 2
|
||||
|
||||
SUBMITTED_ORIGINS = [
|
||||
SubmittedOrigin(origin_url=origin_url, visit_type=visit_type)
|
||||
for origin_url, visit_type in [
|
||||
("https://example.org/download/tarball.tar.gz", "tarball-directory"),
|
||||
("https://git.example.org/user/project.git", "git"),
|
||||
("https://svn.example.org/project/trunk", "svn"),
|
||||
("https://hg.example.org/projects/test", "hg"),
|
||||
("https://bzr.example.org/projects/test", "bzr"),
|
||||
("rsync://cvs.example.org/cvsroot/project/module", "cvs"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def origins_list_requests_mock(requests_mock):
|
||||
nb_pages = len(SUBMITTED_ORIGINS) // PER_PAGE
|
||||
for i in range(nb_pages):
|
||||
requests_mock.get(
|
||||
f"{URL}?page={i+1}&per_page={PER_PAGE}",
|
||||
json=SUBMITTED_ORIGINS[i * PER_PAGE : (i + 1) * PER_PAGE],
|
||||
)
|
||||
requests_mock.get(
|
||||
f"{URL}?page={nb_pages+1}&per_page={PER_PAGE}",
|
||||
json=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"valid_cvs_url",
|
||||
[
|
||||
"rsync://cvs.example.org/project/module",
|
||||
"pserver://anonymous@cvs.example.org/project/module",
|
||||
],
|
||||
)
|
||||
def test_is_valid_cvs_url_success(valid_cvs_url):
|
||||
assert is_valid_cvs_url(valid_cvs_url) == (True, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_cvs_url",
|
||||
[
|
||||
"rsync://cvs.example.org/project",
|
||||
"pserver://anonymous@cvs.example.org/project",
|
||||
"pserver://cvs.example.org/project/module",
|
||||
"http://cvs.example.org/project/module",
|
||||
],
|
||||
)
|
||||
def test_is_valid_cvs_url_failure(invalid_cvs_url):
|
||||
err_msg_prefix = "The origin URL for the CVS repository is malformed"
|
||||
ret, err_msg = is_valid_cvs_url(invalid_cvs_url)
|
||||
assert not ret and err_msg.startswith(err_msg_prefix)
|
||||
|
||||
|
||||
def test_bulk_lister_valid_origins(swh_scheduler, requests_mock, mocker):
|
||||
requests_mock.head(re.compile(".*"), status_code=200)
|
||||
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [
|
||||
("125.25.14.15", 0)
|
||||
]
|
||||
for origin in SUBMITTED_ORIGINS:
|
||||
visit_type = origin["visit_type"].split("-", 1)[0]
|
||||
mocker.patch(
|
||||
f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url"
|
||||
).return_value = (True, None)
|
||||
|
||||
lister_bulk = SaveBulkLister(
|
||||
url=URL,
|
||||
instance=INSTANCE,
|
||||
scheduler=swh_scheduler,
|
||||
per_page=PER_PAGE,
|
||||
)
|
||||
stats = lister_bulk.run()
|
||||
|
||||
expected_nb_origins = len(SUBMITTED_ORIGINS)
|
||||
assert stats == ListerStats(
|
||||
pages=expected_nb_origins // PER_PAGE, origins=expected_nb_origins
|
||||
)
|
||||
|
||||
state = lister_bulk.get_state_from_scheduler()
|
||||
|
||||
assert sorted(
|
||||
[
|
||||
SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type)
|
||||
for origin in swh_scheduler.get_listed_origins(
|
||||
lister_bulk.lister_obj.id
|
||||
).results
|
||||
],
|
||||
key=itemgetter("visit_type"),
|
||||
) == sorted(SUBMITTED_ORIGINS, key=itemgetter("visit_type"))
|
||||
assert state.rejected_origins == []
|
||||
|
||||
|
||||
def test_bulk_lister_not_found_origins(swh_scheduler, requests_mock, mocker):
|
||||
requests_mock.head(re.compile(".*"), status_code=404)
|
||||
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = (
|
||||
OSError("Hostname not found")
|
||||
)
|
||||
|
||||
lister_bulk = SaveBulkLister(
|
||||
url=URL,
|
||||
instance=INSTANCE,
|
||||
scheduler=swh_scheduler,
|
||||
per_page=PER_PAGE,
|
||||
)
|
||||
stats = lister_bulk.run()
|
||||
|
||||
assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0)
|
||||
|
||||
state = lister_bulk.get_state_from_scheduler()
|
||||
|
||||
assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list(
|
||||
sorted(
|
||||
[
|
||||
RejectedOrigin(
|
||||
origin_url=o["origin_url"],
|
||||
visit_type=o["visit_type"],
|
||||
reason=(
|
||||
HTTP_ERROR + ": 404 - Not Found"
|
||||
if o["origin_url"].startswith("http")
|
||||
else HOSTNAME_ERROR
|
||||
),
|
||||
exception=(
|
||||
f"404 Client Error: None for url: {o['origin_url']}"
|
||||
if o["origin_url"].startswith("http")
|
||||
else "Hostname not found"
|
||||
),
|
||||
)
|
||||
for o in SUBMITTED_ORIGINS
|
||||
],
|
||||
key=attrgetter("origin_url"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_bulk_lister_connection_errors(swh_scheduler, requests_mock, mocker):
|
||||
requests_mock.head(
|
||||
re.compile(".*"),
|
||||
exc=requests.exceptions.ConnectionError("connection error"),
|
||||
)
|
||||
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = (
|
||||
OSError("Hostname not found")
|
||||
)
|
||||
|
||||
lister_bulk = SaveBulkLister(
|
||||
url=URL,
|
||||
instance=INSTANCE,
|
||||
scheduler=swh_scheduler,
|
||||
per_page=PER_PAGE,
|
||||
)
|
||||
stats = lister_bulk.run()
|
||||
|
||||
assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0)
|
||||
|
||||
state = lister_bulk.get_state_from_scheduler()
|
||||
|
||||
assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list(
|
||||
sorted(
|
||||
[
|
||||
RejectedOrigin(
|
||||
origin_url=o["origin_url"],
|
||||
visit_type=o["visit_type"],
|
||||
reason=(
|
||||
CONNECTION_ERROR
|
||||
if o["origin_url"].startswith("http")
|
||||
else HOSTNAME_ERROR
|
||||
),
|
||||
exception=(
|
||||
"connection error"
|
||||
if o["origin_url"].startswith("http")
|
||||
else "Hostname not found"
|
||||
),
|
||||
)
|
||||
for o in SUBMITTED_ORIGINS
|
||||
],
|
||||
key=attrgetter("origin_url"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_bulk_lister_invalid_origins(swh_scheduler, requests_mock, mocker):
|
||||
requests_mock.head(re.compile(".*"), status_code=200)
|
||||
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [
|
||||
("125.25.14.15", 0)
|
||||
]
|
||||
|
||||
exc_msg_template = string.Template(
|
||||
"error: the origin url does not target a public $visit_type repository."
|
||||
)
|
||||
for origin in SUBMITTED_ORIGINS:
|
||||
visit_type = origin["visit_type"].split("-", 1)[0]
|
||||
visit_type_check = mocker.patch(
|
||||
f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url"
|
||||
)
|
||||
if visit_type == "tarball":
|
||||
visit_type_check.return_value = (True, None)
|
||||
else:
|
||||
visit_type_check.return_value = (
|
||||
False,
|
||||
exc_msg_template.substitute(visit_type=visit_type),
|
||||
)
|
||||
|
||||
lister_bulk = SaveBulkLister(
|
||||
url=URL,
|
||||
instance=INSTANCE,
|
||||
scheduler=swh_scheduler,
|
||||
per_page=PER_PAGE,
|
||||
)
|
||||
stats = lister_bulk.run()
|
||||
|
||||
assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=1)
|
||||
|
||||
assert [
|
||||
SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type)
|
||||
for origin in swh_scheduler.get_listed_origins(
|
||||
lister_bulk.lister_obj.id
|
||||
).results
|
||||
] == [SUBMITTED_ORIGINS[0]]
|
||||
|
||||
state = lister_bulk.get_state_from_scheduler()
|
||||
|
||||
assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list(
|
||||
sorted(
|
||||
[
|
||||
RejectedOrigin(
|
||||
origin_url=o["origin_url"],
|
||||
visit_type=o["visit_type"],
|
||||
reason=VISIT_TYPE_ERROR[o["visit_type"]],
|
||||
exception=exc_msg_template.substitute(visit_type=o["visit_type"]),
|
||||
)
|
||||
for o in SUBMITTED_ORIGINS
|
||||
if o["visit_type"] != "tarball-directory"
|
||||
],
|
||||
key=attrgetter("origin_url"),
|
||||
)
|
||||
)
|
38
swh/lister/save_bulk/tests/test_tasks.py
Normal file
38
swh/lister/save_bulk/tests/test_tasks.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
# Copyright (C) 2024 The Software Heritage developers
|
||||
# See the AUTHORS file at the top-level directory of this distribution
|
||||
# License: GNU General Public License version 3, or any later version
|
||||
# See top-level LICENSE file for more information
|
||||
|
||||
from swh.lister.pattern import ListerStats
|
||||
|
||||
|
||||
def test_save_bulk_ping(swh_scheduler_celery_app, swh_scheduler_celery_worker):
|
||||
res = swh_scheduler_celery_app.send_task("swh.lister.save_bulk.tasks.ping")
|
||||
assert res
|
||||
res.wait()
|
||||
assert res.successful()
|
||||
assert res.result == "OK"
|
||||
|
||||
|
||||
def test_save_bulk_lister_task(
|
||||
swh_scheduler_celery_app, swh_scheduler_celery_worker, mocker
|
||||
):
|
||||
lister = mocker.patch("swh.lister.save_bulk.tasks.SaveBulkLister")
|
||||
lister.from_configfile.return_value = lister
|
||||
lister.run.return_value = ListerStats(pages=1, origins=2)
|
||||
|
||||
kwargs = dict(
|
||||
url="https://example.org/origins/list/",
|
||||
instance="some-instance",
|
||||
)
|
||||
|
||||
res = swh_scheduler_celery_app.send_task(
|
||||
"swh.lister.save_bulk.tasks.SaveBulkListerTask",
|
||||
kwargs=kwargs,
|
||||
)
|
||||
assert res
|
||||
res.wait()
|
||||
assert res.successful()
|
||||
|
||||
lister.from_configfile.assert_called_once_with(**kwargs)
|
||||
lister.run.assert_called_once()
|
|
@ -49,6 +49,10 @@ lister_args = {
|
|||
"stagit": {
|
||||
"url": "https://git.codemadness.org",
|
||||
},
|
||||
"save-bulk": {
|
||||
"url": "https://example.org/origins/list/",
|
||||
"instance": "example.org",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue