Skip to content

Commit

Permalink
new version of panther_base_helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
nskobov committed Feb 13, 2024
1 parent 2fd19c9 commit 24dd8c0
Showing 1 changed file with 94 additions and 19 deletions.
113 changes: 94 additions & 19 deletions global_helpers/panther_base_helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import re
from collections import OrderedDict
from collections.abc import Mapping
from datetime import datetime
from fnmatch import fnmatch
from functools import reduce
from ipaddress import ip_address, ip_network
from typing import Sequence
from typing import Any, List, Optional, Sequence, Union

from panther_config import config

# # # # # # # # # # # # # #
# Exceptions #
Expand Down Expand Up @@ -33,47 +37,36 @@ def in_pci_scope_tags(resource):
return resource["Tags"].get(CDE_TAG_KEY) == CDE_TAG_VALUE


PCI_NETWORKS = config.PCI_NETWORKS
# Expects a string in cidr notation (e.g. '10.0.0.0/24') indicating the ip range being checked
# Returns True if any ip in the range is marked as in scope
PCI_NETWORKS = [
ip_network("10.0.0.0/24"),
]


def is_pci_scope_cidr(ip_range):
return any(ip_network(ip_range).overlaps(pci_network) for pci_network in PCI_NETWORKS)


DMZ_NETWORKS = config.DMZ_NETWORKS
# Expects a string in cidr notation (e.g. '10.0.0.0/24') indicating the ip range being checked
# Returns True if any ip in the range is marked as DMZ space.
DMZ_NETWORKS = [
ip_network("10.1.0.0/24"),
ip_network("100.1.0.0/24"),
]


def is_dmz_cidr(ip_range):
"""This function determines whether a given IP range is within the defined DMZ IP range."""
return any(ip_network(ip_range).overlaps(dmz_network) for dmz_network in DMZ_NETWORKS)


DMZ_TAG_KEY = "environment"
DMZ_TAG_VALUE = "dmz"


# Defaults to False to assume something is not a DMZ if it is not tagged
def is_dmz_tags(resource):
def is_dmz_tags(resource, dmz_tags):
"""This function determines whether a given resource is tagged as existing in a DMZ."""
if resource["Tags"] is None:
return False
return resource["Tags"].get(DMZ_TAG_KEY) == DMZ_TAG_VALUE
for key, value in dmz_tags:
if resource["Tags"].get(key) == value:
return True
return False


# Function variables here so that implementation details of these functions can be changed without
# having to rename the function in all locations its used, or having an outdated name on the actual
# function being used, etc.
IN_PCI_SCOPE = in_pci_scope_tags
IS_DMZ = is_dmz_tags

# # # # # # # # # # # # # #
# GSuite Helpers #
Expand Down Expand Up @@ -217,6 +210,7 @@ def okta_alert_context(event: dict):
def crowdstrike_detection_alert_context(event: dict):
"""Returns common context for Crowdstrike detections"""
return {
"aid": get_crowdstrike_field(event, "aid", default=""),
"user": get_crowdstrike_field(event, "UserName", default=""),
"console-link": get_crowdstrike_field(event, "FalconHostLink", default=""),
"commandline": get_crowdstrike_field(event, "CommandLine", default=""),
Expand Down Expand Up @@ -309,6 +303,68 @@ def deep_get(dictionary: dict, *keys, default=None):
return out


# pylint: disable=too-complex,too-many-return-statements
def deep_walk(
obj: Optional[Any], *keys: str, default: Optional[str] = None, return_val: str = "all"
) -> Union[Optional[Any], Optional[List[Any]]]:
"""Safely retrieve a value stored in complex dictionary structure
Similar to deep_get but supports accessing dictionary keys within nested lists as well
Parameters:
obj (any): the original log event passed to rule(event)
and nested objects retrieved recursively
keys (str): comma-separated list of keys used to traverse the event object
default (str): the default value to return if the desired key's value is not present
return_val (str): string specifying which value to return
possible values are "first", "last", or "all"
Returns:
any | list[any]: A single value if return_val is "first", "last",
or if "all" is a list containing one element,
otherwise a list of values
"""

def _empty_list(sub_obj: Any):
return (
all(_empty_list(next_obj) for next_obj in sub_obj)
if isinstance(sub_obj, Sequence) and not isinstance(sub_obj, str)
else False
)

if not keys:
return default if _empty_list(obj) else obj

current_key = keys[0]
found: OrderedDict = OrderedDict()

if isinstance(obj, Mapping):
next_key = obj.get(current_key, None)
return (
deep_walk(next_key, *keys[1:], default=default, return_val=return_val)
if next_key is not None
else default
)
if isinstance(obj, Sequence) and not isinstance(obj, str):
for item in obj:
value = deep_walk(item, *keys, default=default, return_val=return_val)
if value is not None:
if isinstance(value, Sequence) and not isinstance(value, str):
for sub_item in value:
found[sub_item] = None
else:
found[value] = None

found_list: list[Any] = list(found.keys())
if not found_list:
return default
return {
"first": found_list[0],
"last": found_list[-1],
"all": found_list[0] if len(found_list) == 1 else found_list,
}.get(return_val, "all")


def get_val_from_list(list_of_dicts, return_field_key, field_cmp_key, field_cmp_val):
"""Return a specific field in a list of Python dictionaries.
We return the empty set if the comparison key is not found"""
Expand Down Expand Up @@ -431,3 +487,22 @@ def defang_ioc(ioc):
"""return defanged IOC from 1.1.1.1 to 1[.]1[.]1[.]1"""
return ioc.replace(".", "[.]")


def panther_nanotime_to_python_datetime(panther_time: str) -> datetime:
panther_time_micros = re.search(r"\.(\d+)", panther_time).group(1)
panther_time_micros_rounded = panther_time_micros[0:6]
panther_time_rounded = re.sub(r"\.\d+", f".{panther_time_micros_rounded}", panther_time)
panther_time_format = r"%Y-%m-%d %H:%M:%S.%f"
return datetime.strptime(panther_time_rounded, panther_time_format)


def golang_nanotime_to_python_datetime(golang_time: str) -> datetime:
golang_time_format = r"%Y-%m-%dT%H:%M:%S.%fZ"
# Golang fractional seconds include a mix of microseconds and
# nanoseconds, which doesn't play well with Python's microseconds datetimes.
# This rounds the fractional seconds to a microsecond-size.
golang_time_micros = re.search(r"\.(\d+)Z", golang_time).group(1)
golang_time_micros_rounded = golang_time_micros[0:6]
golang_time_rounded = re.sub(r"\.\d+Z", f".{golang_time_micros_rounded}Z", golang_time)
return datetime.strptime(golang_time_rounded, golang_time_format)

0 comments on commit 24dd8c0

Please sign in to comment.