Remove no longer used legacy Lister API and update CLI options

Legacy Lister classes from the swh.lister.core mdule are no longer
used in swh-lister codebase so it is time to remove them.

Also remove lister CLI options related to legacy Lister API.

As a consequence, the following requirements are no longer needed:
arrow, SQLAlchemy, sqlalchemy-stubs and testing.postgresql.

Closes T2442
This commit is contained in:
Antoine Lambert 2021-02-02 12:27:46 +01:00
parent ff05191b7d
commit 8933544521
18 changed files with 8 additions and 1730 deletions

View file

@ -1,10 +1,10 @@
# Copyright (C) 2020 The Software Heritage developers
# Copyright (C) 2020-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import os
pytest_plugins = ["swh.scheduler.pytest_plugin", "swh.lister.pytest_plugin"]
pytest_plugins = ["swh.scheduler.pytest_plugin"]
os.environ["LC_ALL"] = "C.UTF-8"

View file

@ -2,10 +2,6 @@
namespace_packages = True
warn_unused_ignores = True
# support for sqlalchemy magic: see https://github.com/dropbox/sqlalchemy-stubs
plugins = sqlmypy
# 3rd party libraries without stubs (yet)
[mypy-bs4.*]
@ -38,9 +34,6 @@ ignore_missing_imports = True
[mypy-requests_mock.*]
ignore_missing_imports = True
[mypy-testing.postgresql.*]
ignore_missing_imports = True
[mypy-urllib3.util.*]
ignore_missing_imports = True

View file

@ -1,5 +1,3 @@
pytest
pytest-mock
requests_mock
sqlalchemy-stubs
testing.postgresql

View file

@ -1,5 +1,3 @@
SQLAlchemy
arrow
python_debian
requests
setuptools

View file

