diff --git a/http_prompt/context/__init__.py b/http_prompt/context/__init__.py index f6445d0..cdd9e09 100644 --- a/http_prompt/context/__init__.py +++ b/http_prompt/context/__init__.py @@ -1,4 +1,5 @@ from http_prompt.tree import Node +from urllib.parse import urlparse class Context(object): @@ -15,10 +16,12 @@ def __init__(self, url=None, spec=None): # Create a tree for supporting API spec and ls command self.root = Node('root') if spec: + is_open_api = 'openapi' in spec + if not self.url: - self.url = spec.get('servers')[0].get('url') - if 'servers' in spec: - self.url = spec.get('servers')[0].get('url') + if is_open_api: + # In open api, the schemes are in the 'server' element, + self.url = spec.get('servers', [{'url': 'http://localhost:8000'}])[0].get('url') else: schemes = spec.get('schemes') scheme = schemes[0] if schemes else 'https' @@ -26,8 +29,15 @@ def __init__(self, url=None, spec=None): spec.get('host', 'http://localhost:8000') + spec.get('basePath', '')) - base_path_tokens = list(filter(lambda s: s, - spec.get('basePath', '').split('/'))) + # in open api, there is no 'basePath', we should extract that from the url + if is_open_api: + server = spec.get('servers', [{'url': 'http://localhost:8000'}])[0].get('url') + base_path_tokens = list(filter(lambda s: s, + urlparse(server).path.split('/'))) + else: + base_path_tokens = list(filter(lambda s: s, + spec.get('basePath', '').split('/'))) + paths = spec.get('paths') if paths: for path in paths: diff --git a/tests/test_cli.py b/tests/test_cli.py index d63b402..e82649b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -181,6 +181,28 @@ def test_spec_from_local_yml(self): self.assertEqual(set([n.name for n in context.root.children]), set(['users', 'orgs'])) + def test_spec_from_local_yml_openapi(self): + spec_filepath = self.make_tempfile(""" + openapi: "3.0.0" + servers: + - url: https://localhost:8080/ + paths: + /api/users: + get: + description: + /api/orgs: + get: + description: + """) + result, context = run_and_exit(['example.com/api', "--spec", + spec_filepath]) + self.assertEqual(result.exit_code, 0) + self.assertEqual(context.url, 'http://example.com/api') + self.assertEqual(set([n.name for n in context.root.children]), + set(['api'])) + self.assertEqual(set([n.name for n in context.root.ls('api')]), + set(['users', 'orgs'])) + def test_spec_basePath(self): spec_filepath = self.make_tempfile(json.dumps({ 'basePath': '/api/v1',