lister.debian: Move run method parameters to constructor

This commit is contained in:
Antoine R. Dumont (@ardumont) 2019-11-05 11:56:40 +01:00
parent b745c5a735
commit e0dbca759c
No known key found for this signature in database
GPG key ID: 52E2E9840D10C3B8
5 changed files with 38 additions and 20 deletions

View file

@ -3,11 +3,11 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from typing import Any, List, Mapping, Optional
from typing import Any, List, Mapping
def debian_init(db_engine, lister=None,
override_conf: Optional[Mapping[str, Any]] = None,
override_conf: Mapping[str, Any] = {},
distributions: List[str] = ['stretch', 'buster'],
area_names: List[str] = ['main', 'contrib', 'non-free']):
"""Initialize the debian data model.
@ -15,26 +15,27 @@ def debian_init(db_engine, lister=None,
Args:
db_engine: SQLAlchemy manipulation database object
lister: Debian lister instance. None by default.
override_conf: Override conf to pass to instantiate a lister.
None by default
override_conf: Override conf to pass to instantiate a lister
distributions: Default distribution to build
"""
distribution_name = 'Debian'
from swh.storage.schemata.distribution import (
Distribution, Area)
if lister is None:
from .lister import DebianLister
lister = DebianLister(override_config=override_conf)
lister = DebianLister(distribution=distribution_name,
override_config=override_conf)
if not lister.db_session\
.query(Distribution)\
.filter(Distribution.name == 'Debian')\
.filter(Distribution.name == distribution_name)\
.one_or_none():
d = Distribution(
name='Debian',
name=distribution_name,
type='deb',
mirror_uri='http://deb.debian.org/debian/')
lister.db_session.add(d)

View file

@ -13,6 +13,7 @@ import logging
from debian.deb822 import Sources
from sqlalchemy.orm import joinedload, load_only
from sqlalchemy.schema import CreateTable, DropTable
from typing import Mapping, Optional
from swh.lister.debian.models import (
AreaSnapshot, Distribution, DistributionSnapshot, Package,
@ -38,9 +39,24 @@ class DebianLister(ListerHttpTransport, ListerBase):
LISTER_NAME = 'debian'
instance = 'debian'
def __init__(self, override_config=None):
def __init__(self, distribution: str = 'Debian',
date: Optional[datetime.datetime] = None,
override_config: Mapping = {}):
"""Initialize the debian lister for a given distribution at a given
date.
Args:
distribution: name of the distribution (e.g. "Debian")
date: date the snapshot is taken (defaults to now if empty)
override_config: Override configuration (which takes precedence
over the parameters if provided)
"""
ListerHttpTransport.__init__(self, url="notused")
ListerBase.__init__(self, override_config=override_config)
self.distribution = override_config.get('distribution', distribution)
self.date = override_config.get('date', date) or datetime.datetime.now(
tz=datetime.timezone.utc)
def transport_request(self, identifier):
"""Subvert ListerHttpTransport.transport_request, to try several
@ -189,29 +205,25 @@ class DebianLister(ListerHttpTransport, ListerBase):
return self.scheduler.create_tasks(tasks)
def run(self, distribution='Debian', date=None):
def run(self):
"""Run the lister for a given (distribution, area) tuple.
Args:
distribution (str): name of the distribution (e.g. "Debian")
date (datetime.datetime): date the snapshot is taken (defaults to
now)
"""
distribution = self.db_session\
.query(Distribution)\
.options(joinedload(Distribution.areas))\
.filter(Distribution.name == distribution)\
.filter(Distribution.name == self.distribution)\
.one_or_none()
if not distribution:
raise ValueError("Distribution %s is not registered" %
distribution)
self.distribution)
if not distribution.type == 'deb':
raise ValueError("Distribution %s is not a Debian derivative" %
distribution)
date = date or datetime.datetime.now(tz=datetime.timezone.utc)
date = self.date
logger.debug('Creating snapshot for distribution %s on date %s' %
(distribution, date))

View file

@ -10,7 +10,7 @@ from .lister import DebianLister
@shared_task(name=__name__ + '.DebianListerTask')
def list_debian_distribution(distribution, **lister_args):
'''List a Debian distribution'''
DebianLister(**lister_args).run(distribution)
DebianLister(distribution=distribution, **lister_args).run()
@shared_task(name=__name__ + '.ping')

View file

@ -14,7 +14,7 @@ def test_lister_debian(lister_debian, datadir, requests_mock_datadir):
"""
# Run the lister
lister_debian.run(distribution="Debian")
lister_debian.run()
r = lister_debian.scheduler.search_tasks(task_type='load-deb-package')
assert len(r) == 151

View file

@ -1,3 +1,8 @@
# Copyright (C) 2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from unittest.mock import patch
@ -22,5 +27,5 @@ def test_lister(lister, swh_app, celery_session_worker):
res.wait()
assert res.successful()
lister.assert_called_once_with()
lister.run.assert_called_once_with('stretch')
lister.assert_called_once_with(distribution='stretch')
lister.run.assert_called_once_with()