Add a flag to not enable origins listed by a lister

This cuts down one more manual step in the add forge now validation
process: we can add the relevant origins to the staging scheduler
without enabling them at all.
This commit is contained in:
Nicolas Dandrimont 2022-12-05 14:20:31 +01:00
parent b815737054
commit 64267f8f50
2 changed files with 43 additions and 0 deletions

View file

@ -10,6 +10,7 @@ import logging
from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Set, TypeVar
from urllib.parse import urlparse
import attr
import requests
from tenacity.before_sleep import before_sleep_log
@ -86,6 +87,7 @@ class Lister(Generic[StateType, PageType]):
expected credentials for the given instance of that lister.
max_pages: the maximum number of pages listed in a full listing operation
max_origins_per_page: the maximum number of origins processed per page
enable_origins: whether the created origins should be enabled or not
Generic types:
- *StateType*: concrete lister type; should usually be a :class:`dataclass` for
@ -106,6 +108,7 @@ class Lister(Generic[StateType, PageType]):
credentials: CredentialsType = None,
max_origins_per_page: Optional[int] = None,
max_pages: Optional[int] = None,
enable_origins: bool = True,
with_github_session: bool = False,
):
if not self.LISTER_NAME:
@ -146,6 +149,7 @@ class Lister(Generic[StateType, PageType]):
self.recorded_origins: Set[str] = set()
self.max_pages = max_pages
self.max_origins_per_page = max_origins_per_page
self.enable_origins = enable_origins
@http_retry(before_sleep=before_sleep_log(logger, logging.WARNING))
def http_request(self, url: str, method="GET", **kwargs) -> requests.Response:
@ -189,6 +193,11 @@ class Lister(Generic[StateType, PageType]):
self.max_origins_per_page,
)
origins = origins[: self.max_origins_per_page]
if not self.enable_origins:
logger.info(
"Disabling origins before sending them to the scheduler"
)
origins = [attr.evolve(origin, enabled=False) for origin in origins]
sent_origins = self.send_origins(origins)
self.recorded_origins.update(sent_origins)
full_stats.origins = len(self.recorded_origins)