Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py-tx] Add a new match command line option (--rotations) #1672

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 91 additions & 15 deletions python-threatexchange/threatexchange/cli/match_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@
Match command for parsing simple data sources against the dataset.
"""

from dataclasses import dataclass, field
import argparse
import logging
import pathlib
import typing as t

import tempfile

from threatexchange import common
from threatexchange.cli.fetch_cmd import FetchCommand
from threatexchange.cli.helpers import FlexFilesInputAction
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata

from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex
from threatexchange.signal_type.index import (
IndexMatch,
SignalTypeIndex,
IndexMatchUntyped,
SignalSimilarityInfo,
T,
)
from threatexchange.cli.exceptions import CommandError
from threatexchange.signal_type.signal_base import BytesHasher, SignalType
from threatexchange.cli.cli_config import CLISettings
from threatexchange.content_type.content_base import ContentType
from threatexchange.content_type.content_base import ContentType, RotationType
from threatexchange.content_type.photo import PhotoContent

from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher
from threatexchange.cli import command_base
Expand All @@ -29,6 +37,19 @@
TMatcher = t.Callable[[pathlib.Path], t.List[IndexMatch]]


@dataclass
class _IndexMatchWithRotation(t.Generic[T]):
match: IndexMatchUntyped[SignalSimilarityInfo, T]
rotation_type: t.Optional[RotationType] = field(default=None)

def __str__(self):
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(self.match.similarity_info.pretty_str().split())
if self.rotation_type is None:
return distance_str
return f"{distance_str} [{self.rotation_type.name}]"


class MatchCommand(command_base.Command):
"""
Match content to fetched signals
Expand Down Expand Up @@ -126,6 +147,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No
action="store_true",
help="show all matches, not just one per collaboration",
)
ap.add_argument(
"--rotations",
"-R",
action="store_true",
help="for photos, generate and match all 8 simple rotations",
)

def __init__(
self,
Expand All @@ -136,6 +163,7 @@ def __init__(
show_false_positives: bool,
hide_disputed: bool,
all: bool,
rotations: bool = False,
) -> None:
self.content_type = content_type
self.only_signal = only_signal
Expand All @@ -144,6 +172,7 @@ def __init__(
self.hide_disputed = hide_disputed
self.files = files
self.all = all
self.rotations = rotations

if only_signal and content_type not in only_signal.get_content_types():
raise CommandError(
Expand All @@ -152,6 +181,11 @@ def __init__(
2,
)

if self.rotations and not issubclass(content_type, PhotoContent):
raise CommandError(
"--rotations flag is only available for Photo content type", 2
)

def execute(self, settings: CLISettings) -> None:
if not settings.index.list():
if not settings.in_demo_mode:
Expand Down Expand Up @@ -196,18 +230,23 @@ def execute(self, settings: CLISettings) -> None:
for s_type, index in indices:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before processing this list of files, if rotations is true here, generate new files that you iterate through.

Here's one way to do that

def handle_rotations() -> Iterator[(Path, t.Optional[RotationType)]:
  for file in self.files:
    if not self.rotations:
      yield file, None
      continue
   for rot, dat in PhotoContent.blah():
      with NamedTemporary(...) as f:
         yield Path(f.name), rot

for path in handle_rotations(self.files):

seen = set() # TODO - maybe take the highest certainty?
if self.as_hashes:
results = _match_hashes(path, s_type, index)
results: t.Sequence[_IndexMatchWithRotation] = _match_hashes(
path, s_type, index
)
else:
results = _match_file(path, s_type, index)
results = _match_file(path, s_type, index, rotations=self.rotations)

for r in results:
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = (
r.match.metadata
)
distance_str = str(r)

for collab, fetched_data in metadatas:
if not self.all and collab in seen:
continue
seen.add(collab)
# Supposed to be without whitespace, but let's make sure
distance_str = "".join(r.similarity_info.pretty_str().split())

print(
s_type.get_name(),
distance_str,
Expand All @@ -217,18 +256,54 @@ def execute(self, settings: CLISettings) -> None:


def _match_file(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
path: pathlib.Path,
s_type: t.Type[SignalType],
index: SignalTypeIndex,
rotations: bool = False,
) -> t.Sequence[_IndexMatchWithRotation]:
if issubclass(s_type, MatchesStr):
return index.query(path.read_text())
matches = index.query(path.read_text())
return [_IndexMatchWithRotation(match=match) for match in matches]

assert issubclass(s_type, FileHasher)
return index.query(s_type.hash_from_file(path))

if not rotations or s_type != PhotoContent:
matches = index.query(s_type.hash_from_file(path))
return [_IndexMatchWithRotation(match=match) for match in matches]

# Handle rotations for photos
with open(path, "rb") as f:
image_data = f.read()

rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations(
image_data
)
all_matches = []

for rotation_type, rotated_bytes in rotated_images.items():
# Create a temporary file to hold the image bytes
with tempfile.NamedTemporaryFile() as temp_file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: This is going to write a lot of files over the course of an execution! We are going to call this method for every single match type.

Another approach you could do is refactor this so that the rotated images are inserted higher up in the stack, and then rather than taking rotations: bool, you could pass in the path of the rotation, and an optional enum representing which enum it can take.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of change s_type to be subclass of BytesHasher so that I can use bytes directly with hash_from_bytes. All the tests passed and mypy doesn't complain. What do you think of this approach?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the last PR, I mentioned the downside - not all photo formats are knowable without the extension. There's likely a workaround, but let's stay with the current course and see if we can fix it in a followup instead.

temp_file.write(rotated_bytes)
temp_file_path = pathlib.Path(temp_file.name)
matches = index.query(s_type.hash_from_file(temp_file_path))
temp_file_path.unlink() # Clean up the temporary file
haianhng31 marked this conversation as resolved.
Show resolved Hide resolved

# Add rotation information if any matches were found
matches_with_rotations = []
for match in matches:
matches_with_rotations.append(
_IndexMatchWithRotation(match=match, rotation_type=rotation_type)
)

all_matches.extend(matches_with_rotations)

return all_matches


def _match_hashes(
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex
) -> t.Sequence[IndexMatch]:
ret: t.List[IndexMatch] = []
) -> t.Sequence[_IndexMatchWithRotation]:
ret: t.List[_IndexMatchWithRotation] = []
for hash in path.read_text().splitlines():
hash = hash.strip()
if not hash:
Expand All @@ -244,5 +319,6 @@ def _match_hashes(
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}",
2,
)
ret.extend(index.query(hash))
matches = index.query(hash)
ret.extend([_IndexMatchWithRotation(match=match) for match in matches])
return ret
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
This records all the valid signal types for a piece of content.
"""

from enum import Enum, auto
from enum import Enum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

import typing as t

from threatexchange import common
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
from PIL import Image
import io
import typing as t

from .content_base import ContentType, RotationType

Expand Down Expand Up @@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes:
return buffer.getvalue()

@classmethod
def all_simple_rotations(cls, image_data: bytes):
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

"""
Generate the 8 naive rotations of an image.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import pickle
import typing as t


T = t.TypeVar("T")
S_Co = t.TypeVar("S_Co", covariant=True, bound="SignalSimilarityInfo")
CT = t.TypeVar("CT", bound="Comparable")
Expand Down