Skip to content

Commit

Permalink
Merge pull request #298 from Uninett/api/add-pm-subcommands
Browse files Browse the repository at this point in the history
Add PM subcommands to legacy API
  • Loading branch information
lunkwill42 authored Jul 9, 2024
2 parents 4c5afad + ad93dbc commit 66c22f0
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 5 deletions.
1 change: 1 addition & 0 deletions changelog.d/298.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added the PM family of API commands to manipulate planned maintenance
153 changes: 150 additions & 3 deletions src/zino/api/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import re
import textwrap
from datetime import datetime, timezone
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, NamedTuple, Optional, Union
Expand All @@ -17,7 +18,16 @@
from zino.api import auth
from zino.api.notify import Zino1NotificationProtocol
from zino.state import ZinoState, config
from zino.statemodels import ClosedEventError, Event, EventState
from zino.statemodels import (
ClosedEventError,
DeviceMaintenance,
Event,
EventState,
MatchType,
PlannedMaintenance,
PortStateMaintenance,
)
from zino.time import now

if TYPE_CHECKING:
from zino.api.server import ZinoServer
Expand Down Expand Up @@ -135,11 +145,18 @@ def _dispatch_command(self, message: str):
if getattr(responder.function, "requires_authentication", False) and not self.is_authenticated:
return self._respond_error("Not authenticated")

required_args = inspect.signature(responder.function).parameters
required_args = {
name: param
for name, param in inspect.signature(responder.function).parameters.items()
if param.kind == param.POSITIONAL_OR_KEYWORD
}
has_variable_args = any(
param.kind == param.VAR_POSITIONAL for param in inspect.signature(responder.function).parameters.values()
)
if len(args) < len(required_args):
arg_summary = " (" + ", ".join(required_args.keys()) + ")" if required_args else ""
return self._respond_error(f"{responder.name} needs {len(required_args)} parameters{arg_summary}")
elif len(args) > len(required_args):
elif not has_variable_args and len(args) > len(required_args):
garbage_args = args[len(required_args) :]
_logger.debug("client %s sent %r, ignoring garbage args at end: %r", self.peer_name, args, garbage_args)
args = args[: len(required_args)]
Expand Down Expand Up @@ -388,6 +405,136 @@ async def do_ntie(self, nonce: str):

return self._respond_ok()

def _translate_pm_id_to_pm(responder: callable): # noqa
"""Decorates any command that works with planned maintenance adding verification of the
incoming pm_id argument and translation to an actual PlannedMaintenance object.
"""

@wraps(responder)
def _verify(self, pm_id: Union[str, int], *args, **kwargs):
try:
pm_id = int(pm_id)
pm = self._state.planned_maintenances[pm_id]
except (ValueError, KeyError):
self._respond_error(f'pm "{pm_id}" does not exist')
response = asyncio.get_running_loop().create_future()
response.set_result(None)
return response
return responder(self, pm, *args, **kwargs)

return _verify

@requires_authentication
async def do_pm(self):
"""Implements the top-level PM command.
In the original Zino, this has its own dispatcher, and calling it without arguments only results an error.
"""
return self._respond_error("PM command requires a subcommand")

@requires_authentication
async def do_pm_help(self):
"""Lists all available PM sub-commands"""
responders = (responder for name, responder in self._responders.items() if responder.name.startswith("PM "))
commands = " ".join(sorted(responder.name.removeprefix("PM ") for responder in responders))
self._respond_multiline(200, ["PM subcommands are:"] + textwrap.wrap(commands, width=56))

@requires_authentication
async def do_pm_list(self):
self._respond(300, "PM event ids follows, terminated with '.'")
for id in self._state.planned_maintenances.planned_maintenances:
self._respond_raw(id)
self._respond_raw(".")

@requires_authentication
@_translate_pm_id_to_pm
async def do_pm_cancel(self, pm: PlannedMaintenance):
self._state.planned_maintenances.close_planned_maintenance(pm.id, "PM cancelled", self.user)
self._respond_ok()

