From 8d85b2e4e8d58278f4fb94ec6b056f62c66b7f06 Mon Sep 17 00:00:00 2001 From: Antoine Lambert Date: Thu, 29 Sep 2022 11:14:08 +0200 Subject: [PATCH] pattern: Ensure accurate origin counts returned by run method Previously, the run method was returning the total count of ListedOrigin objects sent to scheduler database. However, some listers can send multiple ListedOrigin objects for a given origin URL during the listing process, for instance when an origin is contained in multiple pages (e.g. gogs listing) or when the listing is gathering multiple versions of an origin spread across multiple pages (e.g. maven listing). This changes ensures an accurate count of listed origins by maintaining a set of origin URLs associated to the sent ListedOrigin objects. --- swh/lister/arch/tests/test_lister.py | 3 ++- swh/lister/conda/tests/test_lister.py | 2 +- swh/lister/pattern.py | 19 ++++++++++++------- swh/lister/tests/test_pattern.py | 17 +++++++++++++++++ 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/swh/lister/arch/tests/test_lister.py b/swh/lister/arch/tests/test_lister.py index daa8712..fa644d3 100644 --- a/swh/lister/arch/tests/test_lister.py +++ b/swh/lister/arch/tests/test_lister.py @@ -2,6 +2,7 @@ # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information + from swh.lister.arch.lister import ArchLister expected_origins = [ @@ -1371,7 +1372,7 @@ def test_arch_lister(datadir, requests_mock_datadir, swh_scheduler): res = lister.run() assert res.pages == 9 - assert res.origins == 12 + assert res.origins == 11 scheduler_origins = swh_scheduler.get_listed_origins(lister.lister_obj.id).results diff --git a/swh/lister/conda/tests/test_lister.py b/swh/lister/conda/tests/test_lister.py index 0a67ce3..244d61a 100644 --- a/swh/lister/conda/tests/test_lister.py +++ b/swh/lister/conda/tests/test_lister.py @@ -13,7 +13,7 @@ def test_conda_lister_free_channel(datadir, requests_mock_datadir, swh_scheduler res = lister.run() assert res.pages == 3 - assert res.origins == 14 + assert res.origins == 11 def test_conda_lister_conda_forge_channel( diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py index d188896..7492683 100644 --- a/swh/lister/pattern.py +++ b/swh/lister/pattern.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass import logging -from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, TypeVar +from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Set, TypeVar from urllib.parse import urlparse import requests @@ -128,6 +128,8 @@ class Lister(Generic[StateType, PageType]): {"User-Agent": USER_AGENT_TEMPLATE % self.LISTER_NAME} ) + self.recorded_origins: Set[str] = set() + @http_retry(before_sleep=before_sleep_log(logger, logging.WARNING)) def http_request(self, url: str, method="GET", **kwargs) -> requests.Response: @@ -154,12 +156,15 @@ class Lister(Generic[StateType, PageType]): """ full_stats = ListerStats() + self.recorded_origins = set() try: for page in self.get_pages(): full_stats.pages += 1 origins = self.get_origins_from_page(page) - full_stats.origins += self.send_origins(origins) + sent_origins = self.send_origins(origins) + self.recorded_origins.update(sent_origins) + full_stats.origins = len(self.recorded_origins) self.commit_page(page) finally: self.finalize() @@ -255,18 +260,18 @@ class Lister(Generic[StateType, PageType]): """ pass - def send_origins(self, origins: Iterable[model.ListedOrigin]) -> int: + def send_origins(self, origins: Iterable[model.ListedOrigin]) -> List[str]: """Record a list of :class:`model.ListedOrigin` in the scheduler. Returns: - the number of listed origins recorded in the scheduler + the list of origin URLs recorded in scheduler database """ - count = 0 + recorded_origins = [] for batch_origins in grouper(origins, n=1000): ret = self.scheduler.record_listed_origins(batch_origins) - count += len(ret) + recorded_origins += [origin.url for origin in ret] - return count + return recorded_origins @classmethod def from_config(cls, scheduler: Dict[str, Any], **config: Any): diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py index 192f8f7..554a8d1 100644 --- a/swh/lister/tests/test_pattern.py +++ b/swh/lister/tests/test_pattern.py @@ -198,3 +198,20 @@ def test_stateless_run(swh_scheduler): # And that all origins are stored check_listed_origins(swh_scheduler, lister, stored_lister) + + +class ListerWithSameOriginInMultiplePages(RunnableStatelessLister): + def get_pages(self) -> Iterator[PageType]: + for _ in range(2): + yield [{"url": "https://example.org/user/project"}] + + +def test_listed_origins_count(swh_scheduler): + lister = ListerWithSameOriginInMultiplePages( + scheduler=swh_scheduler, url="https://example.org", instance="example.org" + ) + + run_result = lister.run() + + assert run_result.pages == 2 + assert run_result.origins == 1