@ -1,4 +1,4 @@
# Copyright (C) 2018-2020 The Software Heritage developers
# Copyright (C) 2018-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
@ -14,31 +14,11 @@ import click
from swh.core.cli import CONTEXT_SETTINGS
from swh.core.cli import swh as swh_cli_group
from swh.lister import LISTERS, SUPPORTED_LISTERS, get_lister
from swh.lister import SUPPORTED_LISTERS, get_lister
logger = logging.getLogger(__name__)
# the key in this dict is the suffix used to match new task-type to be added.
# For example for a task which function name is "list_gitlab_full', the default
# value used when inserting a new task-type in the scheduler db will be the one
# under the 'full' key below (because it matches xxx_full).
DEFAULT_TASK_TYPE = {
"full": { # for tasks like 'list_xxx_full()'
"default_interval": "90 days",
"min_interval": "90 days",
"max_interval": "90 days",
"backoff_factor": 1,
},
"*": { # value if not suffix matches
"default_interval": "1 day",
"min_interval": "1 day",
"max_interval": "1 day",
"backoff_factor": 1,
},
}
@swh_cli_group.group(name="lister", context_settings=CONTEXT_SETTINGS)
@click.option(
"--config-file",
@ -47,15 +27,8 @@ DEFAULT_TASK_TYPE = {
type=click.Path(exists=True, dir_okay=False,),
help="Configuration file.",
)
@click.option(
"--db-url",
"-d",
default=None,
help="SQLAlchemy DB URL; see "
"<http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>",
) # noqa
@click.pass_context
def lister(ctx, config_file, db_url):
def lister(ctx, config_file):
"""Software Heritage Lister tools."""
from swh.core import config
@ -64,53 +37,10 @@ def lister(ctx, config_file, db_url):
if not config_file:
config_file = os.environ.get("SWH_CONFIG_FILENAME")
conf = config.read(config_file)
if db_url:
conf["lister"] = {"cls": "local", "args": {"db": db_url}}
ctx.obj["config"] = conf
@lister.command(name="db-init", context_settings=CONTEXT_SETTINGS)
@click.option(
"--drop-tables",
"-D",
is_flag=True,
default=False,
help="Drop tables before creating the database schema",
)
@click.pass_context
def db_init(ctx, drop_tables):
"""Initialize the database model for given listers.
"""
from sqlalchemy import create_engine
from swh.lister.core.models import initialize
cfg = ctx.obj["config"]
lister_cfg = cfg["lister"]
if lister_cfg["cls"] != "local":
click.echo("A local lister configuration is required")
ctx.exit(1)
db_url = lister_cfg["args"]["db"]
db_engine = create_engine(db_url)
registry = {}
for lister, entrypoint in LISTERS.items():
logger.info("Loading lister %s", lister)
registry[lister] = entrypoint.load()()
logger.info("Initializing database")
initialize(db_engine, drop_tables)
for lister, entrypoint in LISTERS.items():
registry_entry = registry[lister]
init_hook = registry_entry.get("init")
if callable(init_hook):
logger.info("Calling init hook for %s", lister)
init_hook(db_engine)
@lister.command(
name="run",
context_settings=CONTEXT_SETTINGS,
@ -122,17 +52,9 @@ def db_init(ctx, drop_tables):
@click.option(
"--lister", "-l", help="Lister to run", type=click.Choice(SUPPORTED_LISTERS)
)
@click.option(
"--priority",
"-p",
default="high",
type=click.Choice(["high", "medium", "low"]),
help="Task priority for the listed repositories to ingest",
)
@click.option("--legacy", help="Allow unported lister to run with such flag")
@click.argument("options", nargs=-1)
@click.pass_context
def run(ctx, lister, priority, options, legacy):
def run(ctx, lister, options):
from swh.scheduler.cli.utils import parse_options
config = deepcopy(ctx.obj["config"])
@ -140,10 +62,6 @@ def run(ctx, lister, priority, options, legacy):
if options:
config.update(parse_options(options)[1])
if legacy:
config["priority"] = priority
config["policy"] = "oneshot"
get_lister(lister, **config).run()

View file

@ -1,28 +0,0 @@
# Copyright (C) 2017 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
class AbstractAttribute:
"""AbstractAttributes in a base class must be overridden by the subclass.
It's like the :func:`abc.abstractmethod` decorator, but for things that
are explicitly attributes/properties, not methods, without the need for
empty method def boilerplate. Like abc.abstractmethod, the class containing
AbstractAttributes must inherit from :class:`abc.ABC` or use the
:class:`abc.ABCMeta` metaclass.
Usage example::
import abc
class ClassContainingAnAbstractAttribute(abc.ABC):
foo: Union[AbstractAttribute, Any] = \
AbstractAttribute('docstring for foo')
"""
__isabstractmethod__ = True
def __init__(self, docstring=None):
if docstring is not None:
self.__doc__ = "AbstractAttribute: " + docstring

View file

@ -1,508 +0,0 @@
# Copyright (C) 2015-2020 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import abc
import datetime
import gzip
import json
import logging
import os
import re
import time
from typing import Any, Dict, List, Optional, Type, Union
from requests import Response
from sqlalchemy import create_engine, func
from sqlalchemy.orm import sessionmaker
from swh.core import config
from swh.core.utils import grouper
from swh.scheduler import get_scheduler, utils
from .abstractattribute import AbstractAttribute
logger = logging.getLogger(__name__)
def utcnow():
return datetime.datetime.now(tz=datetime.timezone.utc)
class FetchError(RuntimeError):
def __init__(self, response):
self.response = response
def __str__(self):
return repr(self.response)
DEFAULT_CONFIG = {
"scheduler": {"cls": "memory"},
"lister": {"cls": "local", "args": {"db": "postgresql:///lister",},},
"credentials": {},
"cache_responses": False,
}
class ListerBase(abc.ABC):
"""Lister core base class.
Generally a source code hosting service provides an API endpoint
for listing the set of stored repositories. A Lister is the discovery
service responsible for finding this list, all at once or sequentially
by parts, and queueing local tasks to fetch and ingest the referenced
repositories.
The core method in this class is ingest_data. Any subclasses should be
calling this method one or more times to fetch and ingest data from API
endpoints. See swh.lister.core.lister_base.IndexingLister for
example usage.
This class cannot be instantiated. Any instantiable Lister descending
from ListerBase must provide at least the required overrides.
(see member docstrings for details):
Required Overrides:
MODEL
def transport_request
def transport_response_to_string
def transport_response_simplified
def transport_quota_check
Optional Overrides:
def filter_before_inject
def is_within_bounds
"""
MODEL = AbstractAttribute(
"Subclass type (not instance) of swh.lister.core.models.ModelBase "
"customized for a specific service."
) # type: Union[AbstractAttribute, Type[Any]]
LISTER_NAME = AbstractAttribute(
"Lister's name"
) # type: Union[AbstractAttribute, str]
def transport_request(self, identifier):
"""Given a target endpoint identifier to query, try once to request it.
Implementation of this method determines the network request protocol.
Args:
identifier (string): unique identifier for an endpoint query.
e.g. If the service indexes lists of repositories by date and
time of creation, this might be that as a formatted string. Or
it might be an integer UID. Or it might be nothing.
It depends on what the service needs.
Returns:
the entire request response
Raises:
Will catch internal transport-dependent connection exceptions and
raise swh.lister.core.lister_base.FetchError instead. Other
non-connection exceptions should propagate unchanged.
"""
pass
def transport_response_to_string(self, response):
"""Convert the server response into a formatted string for logging.
Implementation of this method depends on the shape of the network
response object returned by the transport_request method.
Args:
response: the server response
Returns:
a pretty string of the response
"""
pass
def transport_response_simplified(self, response):
"""Convert the server response into list of a dict for each repo in the
response, mapping columns in the lister's MODEL class to repo data.
Implementation of this method depends on the server API spec and the
shape of the network response object returned by the transport_request
method.
Args:
response: response object from the server.
Returns:
list of repo MODEL dicts
( eg. [{'uid': r['id'], etc.} for r in response.json()] )
"""
pass
def transport_quota_check(self, response):
"""Check server response to see if we're hitting request rate limits.
Implementation of this method depends on the server communication
protocol and API spec and the shape of the network response object
returned by the transport_request method.
Args:
response (session response): complete API query response
Returns:
1) must retry request? True/False
2) seconds to delay if True
"""
pass
def filter_before_inject(self, models_list: List[Dict]) -> List[Dict]:
"""Filter models_list entries prior to injection in the db.
This is ran directly after `transport_response_simplified`.
Default implementation is to have no filtering.
Args:
models_list: list of dicts returned by
transport_response_simplified.
Returns:
models_list with entries changed according to custom logic.
"""
return models_list
def do_additional_checks(self, models_list: List[Dict]) -> List[Dict]:
"""Execute some additional checks on the model list (after the
filtering).
Default implementation is to run no check at all and to return
the input as is.
Args:
models_list: list of dicts returned by
transport_response_simplified.
Returns:
models_list with entries if checks ok, False otherwise
"""
return models_list
def is_within_bounds(
self, inner: int, lower: Optional[int] = None, upper: Optional[int] = None
) -> bool:
"""See if a sortable value is inside the range [lower,upper].
MAY BE OVERRIDDEN, for example if the server indexable* key is
technically sortable but not automatically so.
* - ( see: swh.lister.core.indexing_lister.IndexingLister )
Args:
inner (sortable type): the value being checked
lower (sortable type): optional lower bound
upper (sortable type): optional upper bound
Returns:
whether inner is confined by the optional lower and upper bounds
"""
try:
if lower is None and upper is None:
return True
elif lower is None:
ret = inner <= upper # type: ignore
elif upper is None:
ret = inner >= lower
else:
ret = lower <= inner <= upper
self.string_pattern_check(inner, lower, upper)
except Exception as e:
logger.error(
str(e)
+ ": %s, %s, %s"
% (
("inner=%s%s" % (type(inner), inner)),
("lower=%s%s" % (type(lower), lower)),
("upper=%s%s" % (type(upper), upper)),
)
)
raise
return ret
# You probably don't need to override anything below this line.
INITIAL_BACKOFF = 10
MAX_RETRIES = 7
CONN_SLEEP = 10
def __init__(self, override_config=None):
self.backoff = self.INITIAL_BACKOFF
self.config = config.load_from_envvar(DEFAULT_CONFIG)
if self.config["cache_responses"]:
cache_dir = self.config.get(
"cache_dir", f"~/.cache/swh/lister/{self.LISTER_NAME}"
)
self.config["cache_dir"] = os.path.expanduser(cache_dir)
config.prepare_folders(self.config, "cache_dir")
if override_config:
self.config.update(override_config)
logger.debug("%s CONFIG=%s" % (self, self.config))
self.scheduler = get_scheduler(**self.config["scheduler"])
self.db_engine = create_engine(self.config["lister"]["args"]["db"])
self.mk_session = sessionmaker(bind=self.db_engine)
self.db_session = self.mk_session()
def reset_backoff(self):
"""Reset exponential backoff timeout to initial level."""
self.backoff = self.INITIAL_BACKOFF
def back_off(self) -> int:
"""Get next exponential backoff timeout."""
ret = self.backoff
self.backoff *= 10
return ret
def safely_issue_request(self, identifier: int) -> Optional[Response]:
"""Make network request with retries, rate quotas, and response logs.
Protocol is handled by the implementation of the transport_request
method.
Args:
identifier: resource identifier
Returns:
server response
"""
retries_left = self.MAX_RETRIES
do_cache = self.config["cache_responses"]
r = None
while retries_left > 0:
try:
r = self.transport_request(identifier)
except FetchError:
# network-level connection error, try again
logger.warning(
"connection error on %s: sleep for %d seconds"
% (identifier, self.CONN_SLEEP)
)
time.sleep(self.CONN_SLEEP)
retries_left -= 1
continue
if do_cache:
self.save_response(r)
# detect throttling
must_retry, delay = self.transport_quota_check(r)
if must_retry:
logger.warning(
"rate limited on %s: sleep for %f seconds" % (identifier, delay)
)
time.sleep(delay)
else: # request ok
break
retries_left -= 1
if not retries_left:
logger.warning("giving up on %s: max retries exceeded" % identifier)
return r
def db_query_equal(self, key: Any, value: Any):
"""Look in the db for a row with key == value
Args:
key: column key to look at
value: value to look for in that column
Returns:
sqlalchemy.ext.declarative.declarative_base object
with the given key == value
"""
if isinstance(key, str):
key = self.MODEL.__dict__[key]
return self.db_session.query(self.MODEL).filter(key == value).first()
def winnow_models(self, mlist, key, to_remove):
"""Given a list of models, remove any with <key> matching
some member of a list of values.
Args:
mlist (list of model rows): the initial list of models
key (column): the column to filter on
to_remove (list): if anything in mlist has column <key> equal to
one of the values in to_remove, it will be removed from the
result
Returns:
A list of model rows starting from mlist minus any matching rows
"""
if isinstance(key, str):
key = self.MODEL.__dict__[key]
if to_remove:
return mlist.filter(~key.in_(to_remove)).all()
else:
return mlist.all()
def db_num_entries(self):
"""Return the known number of entries in the lister db"""
return self.db_session.query(func.count("*")).select_from(self.MODEL).scalar()
def db_inject_repo(self, model_dict):
"""Add/update a new repo to the db and mark it last_seen now.
Args:
model_dict: dictionary mapping model keys to values
Returns:
new or updated sqlalchemy.ext.declarative.declarative_base
object associated with the injection
"""
sql_repo = self.db_query_equal("uid", model_dict["uid"])
if not sql_repo:
sql_repo = self.MODEL(**model_dict)
self.db_session.add(sql_repo)
else:
for k in model_dict:
setattr(sql_repo, k, model_dict[k])
sql_repo.last_seen = utcnow()
return sql_repo
def task_dict(self, origin_type: str, origin_url: str, **kwargs) -> Dict[str, Any]:
"""Return special dict format for the tasks list
Args:
origin_type (string)
origin_url (string)
Returns:
the same information in a different form
"""
logger.debug("origin-url: %s, type: %s", origin_url, origin_type)
_type = "load-%s" % origin_type
_policy = kwargs.get("policy", "recurring")
priority = kwargs.get("priority")
kw = {"priority": priority} if priority else {}
return utils.create_task_dict(_type, _policy, url=origin_url, **kw)
def string_pattern_check(self, a, b, c=None):
"""When comparing indexable types in is_within_bounds, complex strings
may not be allowed to differ in basic structure. If they do, it
could be a sign of not understanding the data well. For instance,
an ISO 8601 time string cannot be compared against its urlencoded
equivalent, but this is an easy mistake to accidentally make. This
method acts as a friendly sanity check.
Args:
a (string): inner component of the is_within_bounds method
b (string): lower component of the is_within_bounds method
c (string): upper component of the is_within_bounds method
Returns:
nothing
Raises:
TypeError if strings a, b, and c don't conform to the same basic
pattern.
"""
if isinstance(a, str):
a_pattern = re.sub("[a-zA-Z0-9]", "[a-zA-Z0-9]", re.escape(a))
if (
isinstance(b, str)
and (re.match(a_pattern, b) is None)
or isinstance(c, str)
and (re.match(a_pattern, c) is None)
):
logger.debug(a_pattern)
raise TypeError("incomparable string patterns detected")
def inject_repo_data_into_db(self, models_list: List[Dict]) -> Dict:
"""Inject data into the db.
Args:
models_list: list of dicts mapping keys from the db model
for each repo to be injected
Returns:
dict of uid:sql_repo pairs
"""
injected_repos = {}
for m in models_list:
injected_repos[m["uid"]] = self.db_inject_repo(m)
return injected_repos
def schedule_missing_tasks(
self, models_list: List[Dict], injected_repos: Dict
) -> None:
"""Schedule any newly created db entries that do not have been
scheduled yet.
Args:
models_list: List of dicts mapping keys in the db model
for each repo
injected_repos: Dict of uid:sql_repo pairs that have just
been created
Returns:
Nothing. (Note that it Modifies injected_repos to set the new
task_id).
"""
tasks = {}
def _task_key(m):
return "%s-%s" % (m["type"], json.dumps(m["arguments"], sort_keys=True))
for m in models_list:
ir = injected_repos[m["uid"]]
if not ir.task_id:
# Patching the model instance to add the policy/priority task
# scheduling
if "policy" in self.config:
m["policy"] = self.config["policy"]
if "priority" in self.config:
m["priority"] = self.config["priority"]
task_dict = self.task_dict(**m)
task_dict.setdefault("retries_left", 3)
tasks[_task_key(task_dict)] = (ir, m, task_dict)
gen_tasks = (task_dicts for (_, _, task_dicts) in tasks.values())
for grouped_tasks in grouper(gen_tasks, n=1000):
new_tasks = self.scheduler.create_tasks(list(grouped_tasks))
for task in new_tasks:
ir, m, _ = tasks[_task_key(task)]
ir.task_id = task["id"]
def ingest_data(self, identifier: int, checks: bool = False):
"""The core data fetch sequence. Request server endpoint. Simplify and
filter response list of repositories. Inject repo information into
local db. Queue loader tasks for linked repositories.
Args:
identifier: Resource identifier.
checks (bool): Additional checks required
"""
# Request (partial?) list of repositories info
response = self.safely_issue_request(identifier)
if not response:
return response, []
models_list = self.transport_response_simplified(response)
models_list = self.filter_before_inject(models_list)
if checks:
models_list = self.do_additional_checks(models_list)
if not models_list:
return response, []
# inject into local db
injected = self.inject_repo_data_into_db(models_list)
# queue workers
self.schedule_missing_tasks(models_list, injected)
return response, injected
def save_response(self, response):
"""Log the response from a server request to a cache dir.
Args:
response: full server response
cache_dir: system path for cache dir
Returns:
nothing
"""
datepath = utcnow().isoformat()
fname = os.path.join(self.config["cache_dir"], datepath + ".gz",)
with gzip.open(fname, "w") as f:
f.write(bytes(self.transport_response_to_string(response), "UTF-8"))

View file

@ -1,233 +0,0 @@
# Copyright (C) 2017-2018 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import abc
from datetime import datetime
from email.utils import parsedate
import logging
from pprint import pformat
import random
from typing import Any, Dict, List, Optional, Union
import requests
from requests import Response
import xmltodict
from swh.lister import USER_AGENT_TEMPLATE, __version__
from .abstractattribute import AbstractAttribute
from .lister_base import FetchError
logger = logging.getLogger(__name__)
class ListerHttpTransport(abc.ABC):
"""Use the Requests library for making Lister endpoint requests.
To be used in conjunction with ListerBase or a subclass of it.
"""
DEFAULT_URL = None # type: Optional[str]
PATH_TEMPLATE = AbstractAttribute(
"string containing a python string format pattern that produces"
" the API endpoint path for listing stored repositories when given"
' an index, e.g., "/repositories?after=%s". To be implemented in'
" the API-specific class inheriting this."
) # type: Union[AbstractAttribute, Optional[str]]
EXPECTED_STATUS_CODES = (200, 429, 403, 404)
def request_headers(self) -> Dict[str, Any]:
"""Returns dictionary of any request headers needed by the server.
MAY BE OVERRIDDEN if request headers are needed.
"""
return {"User-Agent": USER_AGENT_TEMPLATE % self.lister_version}
def request_instance_credentials(self) -> List[Dict[str, Any]]:
"""Returns dictionary of any credentials configuration needed by the
forge instance to list.
The 'credentials' configuration is expected to be a dict of multiple
levels. The first level is the lister's name, the second is the
lister's instance name, which value is expected to be a list of
credential structures (typically a couple username/password).
For example::
credentials:
github: # github lister
github: # has only one instance (so far)
- username: some
password: somekey
- username: one
password: onekey
- ...
gitlab: # gitlab lister
riseup: # has many instances
- username: someone
password: ...
- ...
gitlab:
- username: someone
password: ...
- ...
Returns:
list of credential dicts for the current lister.
"""
all_creds = self.config.get("credentials") # type: ignore
if not all_creds:
return []
lister_creds = all_creds.get(self.LISTER_NAME, {}) # type: ignore
creds = lister_creds.get(self.instance, []) # type: ignore
return creds
def request_uri(self, identifier: str) -> str:
"""Get the full request URI given the transport_request identifier.
MAY BE OVERRIDDEN if something more complex than the PATH_TEMPLATE is
required.
"""
path = self.PATH_TEMPLATE % identifier # type: ignore
return self.url + path
def request_params(self, identifier: str) -> Dict[str, Any]:
"""Get the full parameters passed to requests given the
transport_request identifier.
This uses credentials if any are provided (see
request_instance_credentials).
MAY BE OVERRIDDEN if something more complex than the request headers
is needed.
"""
params = {}
params["headers"] = self.request_headers() or {}
creds = self.request_instance_credentials()
if not creds:
return params
auth = random.choice(creds) if creds else None
if auth:
params["auth"] = (
auth["username"], # type: ignore
auth["password"],
)
return params
def transport_quota_check(self, response):
"""Implements ListerBase.transport_quota_check with standard 429
code check for HTTP with Requests library.
MAY BE OVERRIDDEN if the server notifies about rate limits in a
non-standard way that doesn't use HTTP 429 and the Retry-After
response header. ( https://tools.ietf.org/html/rfc6585#section-4 )
"""
if response.status_code == 429: # HTTP too many requests
retry_after = response.headers.get("Retry-After", self.back_off())
try:
# might be seconds
return True, float(retry_after)
except Exception:
# might be http-date
at_date = datetime(*parsedate(retry_after)[:6])
from_now = (at_date - datetime.today()).total_seconds() + 5
return True, max(0, from_now)
else: # response ok
self.reset_backoff()
return False, 0
def __init__(self, url=None):
if not url:
url = self.config.get("url")
if not url:
url = self.DEFAULT_URL
if not url:
raise NameError("HTTP Lister Transport requires an url.")
self.url = url # eg. 'https://api.github.com'
self.session = requests.Session()
self.lister_version = __version__
def _transport_action(self, identifier: str, method: str = "get") -> Response:
"""Permit to ask information to the api prior to actually executing
query.
"""
path = self.request_uri(identifier)
params = self.request_params(identifier)
logger.debug("path: %s", path)
logger.debug("params: %s", params)
logger.debug("method: %s", method)
try:
if method == "head":
response = self.session.head(path, **params)
else:
response = self.session.get(path, **params)
except requests.exceptions.ConnectionError as e:
logger.warning("Failed to fetch %s: %s", path, e)
raise FetchError(e)
else:
if response.status_code not in self.EXPECTED_STATUS_CODES:
raise FetchError(response)
return response
def transport_head(self, identifier: str) -> Response:
"""Retrieve head information on api.
"""
return self._transport_action(identifier, method="head")
def transport_request(self, identifier: str) -> Response:
"""Implements ListerBase.transport_request for HTTP using Requests.
Retrieve get information on api.
"""
return self._transport_action(identifier)
def transport_response_to_string(self, response: Response) -> str:
"""Implements ListerBase.transport_response_to_string for HTTP given
Requests responses.
"""
s = pformat(response.request.path_url)
s += "\n#\n" + pformat(response.request.headers)
s += "\n#\n" + pformat(response.status_code)
s += "\n#\n" + pformat(response.headers)
s += "\n#\n"
try: # json?
s += pformat(response.json())
except Exception: # not json
try: # xml?
s += pformat(xmltodict.parse(response.text))
except Exception: # not xml
s += pformat(response.text)
return s
class ListerOnePageApiTransport(ListerHttpTransport):
"""Leverage requests library to retrieve basic html page and parse
result.
To be used in conjunction with ListerBase or a subclass of it.
"""
PAGE = AbstractAttribute(
"URL of the API's unique page to retrieve and parse " "for information"
) # type: Union[AbstractAttribute, str]
PATH_TEMPLATE = None # we do not use it
def __init__(self, url=None):
self.session = requests.Session()
self.lister_version = __version__
def request_uri(self, _):
"""Get the full request URI given the transport_request identifier.
"""
return self.PAGE

View file

@ -1,77 +0,0 @@
# Copyright (C) 2015-2019 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import abc
from datetime import datetime
import logging
from typing import Type, Union
from sqlalchemy import Column, DateTime, Integer, String
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from .abstractattribute import AbstractAttribute
SQLBase = declarative_base()
logger = logging.getLogger(__name__)
class ABCSQLMeta(abc.ABCMeta, DeclarativeMeta):
pass
class ModelBase(SQLBase, metaclass=ABCSQLMeta):
"""a common repository"""
__abstract__ = True
__tablename__ = AbstractAttribute # type: Union[Type[AbstractAttribute], str]
uid = AbstractAttribute(
"Column(<uid_type>, primary_key=True)"
) # type: Union[AbstractAttribute, Column]
name = Column(String, index=True)
full_name = Column(String, index=True)
html_url = Column(String)
origin_url = Column(String)
origin_type = Column(String)
last_seen = Column(DateTime, nullable=False)
task_id = Column(Integer)
def __init__(self, **kw):
kw["last_seen"] = datetime.now()
super().__init__(**kw)
class IndexingModelBase(ModelBase, metaclass=ABCSQLMeta):
__abstract__ = True
__tablename__ = AbstractAttribute # type: Union[Type[AbstractAttribute], str]
# The value used for sorting, segmenting, or api query paging,
# because uids aren't always sequential.
indexable = AbstractAttribute(
"Column(<indexable_type>, index=True)"
) # type: Union[AbstractAttribute, Column]
def initialize(db_engine, drop_tables=False, **kwargs):
"""Default database initialization function for a lister.
Typically called from the lister's initialization hook.
Args:
models (list): list of SQLAlchemy tables/models to drop/create.
db_engine (): the SQLAlchemy DB engine.
drop_tables (bool): if True, tables will be dropped before
(re)creating them.
"""
if drop_tables:
logger.info("Dropping tables")
SQLBase.metadata.drop_all(db_engine, checkfirst=True)
logger.info("Creating tables")
SQLBase.metadata.create_all(db_engine, checkfirst=True)

View file

@ -1,96 +0,0 @@
# Copyright (C) 2018-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
from typing import Any, List
from swh.core import utils
from .lister_base import ListerBase
logger = logging.getLogger(__name__)
class SimpleLister(ListerBase):
"""Lister* intermediate class for any service that follows the simple,
'list in oneshot information' pattern.
- Client sends a request to list repositories in oneshot
- Client receives structured (json/xml/etc) response with
information and stores those in db
"""
flush_packet_db = 2
"""Number of iterations in-between write flushes of lister repositories to
db (see fn:`ingest_data`).
"""
def list_packages(self, response: Any) -> List[Any]:
"""Listing packages method.
"""
pass
def ingest_data(self, identifier, checks=False):
"""Rework the base ingest_data.
Request server endpoint which gives all in one go.
Simplify and filter response list of repositories. Inject
repo information into local db. Queue loader tasks for
linked repositories.
Args:
identifier: Resource identifier (unused)
checks (bool): Additional checks required (unused)
"""
response = self.safely_issue_request(identifier)
response = self.list_packages(response)
if not response:
return response, []
models_list = self.transport_response_simplified(response)
models_list = self.filter_before_inject(models_list)
all_injected = []
for i, models in enumerate(utils.grouper(models_list, n=100), start=1):
models = list(models)
logging.debug("models: %s" % len(models))
# inject into local db
injected = self.inject_repo_data_into_db(models)
# queue workers
self.schedule_missing_tasks(models, injected)
all_injected.append(injected)
if (i % self.flush_packet_db) == 0:
logger.debug("Flushing updates at index %s", i)
self.db_session.commit()
self.db_session = self.mk_session()
return response, all_injected
def transport_response_simplified(self, response):
"""Transform response to list for model manipulation
"""
return [self.get_model_from_repo(repo_name) for repo_name in response]
def run(self):
"""Query the server which answers in one query. Stores the
information, dropping actual redundant information we
already have.
Returns:
nothing
"""
dump_not_used_identifier = 0
response, injected_repos = self.ingest_data(dump_not_used_identifier)
if not response and not injected_repos:
logging.info("No response from api server, stopping")
status = "uneventful"
else:
status = "eventful"
return {"status": status}

View file

@ -1,64 +0,0 @@
# Copyright (C) 2017 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import abc
from typing import Any
import unittest
from swh.lister.core.abstractattribute import AbstractAttribute
class BaseClass(abc.ABC):
v1 = AbstractAttribute # type: Any
v2 = AbstractAttribute() # type: Any
v3 = AbstractAttribute("changed docstring") # type: Any
v4 = "qux"
class BadSubclass1(BaseClass):
pass
class BadSubclass2(BaseClass):
v1 = "foo"
v2 = "bar"
class BadSubclass3(BaseClass):
v2 = "bar"
v3 = "baz"
class GoodSubclass(BaseClass):
v1 = "foo"
v2 = "bar"
v3 = "baz"
class TestAbstractAttributes(unittest.TestCase):
def test_aa(self):
with self.assertRaises(TypeError):
BaseClass()
with self.assertRaises(TypeError):
BadSubclass1()
with self.assertRaises(TypeError):
BadSubclass2()
with self.assertRaises(TypeError):
BadSubclass3()
self.assertIsInstance(GoodSubclass(), GoodSubclass)
gsc = GoodSubclass()
self.assertEqual(gsc.v1, "foo")
self.assertEqual(gsc.v2, "bar")
self.assertEqual(gsc.v3, "baz")
self.assertEqual(gsc.v4, "qux")
def test_aa_docstrings(self):
self.assertEqual(BaseClass.v1.__doc__, AbstractAttribute.__doc__)
self.assertEqual(BaseClass.v2.__doc__, AbstractAttribute.__doc__)
self.assertEqual(BaseClass.v3.__doc__, "AbstractAttribute: changed docstring")

View file

@ -1,453 +0,0 @@
# Copyright (C) 2019 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import abc
import datetime
import time
from typing import Any, Callable, Optional, Pattern, Type, Union
from unittest import TestCase
from unittest.mock import Mock, patch
import requests_mock
from sqlalchemy import create_engine
import swh.lister
from swh.lister.core.abstractattribute import AbstractAttribute
from swh.lister.tests.test_utils import init_db
def noop(*args, **kwargs):
pass
def test_version_generation():
assert (
swh.lister.__version__ != "devel"
), "Make sure swh.lister is installed (e.g. pip install -e .)"
class HttpListerTesterBase(abc.ABC):
"""Testing base class for listers.
This contains methods for both :class:`HttpSimpleListerTester` and
:class:`HttpListerTester`.
See :class:`swh.lister.gitlab.tests.test_lister` for an example of how
to customize for a specific listing service.
"""
Lister = AbstractAttribute(
"Lister class to test"
) # type: Union[AbstractAttribute, Type[Any]]
lister_subdir = AbstractAttribute(
"bitbucket, github, etc."
) # type: Union[AbstractAttribute, str]
good_api_response_file = AbstractAttribute(
"Example good response body"
) # type: Union[AbstractAttribute, str]
LISTER_NAME = "fake-lister"
# May need to override this if the headers are used for something
def response_headers(self, request):
return {}
# May need to override this if the server uses non-standard rate limiting
# method.
# Please keep the requested retry delay reasonably low.
def mock_rate_quota(self, n, request, context):
self.rate_limit += 1
context.status_code = 429
context.headers["Retry-After"] = "1"
return '{"error":"dummy"}'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.rate_limit = 1
self.response = None
self.fl = None
self.helper = None
self.scheduler_tasks = []
if self.__class__ != HttpListerTesterBase:
self.run = TestCase.run.__get__(self, self.__class__)
else:
self.run = noop
def mock_limit_n_response(self, n, request, context):
self.fl.reset_backoff()
if self.rate_limit <= n:
return self.mock_rate_quota(n, request, context)
else:
return self.mock_response(request, context)
def mock_limit_twice_response(self, request, context):
return self.mock_limit_n_response(2, request, context)
def get_api_response(self, identifier):
fl = self.get_fl()
if self.response is None:
self.response = fl.safely_issue_request(identifier)
return self.response
def get_fl(self, override_config=None):
"""Retrieve an instance of fake lister (fl).
"""
if override_config or self.fl is None:
self.fl = self.Lister(
url="https://fakeurl", override_config=override_config
)
self.fl.INITIAL_BACKOFF = 1
self.fl.reset_backoff()
self.scheduler_tasks = []
return self.fl
def disable_scheduler(self, fl):
fl.schedule_missing_tasks = Mock(return_value=None)
def mock_scheduler(self, fl):
def _create_tasks(tasks):
task_id = 0
current_nb_tasks = len(self.scheduler_tasks)
if current_nb_tasks > 0:
task_id = self.scheduler_tasks[-1]["id"] + 1
for task in tasks:
scheduler_task = dict(task)
scheduler_task.update(
{
"status": "next_run_not_scheduled",
"retries_left": 0,
"priority": None,
"id": task_id,
"current_interval": datetime.timedelta(days=64),
}
)
self.scheduler_tasks.append(scheduler_task)
task_id = task_id + 1
return self.scheduler_tasks[current_nb_tasks:]
def _disable_tasks(task_ids):
for task_id in task_ids:
self.scheduler_tasks[task_id]["status"] = "disabled"
fl.scheduler.create_tasks = Mock(wraps=_create_tasks)
fl.scheduler.disable_tasks = Mock(wraps=_disable_tasks)
def disable_db(self, fl):
fl.winnow_models = Mock(return_value=[])
fl.db_inject_repo = Mock(return_value=fl.MODEL())
fl.disable_deleted_repo_tasks = Mock(return_value=None)
def init_db(self, db, model):
engine = create_engine(db.url())
model.metadata.create_all(engine)
@requests_mock.Mocker()
def test_is_within_bounds(self, http_mocker):
fl = self.get_fl()
self.assertFalse(fl.is_within_bounds(1, 2, 3))
self.assertTrue(fl.is_within_bounds(2, 1, 3))
self.assertTrue(fl.is_within_bounds(1, 1, 1))
self.assertTrue(fl.is_within_bounds(1, None, None))
self.assertTrue(fl.is_within_bounds(1, None, 2))
self.assertTrue(fl.is_within_bounds(1, 0, None))
self.assertTrue(fl.is_within_bounds("b", "a", "c"))
self.assertFalse(fl.is_within_bounds("a", "b", "c"))
self.assertTrue(fl.is_within_bounds("a", None, "c"))
self.assertTrue(fl.is_within_bounds("a", None, None))
self.assertTrue(fl.is_within_bounds("b", "a", None))
self.assertFalse(fl.is_within_bounds("a", "b", None))
self.assertTrue(fl.is_within_bounds("aa:02", "aa:01", "aa:03"))
self.assertFalse(fl.is_within_bounds("aa:12", None, "aa:03"))
with self.assertRaises(TypeError):
fl.is_within_bounds(1.0, "b", None)
with self.assertRaises(TypeError):
fl.is_within_bounds("A:B", "A::B", None)
class HttpListerTester(HttpListerTesterBase, abc.ABC):
"""Base testing class for subclass of
:class:`swh.lister.core.indexing_lister.IndexingHttpLister`
See :class:`swh.lister.github.tests.test_gh_lister` for an example of how
to customize for a specific listing service.
"""
last_index = AbstractAttribute(
"Last index " "in good_api_response"
) # type: Union[AbstractAttribute, int]
first_index = AbstractAttribute(
"First index in " " good_api_response"
) # type: Union[AbstractAttribute, Optional[int]]
bad_api_response_file = AbstractAttribute(
"Example bad response body"
) # type: Union[AbstractAttribute, str]
entries_per_page = AbstractAttribute(
"Number of results in " "good response"
) # type: Union[AbstractAttribute, int]
test_re = AbstractAttribute(
"Compiled regex matching the server url. Must capture the " "index value."
) # type: Union[AbstractAttribute, Pattern]
convert_type = str # type: Callable[..., Any]
"""static method used to convert the "request_index" to its right type (for
indexing listers for example, this is in accordance with the model's
"indexable" column).
"""
def mock_response(self, request, context):
self.fl.reset_backoff()
self.rate_limit = 1
context.status_code = 200
custom_headers = self.response_headers(request)
context.headers.update(custom_headers)
req_index = self.request_index(request)
if req_index == self.first_index:
response_file = self.good_api_response_file
else:
response_file = self.bad_api_response_file
with open(
"swh/lister/%s/tests/%s" % (self.lister_subdir, response_file),
"r",
encoding="utf-8",
) as r:
return r.read()
def request_index(self, request):
m = self.test_re.search(request.path_url)
if m and (len(m.groups()) > 0):
return self.convert_type(m.group(1))
def create_fl_with_db(self, http_mocker):
http_mocker.get(self.test_re, text=self.mock_response)
db = init_db()
fl = self.get_fl(
override_config={"lister": {"cls": "local", "args": {"db": db.url()}}}
)
fl.db = db
self.init_db(db, fl.MODEL)
self.mock_scheduler(fl)
return fl
@requests_mock.Mocker()
def test_fetch_no_bounds_yesdb(self, http_mocker):
fl = self.create_fl_with_db(http_mocker)
fl.run()
self.assertEqual(fl.db_last_index(), self.last_index)
ingested_repos = list(fl.db_query_range(self.first_index, self.last_index))
self.assertEqual(len(ingested_repos), self.entries_per_page)
@requests_mock.Mocker()
def test_fetch_multiple_pages_yesdb(self, http_mocker):
fl = self.create_fl_with_db(http_mocker)
fl.run(min_bound=self.first_index)
self.assertEqual(fl.db_last_index(), self.last_index)
partitions = fl.db_partition_indices(5)
self.assertGreater(len(partitions), 0)
for k in partitions:
self.assertLessEqual(len(k), 5)
self.assertGreater(len(k), 0)
@requests_mock.Mocker()
def test_fetch_none_nodb(self, http_mocker):
http_mocker.get(self.test_re, text=self.mock_response)
fl = self.get_fl()
self.disable_scheduler(fl)
self.disable_db(fl)
fl.run(min_bound=1, max_bound=1) # stores no results
# FIXME: Determine what this method tries to test and add checks to
# actually test
@requests_mock.Mocker()
def test_fetch_one_nodb(self, http_mocker):
http_mocker.get(self.test_re, text=self.mock_response)
fl = self.get_fl()
self.disable_scheduler(fl)
self.disable_db(fl)
fl.run(min_bound=self.first_index, max_bound=self.first_index)
# FIXME: Determine what this method tries to test and add checks to
# actually test
@requests_mock.Mocker()
def test_fetch_multiple_pages_nodb(self, http_mocker):
http_mocker.get(self.test_re, text=self.mock_response)
fl = self.get_fl()
self.disable_scheduler(fl)
self.disable_db(fl)
fl.run(min_bound=self.first_index)
# FIXME: Determine what this method tries to test and add checks to
# actually test
@requests_mock.Mocker()
def test_repos_list(self, http_mocker):
"""Test the number of repos listed by the lister
"""
http_mocker.get(self.test_re, text=self.mock_response)
li = self.get_fl().transport_response_simplified(
self.get_api_response(self.first_index)
)
self.assertIsInstance(li, list)
self.assertEqual(len(li), self.entries_per_page)
@requests_mock.Mocker()
def test_model_map(self, http_mocker):
"""Check if all the keys of model are present in the model created by
the `transport_response_simplified`
"""
http_mocker.get(self.test_re, text=self.mock_response)
fl = self.get_fl()
li = fl.transport_response_simplified(self.get_api_response(self.first_index))
di = li[0]
self.assertIsInstance(di, dict)
pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith("_")]
for k in pubs:
if k not in ["last_seen", "task_id", "id"]:
self.assertIn(k, di)
@requests_mock.Mocker()
def test_api_request(self, http_mocker):
"""Test API request for rate limit handling
"""
http_mocker.get(self.test_re, text=self.mock_limit_twice_response)
with patch.object(time, "sleep", wraps=time.sleep) as sleepmock:
self.get_api_response(self.first_index)
self.assertEqual(sleepmock.call_count, 2)
@requests_mock.Mocker()
def test_request_headers(self, http_mocker):
fl = self.create_fl_with_db(http_mocker)
fl.run()
self.assertNotEqual(len(http_mocker.request_history), 0)
for request in http_mocker.request_history:
assert "User-Agent" in request.headers
user_agent = request.headers["User-Agent"]
assert "Software Heritage Lister" in user_agent
assert swh.lister.__version__ in user_agent
def scheduled_tasks_test(
self, next_api_response_file, next_last_index, http_mocker
):
"""Check that no loading tasks get disabled when processing a new
page of repositories returned by a forge API
"""
fl = self.create_fl_with_db(http_mocker)
# process first page of repositories listing
fl.run()
# process second page of repositories listing
prev_last_index = self.last_index
self.first_index = self.last_index
self.last_index = next_last_index
self.good_api_response_file = next_api_response_file
fl.run(min_bound=prev_last_index)
# check expected number of ingested repos and loading tasks
ingested_repos = list(fl.db_query_range(0, self.last_index))
self.assertEqual(len(ingested_repos), len(self.scheduler_tasks))
self.assertEqual(len(ingested_repos), 2 * self.entries_per_page)
# check tasks are not disabled
for task in self.scheduler_tasks:
self.assertTrue(task["status"] != "disabled")
class HttpSimpleListerTester(HttpListerTesterBase, abc.ABC):
"""Base testing class for subclass of
:class:`swh.lister.core.simple)_lister.SimpleLister`
See :class:`swh.lister.pypi.tests.test_lister` for an example of how
to customize for a specific listing service.
"""
entries = AbstractAttribute(
"Number of results " "in good response"
) # type: Union[AbstractAttribute, int]
PAGE = AbstractAttribute(
"URL of the server api's unique page to retrieve and " "parse for information"
) # type: Union[AbstractAttribute, str]
def get_fl(self, override_config=None):
"""Retrieve an instance of fake lister (fl).
"""
if override_config or self.fl is None:
self.fl = self.Lister(override_config=override_config)
self.fl.INITIAL_BACKOFF = 1
self.fl.reset_backoff()
return self.fl
def mock_response(self, request, context):
self.fl.reset_backoff()
self.rate_limit = 1
context.status_code = 200
custom_headers = self.response_headers(request)
context.headers.update(custom_headers)
response_file = self.good_api_response_file
with open(
"swh/lister/%s/tests/%s" % (self.lister_subdir, response_file),
"r",
encoding="utf-8",
) as r:
return r.read()
@requests_mock.Mocker()
def test_api_request(self, http_mocker):
"""Test API request for rate limit handling
"""
http_mocker.get(self.PAGE, text=self.mock_limit_twice_response)
with patch.object(time, "sleep", wraps=time.sleep) as sleepmock:
self.get_api_response(0)
self.assertEqual(sleepmock.call_count, 2)
@requests_mock.Mocker()
def test_model_map(self, http_mocker):
"""Check if all the keys of model are present in the model created by
the `transport_response_simplified`
"""
http_mocker.get(self.PAGE, text=self.mock_response)
fl = self.get_fl()
li = fl.list_packages(self.get_api_response(0))
li = fl.transport_response_simplified(li)
di = li[0]
self.assertIsInstance(di, dict)
pubs = [k for k in vars(fl.MODEL).keys() if not k.startswith("_")]
for k in pubs:
if k not in ["last_seen", "task_id", "id"]:
self.assertIn(k, di)
@requests_mock.Mocker()
def test_repos_list(self, http_mocker):
"""Test the number of packages listed by the lister
"""
http_mocker.get(self.PAGE, text=self.mock_response)
li = self.get_fl().list_packages(self.get_api_response(0))
self.assertIsInstance(li, list)
self.assertEqual(len(li), self.entries)

View file

@ -1,91 +0,0 @@
# Copyright (C) 2017 the Software Heritage developers
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import unittest
from sqlalchemy import Column, Integer
from swh.lister.core.models import IndexingModelBase, ModelBase
class BadSubclass1(ModelBase):
__abstract__ = True
pass
class BadSubclass2(ModelBase):
__abstract__ = True
__tablename__ = "foo"
class BadSubclass3(BadSubclass2):
__abstract__ = True
pass
class GoodSubclass(BadSubclass2):
uid = Column(Integer, primary_key=True)
indexable = Column(Integer, index=True)
class IndexingBadSubclass(IndexingModelBase):
__abstract__ = True
pass
class IndexingBadSubclass2(IndexingModelBase):
__abstract__ = True
__tablename__ = "foo"
class IndexingBadSubclass3(IndexingBadSubclass2):
__abstract__ = True
pass
class IndexingGoodSubclass(IndexingModelBase):
uid = Column(Integer, primary_key=True)
indexable = Column(Integer, index=True)
__tablename__ = "bar"
class TestModel(unittest.TestCase):
def test_model_instancing(self):
with self.assertRaises(TypeError):
ModelBase()
with self.assertRaises(TypeError):
BadSubclass1()
with self.assertRaises(TypeError):
BadSubclass2()
with self.assertRaises(TypeError):
BadSubclass3()
self.assertIsInstance(GoodSubclass(), GoodSubclass)
gsc = GoodSubclass(uid="uid")
self.assertEqual(gsc.__tablename__, "foo")
self.assertEqual(gsc.uid, "uid")
def test_indexing_model_instancing(self):
with self.assertRaises(TypeError):
IndexingModelBase()
with self.assertRaises(TypeError):
IndexingBadSubclass()
with self.assertRaises(TypeError):
IndexingBadSubclass2()
with self.assertRaises(TypeError):
IndexingBadSubclass3()
self.assertIsInstance(IndexingGoodSubclass(), IndexingGoodSubclass)
gsc = IndexingGoodSubclass(uid="uid", indexable="indexable")
self.assertEqual(gsc.__tablename__, "bar")
self.assertEqual(gsc.uid, "uid")
self.assertEqual(gsc.indexable, "indexable")

View file

@ -1,62 +0,0 @@
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
import os
import pytest
from sqlalchemy import create_engine
import yaml
from swh.lister import SUPPORTED_LISTERS, get_lister
from swh.lister.core.models import initialize
logger = logging.getLogger(__name__)
@pytest.fixture
def lister_db_url(postgresql):
db_params = postgresql.get_dsn_parameters()
db_url = "postgresql://{user}@{host}:{port}/{dbname}".format(**db_params)
logger.debug("lister db_url: %s", db_url)
return db_url
@pytest.fixture
def lister_under_test():
"""Fixture to determine which lister to test"""
return "core"
@pytest.fixture
def swh_lister_config(lister_db_url, swh_scheduler_config):
return {
"scheduler": {"cls": "local", **swh_scheduler_config},
"lister": {"cls": "local", "args": {"db": lister_db_url},},
"credentials": {},
"cache_responses": False,
}
@pytest.fixture(autouse=True)
def swh_config(swh_lister_config, monkeypatch, tmp_path):
conf_path = os.path.join(str(tmp_path), "lister.yml")
with open(conf_path, "w") as f:
f.write(yaml.dump(swh_lister_config))
monkeypatch.setenv("SWH_CONFIG_FILENAME", conf_path)
return conf_path
@pytest.fixture
def engine(lister_db_url):
engine = create_engine(lister_db_url)
initialize(engine, drop_tables=True)
return engine
@pytest.fixture
def swh_lister(engine, lister_db_url, lister_under_test, swh_config):
assert lister_under_test in SUPPORTED_LISTERS
return get_lister(lister_under_test, db_url=lister_db_url)

View file

@ -1,4 +1,4 @@
# Copyright (C) 2019-2020 The Software Heritage developers
# Copyright (C) 2019-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
@ -7,8 +7,6 @@ import pytest
from swh.lister.cli import SUPPORTED_LISTERS, get_lister
from .test_utils import init_db
lister_args = {
"cgit": {"url": "https://git.eclipse.org/c/",},
"phabricator": {
@ -33,13 +31,11 @@ def test_get_lister(swh_scheduler_config):
"""Instantiating a supported lister should be ok
"""
db_url = init_db().url()
# Drop launchpad lister from the lister to check, its test setup is more involved
# than the other listers and it's not currently done here
for lister_name in SUPPORTED_LISTERS:
lst = get_lister(
lister_name,
db_url,
scheduler={"cls": "local", **swh_scheduler_config},
**lister_args.get(lister_name, {}),
)

View file

@ -6,7 +6,6 @@ import pytest
import requests
from requests.status_codes import codes
from tenacity.wait import wait_fixed
from testing.postgresql import Postgresql
from swh.lister.utils import (
MAX_NUMBER_ATTEMPTS,
@ -37,18 +36,6 @@ def test_split_range_errors(total_pages, nb_pages):
next(split_range(total_pages, nb_pages))
def init_db():
"""Factorize the db_url instantiation
Returns:
db object to ease db manipulation
"""
initdb_args = Postgresql.DEFAULT_SETTINGS["initdb_args"]
initdb_args = " ".join([initdb_args, "-E UTF-8"])
return Postgresql(initdb_args=initdb_args)
TEST_URL = "https://example.og/api/repositories"