diff --git a/bin/ghlister b/bin/ghlister index 20665fe..594f138 100755 --- a/bin/ghlister +++ b/bin/ghlister @@ -5,17 +5,11 @@ # See top-level LICENSE file for more information import argparse -import configparser import logging -import os import sys -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from swh.lister.github import lister, models -from swh.lister.github.db_utils import session_scope - +from swh.lister.github import models +from swh.lister.github.lister import GitHubLister DEFAULT_CONF = { 'cache_dir': './cache', @@ -24,13 +18,6 @@ DEFAULT_CONF = { } -def db_connect(db_url): - engine = create_engine(db_url) - session = sessionmaker(bind=engine) - - return (engine, session) - - def int_interval(s): """parse an "N-M" string as an interval. @@ -84,67 +71,31 @@ def parse_args(): return args -def read_conf(args): - config = configparser.ConfigParser(defaults=DEFAULT_CONF) - config.read(os.path.expanduser('~/.config/swh/lister-github.ini')) - - conf = config._sections['main'] - - # overrides - if args.db_url: - conf['db_url'] = args.db_url - - # typing - if 'cache_json' in conf and conf['cache_json'].lower() == 'true': - conf['cache_json'] = True - else: - conf['cache_json'] = False - - if 'credentials' in conf: - credentials = conf['credentials'].split() - conf['credentials'] = [] - for user_pair in credentials: - username, password = user_pair.split(':') - conf['credentials'].append({ - 'username': username, - 'password': password, - }) - else: - conf['credentials'] = [{ - 'username': conf['username'], - 'password': conf['password'], - }] - - return conf - - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) # XXX args = parse_args() - conf = read_conf(args) - db_engine, mk_session = db_connect(conf['db_url']) + override_conf = {} + if args.db_url: + override_conf['lister_db_url'] = args.db_url + + lister = GitHubLister(override_conf) if args.action == 'createdb': - models.SQLBase.metadata.create_all(db_engine) + models.SQLBase.metadata.create_all(lister.db_engine) elif args.action == 'dropdb': - models.SQLBase.metadata.drop_all(db_engine) + models.SQLBase.metadata.drop_all(lister.db_engine) elif args.action == 'list': - lister.fetch(conf, - mk_session, - min_id=args.interval[0], + lister.fetch(min_id=args.interval[0], max_id=args.interval[1]) elif args.action == 'catchup': - with session_scope(mk_session) as db_session: - last_known_id = lister.last_repo_id(db_session) - if last_known_id is not None: - logging.info('catching up from last known repo id: %d' % - last_known_id) - lister.fetch(conf, - mk_session, - min_id=last_known_id + 1, - max_id=None) - else: - logging.error('Cannot catchup: no last known id found. Abort.') - sys.exit(2) + last_known_id = lister.last_repo_id() + if last_known_id is not None: + logging.info('catching up from last known repo id: %d' % + last_known_id) + lister.fetch(min_id=last_known_id + 1, + max_id=None) + else: + logging.error('Cannot catchup: no last known id found. Abort.') + sys.exit(2) diff --git a/swh/lister/github/base.py b/swh/lister/github/base.py index a8f974c..ac2fbd9 100644 --- a/swh/lister/github/base.py +++ b/swh/lister/github/base.py @@ -7,6 +7,8 @@ from swh.storage import get_storage from swh.scheduler.backend import SchedulerBackend +# TODO: split this into a lister-agnostic module + class SWHLister(config.SWHConfig): CONFIG_BASE_FILENAME = None @@ -14,7 +16,7 @@ class SWHLister(config.SWHConfig): 'storage_class': ('str', 'remote_storage'), 'storage_args': ('list[str]', ['http://localhost:5000/']), - 'scheduling_db': ('str', 'dbname=swh-scheduler'), + 'scheduling_db': ('str', 'dbname=softwareheritage-scheduler'), } ADDITIONAL_CONFIG = {} diff --git a/swh/lister/github/lister.py b/swh/lister/github/lister.py index f441bd0..204584d 100644 --- a/swh/lister/github/lister.py +++ b/swh/lister/github/lister.py @@ -14,8 +14,13 @@ import requests import time from pprint import pformat -from sqlalchemy import func +from sqlalchemy import create_engine, func +from sqlalchemy.orm import sessionmaker + +from swh.core import config +from swh.lister.github.base import SWHLister +from swh.lister.github.db_utils import session_scope from swh.lister.github.models import Repository @@ -27,6 +32,15 @@ CONN_SLEEP = 10 REPO_API_URL_RE = re.compile(r'^.*/repositories\?since=(\d+)') +class FetchError(RuntimeError): + + def __init__(self, response): + self.response = response + + def __str__(self): + return repr(self.response) + + def save_http_response(r, cache_dir): def escape_url_path(p): return p.replace('/', '__') @@ -97,86 +111,114 @@ def gh_api_request(path, username=None, password=None, session=None, return r -def lookup_repo(db_session, repo_id): - return db_session.query(Repository) \ - .filter(Repository.id == repo_id) \ - .first() +class GitHubLister(SWHLister): + CONFIG_BASE_FILENAME = 'lister-github' + ADDITIONAL_CONFIG = { + 'lister_db_url': ('str', 'postgresql:///lister-github'), + 'credentials': ('list[dict]', []), + 'cache_json': ('bool', False), + 'cache_dir': ('str', '~/.cache/swh/lister/github'), + } + def __init__(self, override_config=None): + super().__init__() + if override_config: + self.config.update(override_config) -def last_repo_id(db_session): - t = db_session.query(func.max(Repository.id)) \ - .first() - if t is not None: - return t[0] - # else: return None + self.config['cache_dir'] = os.path.expanduser(self.config['cache_dir']) + if self.config['cache_json']: + config.prepare_folders(self.config, ['cache_dir']) + if not self.config['credentials']: + raise ValueError('The GitHub lister needs credentials for API') -INJECT_KEYS = ['id', 'name', 'full_name', 'html_url', 'description', 'fork'] + self.db_engine = create_engine(self.config['lister_db_url']) + self.mk_session = sessionmaker(bind=self.db_engine) + def lookup_repo(self, repo_id, db_session=None): + if not db_session: + with session_scope(self.mk_session) as db_session: + return self.lookup_repo(repo_id, db_session=db_session) -def inject_repo(db_session, repo): - logging.debug('injecting repo %d' % repo['id']) - sql_repo = lookup_repo(db_session, repo['id']) - if not sql_repo: - kwargs = {k: repo[k] for k in INJECT_KEYS if k in repo} - sql_repo = Repository(**kwargs) - db_session.add(sql_repo) - else: - for k in INJECT_KEYS: - if k in repo: - setattr(sql_repo, k, repo[k]) - sql_repo.last_seen = datetime.datetime.now() + return db_session.query(Repository) \ + .filter(Repository.id == repo_id) \ + .first() + def last_repo_id(self, db_session=None): + if not db_session: + with session_scope(self.mk_session) as db_session: + return self.last_repo_id(db_session=db_session) -class FetchError(RuntimeError): + t = db_session.query(func.max(Repository.id)).first() - def __init__(self, response): - self.response = response + if t is not None: + return t[0] - def __str__(self): - return repr(self.response) + INJECT_KEYS = ['id', 'name', 'full_name', 'html_url', 'description', + 'fork'] - -def fetch(conf, mk_session, min_id=None, max_id=None): - if min_id is None: - min_id = 1 - if max_id is None: - max_id = float('inf') - next_id = min_id - - session = requests.Session() - db_session = mk_session() - loop_count = 0 - while min_id <= next_id <= max_id: - logging.info('listing repos starting at %d' % next_id) - since = next_id - 1 # github API ?since=... is '>' strict, not '>=' - - cred = random.choice(conf['credentials']) - repos_res = gh_api_request('/repositories?since=%d' % since, - session=session, **cred) - - if 'cache_dir' in conf and conf['cache_json']: - save_http_response(repos_res, conf['cache_dir']) - if not repos_res.ok: - raise FetchError(repos_res) - - repos = repos_res.json() - for repo in repos: - if repo['id'] > max_id: # do not overstep max_id - break - inject_repo(db_session, repo) - - if 'next' in repos_res.links: - next_url = repos_res.links['next']['url'] - m = REPO_API_URL_RE.match(next_url) # parse next_id - next_id = int(m.group(1)) + 1 + def inject_repo(self, db_session, repo): + logging.debug('injecting repo %d' % repo['id']) + sql_repo = self.lookup_repo(repo['id'], db_session) + if not sql_repo: + kwargs = {k: repo[k] for k in self.INJECT_KEYS if k in repo} + sql_repo = Repository(**kwargs) + db_session.add(sql_repo) else: - logging.info('stopping after id %d, no next link found' % next_id) - break - loop_count += 1 - if loop_count == 20: - logging.info('flushing updates') - loop_count = 0 - db_session.commit() - db_session = mk_session() - db_session.commit() + for k in self.INJECT_KEYS: + if k in repo: + setattr(sql_repo, k, repo[k]) + sql_repo.last_seen = datetime.datetime.now() + + def fetch(self, min_id=None, max_id=None): + if min_id is None: + min_id = 1 + if max_id is None: + max_id = float('inf') + next_id = min_id + + do_cache = self.config['cache_json'] + cache_dir = self.config['cache_dir'] + + session = requests.Session() + db_session = self.mk_session() + loop_count = 0 + while min_id <= next_id <= max_id: + logging.info('listing repos starting at %d' % next_id) + + # github API ?since=... is '>' strict, not '>=' + since = next_id - 1 + + cred = random.choice(self.config['credentials']) + repos_res = gh_api_request('/repositories?since=%d' % since, + session=session, **cred) + + if do_cache: + save_http_response(repos_res, cache_dir) + + if not repos_res.ok: + raise FetchError(repos_res) + + repos = repos_res.json() + for repo in repos: + if repo['id'] > max_id: # do not overstep max_id + break + self.inject_repo(db_session, repo) + + if 'next' in repos_res.links: + next_url = repos_res.links['next']['url'] + m = REPO_API_URL_RE.match(next_url) # parse next_id + next_id = int(m.group(1)) + 1 + else: + logging.info('stopping after id %d, no next link found' % + next_id) + break + + loop_count += 1 + if loop_count == 20: + logging.info('flushing updates') + loop_count = 0 + db_session.commit() + db_session = self.mk_session() + + db_session.commit()