Skip to content

Commit

Permalink
feat: Narrow KeyError to MissingRequiredArgument
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmckinney committed Jul 16, 2024
1 parent f1cef60 commit 616272c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/news.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ API
^^^

- If the ``egg`` parameter to the :ref:`addversion.json`` webservice is not a ZIP file, use the error message, "egg is not a ZIP file (if using curl, use egg=@path not egg=path)".
- Clarify some error messages: for example, ``KeyError: 'project' (missing required parameter?)`` instead of ``'project'``, and ``exception class: message`` instead of ``message``.
- Clarify some error messages: for example, ``'project' parameter is required`` instead of ``'project'``, and ``exception class: message`` instead of ``message``.

CLI
^^^
Expand Down
6 changes: 6 additions & 0 deletions scrapyd/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class ScrapydError(Exception):
"""Base class for exceptions from within this package"""


class MissingRequiredArgument(ScrapydError):
"""Raised if a required argument is missing"""
43 changes: 29 additions & 14 deletions scrapyd/webservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,25 @@
from twisted.python import log
from twisted.web import http

from scrapyd.exceptions import MissingRequiredArgument
from scrapyd.jobstorage import job_items_url, job_log_url
from scrapyd.utils import JsonResource, UtilsCache, get_spider_list, native_stringify_dict


def _get_required_param(args, param):
try:
return args[param]
except KeyError as e:
raise MissingRequiredArgument(str(e))


def _pop_required_param(args, param):
try:
return args.pop(param)
except KeyError as e:
raise MissingRequiredArgument(str(e))


class WsResource(JsonResource):

def __init__(self, root):
Expand All @@ -25,8 +40,8 @@ def render(self, txrequest):
if self.root.debug:
return traceback.format_exc().encode('utf-8')
log.err()
if isinstance(e, KeyError):
message = f"KeyError: {e} (missing required parameter?)"
if isinstance(e, MissingRequiredArgument):
message = f"{e} parameter is required"
else:
message = f"{type(e).__name__}: {str(e)}"
r = {"node_name": self.root.nodename, "status": "error", "message": message}
Expand Down Expand Up @@ -66,8 +81,8 @@ def render_POST(self, txrequest):
settings = args.pop('setting', [])
settings = dict(x.split('=', 1) for x in settings)
args = {k: v[0] for k, v in args.items()}
project = args.pop('project')
spider = args.pop('spider')
project = _pop_required_param(args, 'project')
spider = _pop_required_param(args, 'spider')
version = args.get('_version', '')
priority = float(args.pop('priority', 0))
spiders = get_spider_list(project, version=version)
Expand All @@ -84,8 +99,8 @@ class Cancel(WsResource):

def render_POST(self, txrequest):
args = {k: v[0] for k, v in native_stringify_dict(copy(txrequest.args), keys_only=False).items()}
project = args['project']
jobid = args['job']
project = _get_required_param(args, 'project')
jobid = _get_required_param(args, 'job')
signal = args.get('signal', 'INT' if sys.platform != 'win32' else 'BREAK')
prevstate = None
queue = self.root.poller.queues[project]
Expand All @@ -103,12 +118,12 @@ def render_POST(self, txrequest):
class AddVersion(WsResource):

def render_POST(self, txrequest):
eggf = BytesIO(txrequest.args.pop(b'egg')[0])
eggf = BytesIO(_pop_required_param(txrequest.args, b'egg')[0])
if not zipfile.is_zipfile(eggf):
return {"status": "error", "message": "egg is not a ZIP file (if using curl, use egg=@path not egg=path)"}
args = native_stringify_dict(copy(txrequest.args), keys_only=False)
project = args['project'][0]
version = args['version'][0]
project = _get_required_param(args, 'project')[0]
version = _get_required_param(args, 'version')[0]
self.root.eggstorage.put(eggf, project, version)
spiders = get_spider_list(project, version=version)
self.root.update_projects()
Expand All @@ -128,7 +143,7 @@ class ListVersions(WsResource):

def render_GET(self, txrequest):
args = native_stringify_dict(copy(txrequest.args), keys_only=False)
project = args['project'][0]
project = _get_required_param(args, 'project')[0]
versions = self.root.eggstorage.list(project)
return {"node_name": self.root.nodename, "status": "ok", "versions": versions}

Expand All @@ -137,7 +152,7 @@ class ListSpiders(WsResource):

def render_GET(self, txrequest):
args = native_stringify_dict(copy(txrequest.args), keys_only=False)
project = args['project'][0]
project = _get_required_param(args, 'project')[0]
version = args.get('_version', [''])[0]
spiders = get_spider_list(project, runner=self.root.runner, version=version)
return {"node_name": self.root.nodename, "status": "ok", "spiders": spiders}
Expand Down Expand Up @@ -184,7 +199,7 @@ class DeleteProject(WsResource):

def render_POST(self, txrequest):
args = native_stringify_dict(copy(txrequest.args), keys_only=False)
project = args['project'][0]
project = _get_required_param(args, 'project')[0]
self._delete_version(project)
UtilsCache.invalid_cache(project)
return {"node_name": self.root.nodename, "status": "ok"}
Expand All @@ -198,8 +213,8 @@ class DeleteVersion(DeleteProject):

def render_POST(self, txrequest):
args = native_stringify_dict(copy(txrequest.args), keys_only=False)
project = args['project'][0]
version = args['version'][0]
project = _get_required_param(args, 'project')[0]
version = _get_required_param(args, 'version')[0]
self._delete_version(project, version)
UtilsCache.invalid_cache(project)
return {"node_name": self.root.nodename, "status": "ok"}

0 comments on commit 616272c

Please sign in to comment.