@requires_authentication
@_translate_pm_id_to_pm
async def do_pm_addlog(self, pm: PlannedMaintenance):
self._respond(302, "please provide new PM log entry, terminate with '.'")
data = await self._read_multiline()
message = f"{self.user}\n" + "\n".join(line.strip() for line in data)
pm.add_log(message)
self._respond_ok()

@requires_authentication
@_translate_pm_id_to_pm
async def do_pm_log(self, pm: PlannedMaintenance):
self._respond(300, "log follows, terminated with '.'")
for log in pm.log:
for line in log.model_dump_legacy():
self._respond_raw(line)
self._respond_raw(".")

@requires_authentication
@_translate_pm_id_to_pm
async def do_pm_details(self, pm: PlannedMaintenance):
self._respond(200, pm.details())

@requires_authentication
async def do_pm_add(self, from_t: Union[str, int], to_t: Union[str, int], pm_type: str, m_type: str, *args: str):
try:
start_time = datetime.fromtimestamp(int(from_t), tz=timezone.utc)
except ValueError:
return self._respond_error("illegal from_t (param 1), must be only digits")
try:
end_time = datetime.fromtimestamp(int(to_t), tz=timezone.utc)
except ValueError:
return self._respond_error("illegal to_t (param 2), must be only digits")
if end_time < start_time:
return self._respond_error("ending time is before starting time")
if start_time < now():
return self._respond_error("starting time is in the past")

if pm_type == "device":
pm_class = DeviceMaintenance
elif pm_type == "portstate":
pm_class = PortStateMaintenance
else:
return self._respond_error(f"unknown PM event type: {pm_type}")

try:
match_type = MatchType(m_type)
except ValueError:
return self._respond_error(f"unknown match type: {m_type}")

if match_type == MatchType.INTF_REGEXP:
if len(args) < 2:
return self._respond_error(
"{m_type} match type requires two extra arguments: match_device and match_expression"
)
match_device = args[0]
match_expression = args[1]
else:
if len(args) < 1:
return self._respond_error(f"{m_type} match type requires one extra argument: match_expression")
match_device = None
match_expression = args[0]

pm = self._state.planned_maintenances.create_planned_maintenance(
start_time,
end_time,
pm_class,
match_type,
match_expression,
match_device,
)
self._respond(200, f"PM id {pm.id} successfully added")

@requires_authentication
@_translate_pm_id_to_pm
async def do_pm_matching(self, pm: PlannedMaintenance):
matches = pm.get_matching(self._state)
self._respond(300, "Matching ports/devices follows, terminated with '.'")
for match in matches:
output = " ".join(str(i) for i in match)
self._respond_raw(output)
self._respond_raw(".")


class ZinoTestProtocol(Zino1ServerProtocol):
"""Extended Zino 1 server protocol with test commands added in"""
Expand Down
59 changes: 58 additions & 1 deletion src/zino/statemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
from collections.abc import Generator
from enum import Enum
from ipaddress import IPv4Address, IPv6Address
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
TypeVar,
Union,
)

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -428,6 +439,25 @@ def add_log(self, message: str) -> LogEntry:
self.log.append(entry)
return entry

def details(self) -> str:
"""Returns a string with the details of the object.
Format from zino1: $id $from_t $to_t $type $match_type [$match_dev] $match_expr
"""
details = [
str(int(attr.timestamp())) if isinstance(attr, datetime.datetime) else str(attr)
for attr in [
self.id,
self.start_time,
self.end_time,
self.type,
self.match_type,
self.match_device,
self.match_expression,
]
if attr
]
return " ".join(details)

