Commit 9b702154 authored by Lukas Jelonek's avatar Lukas Jelonek
Browse files

Refactor string handling into main method, away from resolve and retrieve

parent 24b3c3d5
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import argparse import argparse
import os import os
import logging import logging
from dbxref import resolver
def main(): def main():
parser = argparse.ArgumentParser(description='Lookup locations of database cross references and retrieve them as json') parser = argparse.ArgumentParser(description='Lookup locations of database cross references and retrieve them as json')
...@@ -38,13 +39,12 @@ def info(args, config): ...@@ -38,13 +39,12 @@ def info(args, config):
print ('info') print ('info')
def resolve(args, config): def resolve(args, config):
from dbxref import resolver
import json import json
print(json.dumps(resolver.resolve(args.dbxrefs, check_existence=args.no_check))) print(json.dumps(resolver.resolve(resolver.convert_to_dbxrefs(args.dbxrefs), check_existence=args.no_check)))
def retrieve(args, config): def retrieve(args, config):
from dbxref import retriever from dbxref import retriever
retriever.retrieve(args.dbxrefs) retriever.retrieve(resolver.convert_to_dbxrefs(args.dbxrefs))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -18,13 +18,12 @@ STATUS_CHECK_NOT_SUPPORTED='check of status not supported' ...@@ -18,13 +18,12 @@ STATUS_CHECK_NOT_SUPPORTED='check of status not supported'
STATUS_CHECK_TIMEOUT='status check timed out' STATUS_CHECK_TIMEOUT='status check timed out'
STATUS_UNSUPPORTED_DB='database unsupported' STATUS_UNSUPPORTED_DB='database unsupported'
def resolve(strings, check_existence=True): def resolve(dbxrefs, check_existence=True):
results = [] results = []
for s in strings: for dbxref in dbxrefs:
status = STATUS_NOT_CHECKED status = STATUS_NOT_CHECKED
if check_existence: if check_existence:
status = check_dbxref_exists(s) status = check_dbxref_exists(dbxref)
dbxref = convert_string_to_dbxref(s)
if dbxref['db'] in providers: if dbxref['db'] in providers:
provider = providers[dbxref['db']] provider = providers[dbxref['db']]
locations = {} locations = {}
...@@ -38,8 +37,11 @@ def resolve(strings, check_existence=True): ...@@ -38,8 +37,11 @@ def resolve(strings, check_existence=True):
results.append({'dbxref': dbxref['db'] + ':' + dbxref['id'], 'status': STATUS_UNSUPPORTED_DB}) results.append({'dbxref': dbxref['db'] + ':' + dbxref['id'], 'status': STATUS_UNSUPPORTED_DB})
return results return results
def check_dbxref_exists(string): def convert_to_dbxrefs(strings):
dbxref = convert_string_to_dbxref(string) '''convert a list of strings to dbxref maps with db and id attribute'''
return list(map(convert_string_to_dbxref, strings))
def check_dbxref_exists(dbxref):
if dbxref['db'] in providers: if dbxref['db'] in providers:
provider = providers[dbxref['db']] provider = providers[dbxref['db']]
urls = [] urls = []
......
...@@ -2,14 +2,12 @@ import logging ...@@ -2,14 +2,12 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from dbxref.config import load_providers from dbxref.config import load_providers
from dbxref.resolver import convert_string_to_dbxref
from itertools import groupby from itertools import groupby
import json import json
providers = load_providers() providers = load_providers()
def retrieve(strings, location=''): def retrieve(dbxrefs, location=''):
dbxrefs = list(map(convert_string_to_dbxref, strings))
sorted(dbxrefs, key=lambda x: x['db']) sorted(dbxrefs, key=lambda x: x['db'])
results = [] results = []
for key, dbxrefs in groupby(dbxrefs, lambda x: x['db']): for key, dbxrefs in groupby(dbxrefs, lambda x: x['db']):
......
...@@ -18,7 +18,7 @@ class TestDbxrefResolve(unittest.TestCase): ...@@ -18,7 +18,7 @@ class TestDbxrefResolve(unittest.TestCase):
def test_resolve_enzyme(self): def test_resolve_enzyme(self):
self.assertNotEqual(resolver.resolve(["EC:1.1.1.1"]), []) self.assertNotEqual(resolver.resolve(resolver.convert_to_dbxrefs(["EC:1.1.1.1"])), [])
def test_check_dbxref_exists(self): def test_check_dbxref_exists(self):
import logging import logging
...@@ -67,7 +67,7 @@ class TestDbxrefResolve(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestDbxrefResolve(unittest.TestCase):
for d in data: for d in data:
with self.subTest(d=d): with self.subTest(d=d):
self.assertEqual(resolver.check_dbxref_exists(d[0]), d[1] ) self.assertEqual(resolver.check_dbxref_exists(resolver.convert_string_to_dbxref(d[0])), d[1] )
def test_check_urls(self): def test_check_urls(self):
import json import json
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment