Reimplement the GitHub lister using the new pattern class

This replaces the test data with some manually generated answers, which allows
us to test a few more cases for instantiating the lister.

This also expands test coverage to test behavior on rate-limited requests.
This commit is contained in:
Nicolas Dandrimont 2020-12-09 18:15:28 +01:00
parent f1eabc5283
commit b63aa83b41
10 changed files with 710 additions and 13590 deletions

View file

@ -5,10 +5,9 @@
def register():
from .lister import GitHubLister
from .models import GitHubModel
return {
"models": [GitHubModel],
"models": [],
"lister": GitHubLister,
"task_modules": ["%s.tasks" % __name__],
}

View file

@ -1,74 +1,280 @@
# Copyright (C) 2017-2020 The Software Heritage developers
# Copyright (C) 2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import re
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import asdict, dataclass
import datetime
import logging
import random
import time
from typing import Any, Dict, Iterator, List, Optional
from urllib.parse import parse_qs, urlparse
from requests import Response
import iso8601
import requests
from swh.lister.core.indexing_lister import IndexingHttpLister
from swh.lister.github.models import GitHubModel
from swh.scheduler.interface import SchedulerInterface
from swh.scheduler.model import ListedOrigin
from .. import USER_AGENT
from ..pattern import CredentialsType, Lister
logger = logging.getLogger(__name__)
class GitHubLister(IndexingHttpLister):
PATH_TEMPLATE = "/repositories?since=%d"
MODEL = GitHubModel
DEFAULT_URL = "https://api.github.com"
API_URL_INDEX_RE = re.compile(r"^.*/repositories\?since=(\d+)")
@dataclass
class GitHubListerState:
"""State of the GitHub lister"""
last_seen_id: int = 0
"""Numeric id of the last repository listed on an incremental pass"""
class GitHubLister(Lister[GitHubListerState, List[Dict[str, Any]]]):
"""List origins from GitHub.
By default, the lister runs in incremental mode: it lists all repositories,
starting with the `last_seen_id` stored in the scheduler backend.
Providing the `first_id` and `last_id` arguments enables the "relisting" mode: in
that mode, the lister finds the origins present in the range **excluding**
`first_id` and **including** `last_id`. In this mode, the lister can overrun the
`last_id`: it will always record all the origins seen in a given page. As the lister
is fully idempotent, this is not a practical problem. Once relisting completes, the
lister state in the scheduler backend is not updated.
When the config contains a set of credentials, we shuffle this list at the beginning
of the listing. To follow GitHub's `abuse rate limit policy`_, we keep using the
same token over and over again, until its rate limit runs out. Once that happens, we
switch to the next token over in our shuffled list.
When a request fails with a rate limit exception for all tokens, we pause the
listing until the largest value for X-Ratelimit-Reset over all tokens.
When the credentials aren't set in the lister config, the lister can run in
anonymous mode too (e.g. for testing purposes).
.. _abuse rate limit policy: https://developer.github.com/v3/guides/best-practices-for-integrators/#dealing-with-abuse-rate-limits
Args:
first_id: the id of the first repo to list
last_id: stop listing after seeing a repo with an id higher than this value.
""" # noqa: E501
LISTER_NAME = "github"
instance = "github" # There is only 1 instance of such lister
default_min_bound = 0 # type: Any
def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]:
return {
"uid": repo["id"],
"indexable": repo["id"],
"name": repo["name"],
"full_name": repo["full_name"],
"html_url": repo["html_url"],
"origin_url": repo["html_url"],
"origin_type": "git",
"fork": repo["fork"],
}
API_URL = "https://api.github.com/repositories"
PAGE_SIZE = 1000
def transport_quota_check(self, response: Response) -> Tuple[bool, int]:
x_rate_limit_remaining = response.headers.get("X-RateLimit-Remaining")
if not x_rate_limit_remaining:
return False, 0
reqs_remaining = int(x_rate_limit_remaining)
if response.status_code == 403 and reqs_remaining == 0:
delay = int(response.headers["Retry-After"])
return True, delay
return False, 0
def __init__(
self,
scheduler: SchedulerInterface,
credentials: CredentialsType = None,
first_id: Optional[int] = None,
last_id: Optional[int] = None,
):
super().__init__(
scheduler=scheduler,
credentials=credentials,
url=self.API_URL,
instance="github",
)
def get_next_target_from_response(self, response: Response) -> Optional[int]:
if "next" in response.links:
self.first_id = first_id
self.last_id = last_id
self.relisting = self.first_id is not None or self.last_id is not None
self.session = requests.Session()
self.session.headers.update(
{"Accept": "application/vnd.github.v3+json", "User-Agent": USER_AGENT}
)
random.shuffle(self.credentials)
self.anonymous = not self.credentials
if self.anonymous:
logger.warning("No tokens set in configuration, using anonymous mode")
self.token_index = -1
self.current_user: Optional[str] = None
if not self.anonymous:
# Initialize the first token value in the session headers
self.set_next_session_token()
def set_next_session_token(self) -> None:
"""Update the current authentication token with the next one in line."""
self.token_index = (self.token_index + 1) % len(self.credentials)
auth = self.credentials[self.token_index]
if "password" in auth:
token = auth["password"]
else:
token = auth["token"]
self.current_user = auth["username"]
logger.debug("Using authentication token for user %s", self.current_user)
self.session.headers.update({"Authorization": f"token {token}"})
def state_from_dict(self, d: Dict[str, Any]) -> GitHubListerState:
return GitHubListerState(**d)
def state_to_dict(self, state: GitHubListerState) -> Dict[str, Any]:
return asdict(state)
def get_pages(self) -> Iterator[List[Dict[str, Any]]]:
current_id = 0
if self.first_id is not None:
current_id = self.first_id
elif self.state is not None:
current_id = self.state.last_seen_id
current_url = f"{self.API_URL}?since={current_id}&per_page={self.PAGE_SIZE}"
while self.last_id is None or current_id < self.last_id:
logger.debug("Getting page %s", current_url)
# The following for/else loop handles rate limiting; if successful,
# it provides the rest of the function with a `response` object.
#
# If all tokens are rate-limited, we sleep until the reset time,
# then `continue` into another iteration of the outer while loop,
# attempting to get data from the same URL again.
max_attempts = 1 if self.anonymous else len(self.credentials)
reset_times: Dict[int, int] = {} # token index -> time
for attempt in range(max_attempts):
response = self.session.get(current_url)
if not (
# GitHub returns inconsistent status codes between unauthenticated
# rate limit and authenticated rate limits. Handle both.
response.status_code == 429
or (self.anonymous and response.status_code == 403)
):
# Not rate limited, exit this loop.
break
ratelimit_reset = response.headers.get("X-Ratelimit-Reset")
if ratelimit_reset is None:
logger.warning(
"Rate-limit reached and X-Ratelimit-Reset value not found. "
"Response content: %s",
response.content,
)
else:
reset_times[self.token_index] = int(ratelimit_reset)
if not self.anonymous:
logger.info(
"Rate limit exhausted for current user %s (resetting at %s)",
self.current_user,
ratelimit_reset,
)
# Use next token in line
self.set_next_session_token()
# Wait one second to avoid triggering GitHub's abuse rate limits.
time.sleep(1)
else:
# All tokens have been rate-limited. What do we do?
if not reset_times:
logger.warning(
"No X-Ratelimit-Reset value found in responses for any token; "
"Giving up."
)
break
sleep_time = max(reset_times.values()) - time.time() + 1
logger.info(
"Rate limits exhausted for all tokens. Sleeping for %f seconds.",
sleep_time,
)
time.sleep(sleep_time)
# This goes back to the outer page-by-page loop, doing one more
# iteration on the same page
continue
# We've successfully retrieved a (non-ratelimited) `response`. We
# still need to check it for validity.
if response.status_code != 200:
logger.warning(
"Got unexpected status_code %s: %s",
response.status_code,
response.content,
)
break
yield response.json()
if "next" not in response.links:
# No `next` link, we've reached the end of the world
logger.debug(
"No next link found in the response headers, all caught up"
)
break
# GitHub strongly advises to use the next link directly. We still
# parse it to get the id of the last repository we've reached so
# far.
next_url = response.links["next"]["url"]
return int(self.API_URL_INDEX_RE.match(next_url).group(1)) # type: ignore
return None
parsed_url = urlparse(next_url)
if not parsed_url.query:
logger.warning("Failed to parse url %s", next_url)
break
def transport_response_simplified(self, response: Response) -> List[Dict[str, Any]]:
repos = response.json()
return [
self.get_model_from_repo(repo) for repo in repos if repo and "id" in repo
]
parsed_query = parse_qs(parsed_url.query)
current_id = int(parsed_query["since"][0])
current_url = next_url
def request_headers(self) -> Dict[str, Any]:
"""(Override) Set requests headers to send when querying the GitHub API
def get_origins_from_page(
self, page: List[Dict[str, Any]]
) -> Iterator[ListedOrigin]:
"""Convert a page of GitHub repositories into a list of ListedOrigins.
This records the html_url, as well as the pushed_at value if it exists.
"""
headers = super().request_headers()
headers["Accept"] = "application/vnd.github.v3+json"
return headers
assert self.lister_obj.id is not None
def disable_deleted_repo_tasks(self, index: int, next_index: int, keep_these: int):
""" (Overrides) Fix provided index value to avoid erroneously disabling
some scheduler tasks
"""
# Next listed repository ids are strictly greater than the 'since'
# parameter, so increment the index to avoid disabling the latest
# created task when processing a new repositories page returned by
# the Github API
return super().disable_deleted_repo_tasks(index + 1, next_index, keep_these)
for repo in page:
pushed_at_str = repo.get("pushed_at")
pushed_at: Optional[datetime.datetime] = None
if pushed_at_str:
pushed_at = iso8601.parse_date(pushed_at_str)
yield ListedOrigin(
lister_id=self.lister_obj.id,
url=repo["html_url"],
visit_type="git",
last_update=pushed_at,
)
def commit_page(self, page: List[Dict[str, Any]]):
"""Update the currently stored state using the latest listed page"""
if self.relisting:
# Don't update internal state when relisting
return
last_id = page[-1]["id"]
if last_id > self.state.last_seen_id:
self.state.last_seen_id = last_id
def finalize(self):
if self.relisting:
return
# Pull fresh lister state from the scheduler backend
scheduler_state = self.get_state_from_scheduler()
# Update the lister state in the backend only if the last seen id of
# the current run is higher than that stored in the database.
if self.state.last_seen_id > scheduler_state.last_seen_id:
self.updated = True

View file

@ -1,17 +0,0 @@
# Copyright (C) 2017-2019 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from sqlalchemy import Boolean, Column, Integer
from swh.lister.core.models import IndexingModelBase
class GitHubModel(IndexingModelBase):
"""a GitHub repository"""
__tablename__ = "github_repo"
uid = Column(Integer, primary_key=True)
indexable = Column(Integer, index=True)
fork = Column(Boolean, default=False)

View file

@ -3,42 +3,46 @@
# See top-level LICENSE file for more information
import random
from typing import Dict, Optional
from celery import group, shared_task
from swh.lister.github.lister import GitHubLister
GROUP_SPLIT = 10000
GROUP_SPLIT = 100000
@shared_task(name=__name__ + ".IncrementalGitHubLister")
def list_github_incremental(**lister_args):
def list_github_incremental() -> Dict[str, int]:
"Incremental update of GitHub"
lister = GitHubLister(**lister_args)
return lister.run(min_bound=lister.db_last_index(), max_bound=None)
lister = GitHubLister.from_configfile()
return lister.run().dict()
@shared_task(name=__name__ + ".RangeGitHubLister")
def _range_github_lister(start, end, **lister_args):
lister = GitHubLister(**lister_args)
return lister.run(min_bound=start, max_bound=end)
def _range_github_lister(first_id: int, last_id: int) -> Dict[str, int]:
lister = GitHubLister.from_configfile(first_id=first_id, last_id=last_id)
return lister.run().dict()
@shared_task(name=__name__ + ".FullGitHubRelister", bind=True)
def list_github_full(self, split=None, **lister_args):
def list_github_full(self, split: Optional[int] = None) -> str:
"""Full update of GitHub
It's not to be called for an initial listing.
"""
lister = GitHubLister(**lister_args)
ranges = lister.db_partition_indices(split or GROUP_SPLIT)
if not ranges:
self.log.info("Nothing to list")
return
lister = GitHubLister.from_configfile()
last_index = lister.state.last_seen_id
bounds = list(range(0, last_index + 1, split or GROUP_SPLIT))
if bounds[-1] != last_index:
bounds.append(last_index)
ranges = list(zip(bounds[:-1], bounds[1:]))
random.shuffle(ranges)
promise = group(
_range_github_lister.s(minv, maxv, **lister_args) for minv, maxv in ranges
_range_github_lister.s(first_id=minv, last_id=maxv) for minv, maxv in ranges
)()
self.log.debug("%s OK (spawned %s subtasks)" % (self.name, len(ranges)))
try:
@ -50,5 +54,5 @@ def list_github_full(self, split=None, **lister_args):
@shared_task(name=__name__ + ".ping")
def _ping():
def _ping() -> str:
return "OK"

View file

@ -1 +0,0 @@
repositories,since=0

View file

@ -1,81 +1,417 @@
# Copyright (C) 2017-2020 The Software Heritage developers
# Copyright (C) 2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import re
import unittest
import datetime
import logging
from typing import Any, Dict, Iterator, List, Optional, Union
import pytest
import requests_mock
from swh.lister.core.tests.test_lister import HttpListerTester
from swh.lister.github.lister import GitHubLister
from swh.lister.github.lister import GitHubLister, time
from swh.lister.pattern import CredentialsType, ListerStats
from swh.scheduler.interface import SchedulerInterface
from swh.scheduler.model import Lister
NUM_PAGES = 10
ORIGIN_COUNT = GitHubLister.PAGE_SIZE * NUM_PAGES
class GitHubListerTester(HttpListerTester, unittest.TestCase):
Lister = GitHubLister
test_re = re.compile(r"/repositories\?since=([^?&]+)")
lister_subdir = "github"
good_api_response_file = "data/https_api.github.com/first_response.json"
bad_api_response_file = "data/https_api.github.com/empty_response.json"
first_index = 0
last_index = 369
entries_per_page = 100
convert_type = int
def github_repo(i: int) -> Dict[str, Union[int, str]]:
"""Basic repository information returned by the GitHub API"""
def response_headers(self, request):
headers = {"X-RateLimit-Remaining": "1"}
if self.request_index(request) == self.first_index:
headers.update(
{
"Link": "<https://api.github.com/repositories?since=%s>;"
' rel="next",'
"<https://api.github.com/repositories{?since}>;"
' rel="first"' % self.last_index
}
)
else:
headers.update(
{
"Link": "<https://api.github.com/repositories{?since}>;"
' rel="first"'
}
)
return headers
repo: Dict[str, Union[int, str]] = {
"id": i,
"html_url": f"https://github.com/origin/{i}",
}
def mock_rate_quota(self, n, request, context):
self.rate_limit += 1
context.status_code = 403
context.headers["X-RateLimit-Remaining"] = "0"
context.headers["Retry-After"] = "1" # 1 second
return '{"error":"dummy"}'
# Set the pushed_at date on one of the origins
if i == 4321:
repo["pushed_at"] = "2018-11-08T13:16:24Z"
@requests_mock.Mocker()
def test_scheduled_tasks(self, http_mocker):
self.scheduled_tasks_test(
"data/https_api.github.com/next_response.json", 876, http_mocker
return repo
def github_response_callback(
request: requests_mock.request._RequestObjectProxy,
context: requests_mock.response._Context,
) -> List[Dict[str, Union[str, int]]]:
"""Return minimal GitHub API responses for the common case where the loader
hasn't been rate-limited"""
# Check request headers
assert request.headers["Accept"] == "application/vnd.github.v3+json"
assert "Software Heritage Lister" in request.headers["User-Agent"]
# Check request parameters: per_page == 1000, since = last_repo_id
assert "per_page" in request.qs
assert request.qs["per_page"] == [str(GitHubLister.PAGE_SIZE)]
assert "since" in request.qs
since = int(request.qs["since"][0])
next_page = since + GitHubLister.PAGE_SIZE
if next_page < ORIGIN_COUNT:
# the first id for the next page is within our origin count; add a Link
# header to the response
next_url = (
GitHubLister.API_URL
+ f"?per_page={GitHubLister.PAGE_SIZE}&since={next_page}"
)
context.headers["Link"] = f"<{next_url}>; rel=next"
return [github_repo(i) for i in range(since + 1, min(next_page, ORIGIN_COUNT) + 1)]
def test_lister_github(lister_github, requests_mock_datadir):
"""Simple github listing should create scheduled tasks
@pytest.fixture()
def requests_mocker() -> Iterator[requests_mock.Mocker]:
with requests_mock.Mocker() as mock:
mock.get(GitHubLister.API_URL, json=github_response_callback)
yield mock
def get_lister_data(swh_scheduler: SchedulerInterface) -> Lister:
"""Retrieve the data for the GitHub Lister"""
return swh_scheduler.get_or_create_lister(name="github", instance_name="github")
def set_lister_state(swh_scheduler: SchedulerInterface, state: Dict[str, Any]) -> None:
"""Set the state of the lister in database"""
lister = swh_scheduler.get_or_create_lister(name="github", instance_name="github")
lister.current_state = state
swh_scheduler.update_lister(lister)
def check_origin_4321(swh_scheduler: SchedulerInterface, lister: Lister) -> None:
"""Check that origin 4321 exists and has the proper last_update timestamp"""
origin_4321_req = swh_scheduler.get_listed_origins(
url="https://github.com/origin/4321"
)
assert len(origin_4321_req.origins) == 1
origin_4321 = origin_4321_req.origins[0]
assert origin_4321.lister_id == lister.id
assert origin_4321.visit_type == "git"
assert origin_4321.last_update == datetime.datetime(
2018, 11, 8, 13, 16, 24, tzinfo=datetime.timezone.utc
)
def check_origin_5555(swh_scheduler: SchedulerInterface, lister: Lister) -> None:
"""Check that origin 5555 exists and has no last_update timestamp"""
origin_5555_req = swh_scheduler.get_listed_origins(
url="https://github.com/origin/5555"
)
assert len(origin_5555_req.origins) == 1
origin_5555 = origin_5555_req.origins[0]
assert origin_5555.lister_id == lister.id
assert origin_5555.visit_type == "git"
assert origin_5555.last_update is None
def test_from_empty_state(
swh_scheduler, caplog, requests_mocker: requests_mock.Mocker
) -> None:
caplog.set_level(logging.DEBUG, "swh.lister.github.lister")
# Run the lister in incremental mode
lister = GitHubLister(scheduler=swh_scheduler)
res = lister.run()
assert res == ListerStats(pages=NUM_PAGES, origins=ORIGIN_COUNT)
listed_origins = swh_scheduler.get_listed_origins(limit=ORIGIN_COUNT + 1)
assert len(listed_origins.origins) == ORIGIN_COUNT
assert listed_origins.next_page_token is None
lister_data = get_lister_data(swh_scheduler)
assert lister_data.current_state == {"last_seen_id": ORIGIN_COUNT}
check_origin_4321(swh_scheduler, lister_data)
check_origin_5555(swh_scheduler, lister_data)
def test_incremental(swh_scheduler, caplog, requests_mocker) -> None:
caplog.set_level(logging.DEBUG, "swh.lister.github.lister")
# Number of origins to skip
skip_origins = 2000
expected_origins = ORIGIN_COUNT - skip_origins
# Bump the last_seen_id in the scheduler backend
set_lister_state(swh_scheduler, {"last_seen_id": skip_origins})
# Run the lister in incremental mode
lister = GitHubLister(scheduler=swh_scheduler)
res = lister.run()
# add 1 page to the number of full_pages if partial_page_len is not 0
full_pages, partial_page_len = divmod(expected_origins, GitHubLister.PAGE_SIZE)
expected_pages = full_pages + bool(partial_page_len)
assert res == ListerStats(pages=expected_pages, origins=expected_origins)
listed_origins = swh_scheduler.get_listed_origins(limit=expected_origins + 1)
assert len(listed_origins.origins) == expected_origins
assert listed_origins.next_page_token is None
lister_data = get_lister_data(swh_scheduler)
assert lister_data.current_state == {"last_seen_id": ORIGIN_COUNT}
check_origin_4321(swh_scheduler, lister_data)
check_origin_5555(swh_scheduler, lister_data)
def test_relister(swh_scheduler, caplog, requests_mocker) -> None:
caplog.set_level(logging.DEBUG, "swh.lister.github.lister")
# Only set this state as a canary: in the currently tested mode, the lister
# should not be touching it.
set_lister_state(swh_scheduler, {"last_seen_id": 123})
# Use "relisting" mode to list origins between id 10 and 1011
lister = GitHubLister(scheduler=swh_scheduler, first_id=10, last_id=1011)
res = lister.run()
# Make sure we got two full pages of results
assert res == ListerStats(pages=2, origins=2000)
# Check that the relisting mode hasn't touched the stored state.
lister_data = get_lister_data(swh_scheduler)
assert lister_data.current_state == {"last_seen_id": 123}
def github_ratelimit_callback(
request: requests_mock.request._RequestObjectProxy,
context: requests_mock.response._Context,
ratelimit_reset: Optional[int],
) -> Dict[str, str]:
"""Return a rate-limited GitHub API response."""
# Check request headers
assert request.headers["Accept"] == "application/vnd.github.v3+json"
assert "Software Heritage Lister" in request.headers["User-Agent"]
if "Authorization" in request.headers:
context.status_code = 429
else:
context.status_code = 403
if ratelimit_reset is not None:
context.headers["X-Ratelimit-Reset"] = str(ratelimit_reset)
return {
"message": "API rate limit exceeded for <IP>.",
"documentation_url": "https://developer.github.com/v3/#rate-limiting",
}
@pytest.fixture()
def num_before_ratelimit() -> int:
"""Number of successful requests before the ratelimit hits"""
return 0
@pytest.fixture()
def num_ratelimit() -> Optional[int]:
"""Number of rate-limited requests; None means infinity"""
return None
@pytest.fixture()
def ratelimit_reset() -> Optional[int]:
"""Value of the X-Ratelimit-Reset header on ratelimited responses"""
return None
@pytest.fixture()
def requests_ratelimited(
num_before_ratelimit: int,
num_ratelimit: Optional[int],
ratelimit_reset: Optional[int],
) -> Iterator[requests_mock.Mocker]:
"""Mock requests to the GitHub API, returning a rate-limiting status code
after `num_before_ratelimit` requests.
GitHub does inconsistent rate-limiting:
- Anonymous requests return a 403 status code
- Authenticated requests return a 429 status code, with an
X-Ratelimit-Reset header.
This fixture takes multiple arguments (which can be overridden with a
:func:`pytest.mark.parametrize` parameter):
- num_before_ratelimit: the global number of requests until the
ratelimit triggers
- num_ratelimit: the number of requests that return a
rate-limited response.
- ratelimit_reset: the timestamp returned in X-Ratelimit-Reset if the
request is authenticated.
The default values set in the previous fixtures make all requests return a rate
limit response.
"""
lister_github.run()
current_request = 0
r = lister_github.scheduler.search_tasks(task_type="load-git")
assert len(r) == 100
def response_callback(request, context):
nonlocal current_request
current_request += 1
if num_before_ratelimit < current_request and (
num_ratelimit is None
or current_request < num_before_ratelimit + num_ratelimit + 1
):
return github_ratelimit_callback(request, context, ratelimit_reset)
else:
return github_response_callback(request, context)
for row in r:
assert row["type"] == "load-git"
# arguments check
args = row["arguments"]["args"]
assert len(args) == 0
with requests_mock.Mocker() as mock:
mock.get(GitHubLister.API_URL, json=response_callback)
yield mock
# kwargs
kwargs = row["arguments"]["kwargs"]
url = kwargs["url"]
assert url.startswith("https://github.com")
assert row["policy"] == "recurring"
assert row["priority"] is None
def test_anonymous_ratelimit(swh_scheduler, caplog, requests_ratelimited) -> None:
caplog.set_level(logging.DEBUG, "swh.lister.github.lister")
lister = GitHubLister(scheduler=swh_scheduler)
assert lister.anonymous
assert "using anonymous mode" in caplog.records[-1].message
caplog.clear()
res = lister.run()
assert res == ListerStats(pages=0, origins=0)
last_log = caplog.records[-1]
assert last_log.levelname == "WARNING"
assert "No X-Ratelimit-Reset value found in responses" in last_log.message
@pytest.fixture
def github_credentials() -> List[Dict[str, str]]:
"""Return a static list of GitHub credentials"""
return sorted(
[{"username": f"swh{i:d}", "token": f"token-{i:d}"} for i in range(3)]
+ [
{"username": f"swh-legacy{i:d}", "password": f"token-legacy-{i:d}"}
for i in range(3)
],
key=lambda c: c["username"],
)
@pytest.fixture
def all_tokens(github_credentials) -> List[str]:
"""Return the list of tokens matching the static credential"""
return [t.get("token", t.get("password")) for t in github_credentials]
@pytest.fixture
def lister_credentials(github_credentials: List[Dict[str, str]]) -> CredentialsType:
"""Return the credentials formatted for use by the lister"""
return {"github": {"github": github_credentials}}
def test_authenticated_credentials(
swh_scheduler, caplog, github_credentials, lister_credentials, all_tokens
):
"""Test credentials management when the lister is authenticated"""
caplog.set_level(logging.DEBUG, "swh.lister.github.lister")
lister = GitHubLister(scheduler=swh_scheduler, credentials=lister_credentials)
assert lister.token_index == 0
assert sorted(lister.credentials, key=lambda t: t["username"]) == github_credentials
assert lister.session.headers["Authorization"] in [
"token %s" % t for t in all_tokens
]
def fake_time_sleep(duration: float, sleep_calls: Optional[List[float]] = None):
"""Record calls to time.sleep in the sleep_calls list"""
if duration < 0:
raise ValueError("Can't sleep for a negative amount of time!")
if sleep_calls is not None:
sleep_calls.append(duration)
def fake_time_time():
"""Return 0 when running time.time()"""
return 0
@pytest.fixture
def monkeypatch_sleep_calls(monkeypatch) -> Iterator[List[float]]:
"""Monkeypatch `time.time` and `time.sleep`. Returns a list cumulating the arguments
passed to time.sleep()."""
sleeps: List[float] = []
monkeypatch.setattr(time, "sleep", lambda d: fake_time_sleep(d, sleeps))
monkeypatch.setattr(time, "time", fake_time_time)
yield sleeps
@pytest.mark.parametrize(
"num_ratelimit", [1]
) # return a single rate-limit response, then continue
def test_ratelimit_once_recovery(
swh_scheduler,
caplog,
requests_ratelimited,
num_ratelimit,
monkeypatch_sleep_calls,
lister_credentials,
):
"""Check that the lister recovers from hitting the rate-limit once"""
caplog.set_level(logging.DEBUG, "swh.lister.github.lister")
lister = GitHubLister(scheduler=swh_scheduler, credentials=lister_credentials)
res = lister.run()
# check that we used all the pages
assert res == ListerStats(pages=NUM_PAGES, origins=ORIGIN_COUNT)
token_users = []
for record in caplog.records:
if "Using authentication token" in record.message:
token_users.append(record.args[0])
# check that we used one more token than we saw rate limited requests
assert len(token_users) == 1 + num_ratelimit
# check that we slept for one second between our token uses
assert monkeypatch_sleep_calls == [1]
@pytest.mark.parametrize(
# Do 5 successful requests, return 6 ratelimits (to exhaust the credentials) with a
# set value for X-Ratelimit-Reset, then resume listing successfully.
"num_before_ratelimit, num_ratelimit, ratelimit_reset",
[(5, 6, 123456)],
)
def test_ratelimit_reset_sleep(
swh_scheduler,
caplog,
requests_ratelimited,
monkeypatch_sleep_calls,
num_before_ratelimit,
ratelimit_reset,
github_credentials,
lister_credentials,
):
"""Check that the lister properly handles rate-limiting when providing it with
authentication tokens"""
caplog.set_level(logging.DEBUG, "swh.lister.github.lister")
lister = GitHubLister(scheduler=swh_scheduler, credentials=lister_credentials)
res = lister.run()
assert res == ListerStats(pages=NUM_PAGES, origins=ORIGIN_COUNT)
# We sleep 1 second every time we change credentials, then we sleep until
# ratelimit_reset + 1
expected_sleep_calls = len(github_credentials) * [1] + [ratelimit_reset + 1]
assert monkeypatch_sleep_calls == expected_sleep_calls
found_exhaustion_message = False
for record in caplog.records:
if record.levelname == "INFO":
if "Rate limits exhausted for all tokens" in record.message:
found_exhaustion_message = True
break
assert found_exhaustion_message

View file

@ -1,8 +1,11 @@
from time import sleep
from unittest.mock import patch
from unittest.mock import call, patch
from celery.result import GroupResult
from swh.lister.github.lister import GitHubListerState
from swh.lister.pattern import ListerStats
def test_ping(swh_scheduler_celery_app, swh_scheduler_celery_worker):
res = swh_scheduler_celery_app.send_task("swh.lister.github.tasks.ping")
@ -15,9 +18,9 @@ def test_ping(swh_scheduler_celery_app, swh_scheduler_celery_worker):
@patch("swh.lister.github.tasks.GitHubLister")
def test_incremental(lister, swh_scheduler_celery_app, swh_scheduler_celery_worker):
# setup the mocked GitHubLister
lister.return_value = lister
lister.db_last_index.return_value = 42
lister.run.return_value = None
lister.from_configfile.return_value = lister
lister.state = GitHubListerState()
lister.run.return_value = ListerStats(pages=5, origins=5000)
res = swh_scheduler_celery_app.send_task(
"swh.lister.github.tasks.IncrementalGitHubLister"
@ -26,35 +29,39 @@ def test_incremental(lister, swh_scheduler_celery_app, swh_scheduler_celery_work
res.wait()
assert res.successful()
lister.assert_called_once_with()
lister.db_last_index.assert_called_once_with()
lister.run.assert_called_once_with(min_bound=42, max_bound=None)
lister.from_configfile.assert_called_once_with()
@patch("swh.lister.github.tasks.GitHubLister")
def test_range(lister, swh_scheduler_celery_app, swh_scheduler_celery_worker):
# setup the mocked GitHubLister
lister.return_value = lister
lister.run.return_value = None
lister.from_configfile.return_value = lister
lister.run.return_value = ListerStats(pages=5, origins=5000)
res = swh_scheduler_celery_app.send_task(
"swh.lister.github.tasks.RangeGitHubLister", kwargs=dict(start=12, end=42)
"swh.lister.github.tasks.RangeGitHubLister",
kwargs=dict(first_id=12, last_id=42),
)
assert res
res.wait()
assert res.successful()
lister.assert_called_once_with()
lister.db_last_index.assert_not_called()
lister.run.assert_called_once_with(min_bound=12, max_bound=42)
lister.from_configfile.assert_called_once_with(first_id=12, last_id=42)
lister.run.assert_called_once_with()
@patch("swh.lister.github.tasks.GitHubLister")
def test_relister(lister, swh_scheduler_celery_app, swh_scheduler_celery_worker):
def test_lister_full(lister, swh_scheduler_celery_app, swh_scheduler_celery_worker):
last_index = 1000000
expected_bounds = list(range(0, last_index + 1, 100000))
if expected_bounds[-1] != last_index:
expected_bounds.append(last_index)
# setup the mocked GitHubLister
lister.return_value = lister
lister.run.return_value = None
lister.db_partition_indices.return_value = [(i, i + 9) for i in range(0, 50, 10)]
lister.state = GitHubListerState(last_seen_id=last_index)
lister.from_configfile.return_value = lister
lister.run.return_value = ListerStats(pages=10, origins=10000)
res = swh_scheduler_celery_app.send_task(
"swh.lister.github.tasks.FullGitHubRelister"
@ -74,18 +81,13 @@ def test_relister(lister, swh_scheduler_celery_app, swh_scheduler_celery_worker)
break
sleep(1)
lister.assert_called_with()
# pulling the state out of the database
assert lister.from_configfile.call_args_list[0] == call()
# one by the FullGitHubRelister task
# + 5 for the RangeGitHubLister subtasks
assert lister.call_count == 6
lister.db_last_index.assert_not_called()
lister.db_partition_indices.assert_called_once_with(10000)
# lister.run should have been called once per partition interval
for i in range(5):
# XXX inconsistent behavior: max_bound is INCLUDED here
assert (
dict(min_bound=10 * i, max_bound=10 * i + 9),
) in lister.run.call_args_list
# Calls for each of the ranges
range_calls = lister.from_configfile.call_args_list[1:]
# Check exhaustivity of the range calls
assert sorted(range_calls, key=lambda c: c[1]["first_id"]) == [
call(first_id=f, last_id=l)
for f, l in zip(expected_bounds[:-1], expected_bounds[1:])
]