def matches_event(self, event: Event, state: "ZinoState") -> bool:
"""Returns true if `event` will be affected by this planned maintenance"""
raise NotImplementedError
Expand All @@ -438,6 +468,14 @@ def _get_or_create_events(self, state: "ZinoState") -> list[Event]:
"""
raise NotImplementedError

def get_matching(self, state: "ZinoState") -> Iterator[Sequence[Union[str, int]]]:
"""Returns a list of matching devices or ports from Zino state.
The number of elements of each sequence of the return value depends on the type of planned maintenance
objects, but each entry should be suitable to join on space and output to the legacy API.
"""
raise NotImplementedError


class DeviceMaintenance(PlannedMaintenance):
type: Literal[PmType.DEVICE] = PmType.DEVICE
Expand All @@ -461,6 +499,15 @@ def matches_device(self, device: DeviceState) -> bool:
return self.match_expression == device.name
return False

def get_matching(self, state: "ZinoState") -> Iterator[Sequence[Union[str, int]]]:
"""Returns a list of matching devices from Zino state.
Each element is a sequence of (pm_id, "device", device_name)
"""
for device in state.devices.devices.values():
if self.matches_device(device):
yield self.id, self.type, device.name

def _get_or_create_events(self, state: "ZinoState") -> list[Event]:
"""Creates/gets events that are affected by the given starting planned
maintenance
Expand Down Expand Up @@ -503,6 +550,16 @@ def matches_portstate(self, device: DeviceState, port: Port) -> bool:
return regex_match(self.match_expression, port.ifdescr)
return False

def get_matching(self, state: "ZinoState") -> Iterator[Sequence[Union[str, int]]]:
"""Returns a list of matching devices from Zino state.
Each element is a sequence of (pm_id, "portstate", device_name, ifIndex, ifDescr, f"({ifAlias})")
"""
for device in state.devices.devices.values():
for port in device.ports.values():
if self.matches_portstate(device, port):
yield self.id, self.type, device.name, port.ifindex, port.ifdescr, f"({port.ifalias})"

def _get_or_create_events(self, state: "ZinoState") -> list[Event]:
events = []
for device, port in self._get_matching_ports(state):
Expand Down
58 changes: 57 additions & 1 deletion tests/api/legacy_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re
from datetime import timedelta
from io import BytesIO
from unittest.mock import Mock, patch

Expand All @@ -14,7 +16,14 @@
from zino.api.server import ZinoServer
from zino.config.models import PollDevice
from zino.state import ZinoState
from zino.statemodels import Event, EventState, ReachabilityEvent
from zino.statemodels import (
DeviceMaintenance,
Event,
EventState,
MatchType,
ReachabilityEvent,
)
from zino.time import now


class TestZino1BaseServerProtocol:
Expand Down Expand Up @@ -696,6 +705,53 @@ def _read_multiline(self):
assert b"200 ok" in buffered_fake_transport.data_buffer.getvalue()


class TestZino1ServerProtocolPmCommand:
@pytest.mark.asyncio
async def test_it_should_always_return_a_500_error(self, authenticated_protocol):
await authenticated_protocol.message_received("PM")

assert b"500 " in authenticated_protocol.transport.data_buffer.getvalue()


class TestZino1ServerProtocolPmHelpCommand:
@pytest.mark.asyncio
async def test_when_authenticated_pm_help_is_issued_then_all_pm_subcommands_should_be_listed(
self, authenticated_protocol
):
await authenticated_protocol.message_received("PM HELP")

all_command_names = set(
responder.name.removeprefix("PM ")
for responder in authenticated_protocol._responders.values()
if responder.name.startswith("PM ")
)
for command_name in all_command_names:
assert (
command_name.encode() in authenticated_protocol.transport.data_buffer.getvalue()
), f"{command_name} is not listed in PM HELP"


class TestZino1ServerProtocolPmListCommand:
@pytest.mark.asyncio
async def test_when_authenticated_should_list_all_pm_ids(self, authenticated_protocol):
pms = authenticated_protocol._state.planned_maintenances
pms.create_planned_maintenance(
now() - timedelta(hours=1),
now() + timedelta(hours=1),
DeviceMaintenance,
MatchType.REGEXP,
"expr",
)
await authenticated_protocol.message_received("PM LIST")
response = authenticated_protocol.transport.data_buffer.getvalue().decode("utf-8")

assert re.search(r"\b300 \b", response), "Expected response to contain status code 300"

pattern_string = r"\b{}\b"
for id in pms.planned_maintenances:
assert re.search(pattern_string.format(id), response), f"Expected response to contain id {id}"


def test_requires_authentication_should_set_function_attribute():
@requires_authentication
def throwaway():
Expand Down

0 comments on commit 66c22f0

Please sign in to comment.