Commit 9b702154 authored by Lukas Jelonek's avatar Lukas Jelonek
Browse files

Refactor string handling into main method, away from resolve and retrieve

parent 24b3c3d5
......@@ -2,6 +2,7 @@
import argparse
import os
import logging
from dbxref import resolver
def main():
parser = argparse.ArgumentParser(description='Lookup locations of database cross references and retrieve them as json')
......@@ -38,13 +39,12 @@ def info(args, config):
print ('info')
def resolve(args, config):
from dbxref import resolver
import json
print(json.dumps(resolver.resolve(args.dbxrefs, check_existence=args.no_check)))
print(json.dumps(resolver.resolve(resolver.convert_to_dbxrefs(args.dbxrefs), check_existence=args.no_check)))
def retrieve(args, config):
from dbxref import retriever
retriever.retrieve(args.dbxrefs)
retriever.retrieve(resolver.convert_to_dbxrefs(args.dbxrefs))
if __name__ == "__main__":
main()
......@@ -18,13 +18,12 @@ STATUS_CHECK_NOT_SUPPORTED='check of status not supported'
STATUS_CHECK_TIMEOUT='status check timed out'
STATUS_UNSUPPORTED_DB='database unsupported'
def resolve(strings, check_existence=True):
def resolve(dbxrefs, check_existence=True):
results = []
for s in strings:
for dbxref in dbxrefs:
status = STATUS_NOT_CHECKED
if check_existence:
status = check_dbxref_exists(s)
dbxref = convert_string_to_dbxref(s)
status = check_dbxref_exists(dbxref)
if dbxref['db'] in providers:
provider = providers[dbxref['db']]
locations = {}
......@@ -38,8 +37,11 @@ def resolve(strings, check_existence=True):
results.append({'dbxref': dbxref['db'] + ':' + dbxref['id'], 'status': STATUS_UNSUPPORTED_DB})
return results
def check_dbxref_exists(string):
dbxref = convert_string_to_dbxref(string)
def convert_to_dbxrefs(strings):
'''convert a list of strings to dbxref maps with db and id attribute'''
return list(map(convert_string_to_dbxref, strings))
def check_dbxref_exists(dbxref):
if dbxref['db'] in providers:
provider = providers[dbxref['db']]
urls = []
......
......@@ -2,14 +2,12 @@ import logging
logger = logging.getLogger(__name__)
from dbxref.config import load_providers
from dbxref.resolver import convert_string_to_dbxref
from itertools import groupby
import json
providers = load_providers()
def retrieve(strings, location=''):
dbxrefs = list(map(convert_string_to_dbxref, strings))
def retrieve(dbxrefs, location=''):
sorted(dbxrefs, key=lambda x: x['db'])
results = []
for key, dbxrefs in groupby(dbxrefs, lambda x: x['db']):
......
......@@ -18,7 +18,7 @@ class TestDbxrefResolve(unittest.TestCase):
def test_resolve_enzyme(self):
self.assertNotEqual(resolver.resolve(["EC:1.1.1.1"]), [])
self.assertNotEqual(resolver.resolve(resolver.convert_to_dbxrefs(["EC:1.1.1.1"])), [])
def test_check_dbxref_exists(self):
import logging
......@@ -67,7 +67,7 @@ class TestDbxrefResolve(unittest.TestCase):
for d in data:
with self.subTest(d=d):
self.assertEqual(resolver.check_dbxref_exists(d[0]), d[1] )
self.assertEqual(resolver.check_dbxref_exists(resolver.convert_string_to_dbxref(d[0])), d[1] )
def test_check_urls(self):
import json
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment