-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[py-tx] Add a new match command line option (--rotations) #1672
base: main
Are you sure you want to change the base?
Changes from 16 commits
26fe246
931b81a
dfe04d9
0e44874
bd532ae
13bf6e7
594782b
f8a5891
e570ef9
02ff144
a9db9d9
ad9a165
a1a39d3
6a6cea0
d841f60
3b83a41
dff8c1b
e367f61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,22 +5,30 @@ | |
Match command for parsing simple data sources against the dataset. | ||
""" | ||
|
||
from dataclasses import dataclass, field | ||
import argparse | ||
import logging | ||
import pathlib | ||
import typing as t | ||
|
||
import tempfile | ||
|
||
from threatexchange import common | ||
from threatexchange.cli.fetch_cmd import FetchCommand | ||
from threatexchange.cli.helpers import FlexFilesInputAction | ||
from threatexchange.exchanges.fetch_state import FetchedSignalMetadata | ||
|
||
from threatexchange.signal_type.index import IndexMatch, SignalTypeIndex | ||
from threatexchange.signal_type.index import ( | ||
IndexMatch, | ||
SignalTypeIndex, | ||
IndexMatchUntyped, | ||
SignalSimilarityInfo, | ||
T, | ||
) | ||
from threatexchange.cli.exceptions import CommandError | ||
from threatexchange.signal_type.signal_base import BytesHasher, SignalType | ||
from threatexchange.cli.cli_config import CLISettings | ||
from threatexchange.content_type.content_base import ContentType | ||
from threatexchange.content_type.content_base import ContentType, RotationType | ||
from threatexchange.content_type.photo import PhotoContent | ||
|
||
from threatexchange.signal_type.signal_base import MatchesStr, TextHasher, FileHasher | ||
from threatexchange.cli import command_base | ||
|
@@ -29,6 +37,19 @@ | |
TMatcher = t.Callable[[pathlib.Path], t.List[IndexMatch]] | ||
|
||
|
||
@dataclass | ||
class _IndexMatchWithRotation(t.Generic[T]): | ||
match: IndexMatchUntyped[SignalSimilarityInfo, T] | ||
rotation_type: t.Optional[RotationType] = field(default=None) | ||
|
||
def __str__(self): | ||
# Supposed to be without whitespace, but let's make sure | ||
distance_str = "".join(self.match.similarity_info.pretty_str().split()) | ||
if self.rotation_type is None: | ||
return distance_str | ||
return f"{distance_str} [{self.rotation_type.name}]" | ||
|
||
|
||
class MatchCommand(command_base.Command): | ||
""" | ||
Match content to fetched signals | ||
|
@@ -126,6 +147,12 @@ def init_argparse(cls, settings: CLISettings, ap: argparse.ArgumentParser) -> No | |
action="store_true", | ||
help="show all matches, not just one per collaboration", | ||
) | ||
ap.add_argument( | ||
"--rotations", | ||
"-R", | ||
action="store_true", | ||
help="for photos, generate and match all 8 simple rotations", | ||
) | ||
|
||
def __init__( | ||
self, | ||
|
@@ -136,6 +163,7 @@ def __init__( | |
show_false_positives: bool, | ||
hide_disputed: bool, | ||
all: bool, | ||
rotations: bool = False, | ||
) -> None: | ||
self.content_type = content_type | ||
self.only_signal = only_signal | ||
|
@@ -144,6 +172,7 @@ def __init__( | |
self.hide_disputed = hide_disputed | ||
self.files = files | ||
self.all = all | ||
self.rotations = rotations | ||
|
||
if only_signal and content_type not in only_signal.get_content_types(): | ||
raise CommandError( | ||
|
@@ -152,6 +181,11 @@ def __init__( | |
2, | ||
) | ||
|
||
if self.rotations and not issubclass(content_type, PhotoContent): | ||
raise CommandError( | ||
"--rotations flag is only available for Photo content type", 2 | ||
) | ||
|
||
def execute(self, settings: CLISettings) -> None: | ||
if not settings.index.list(): | ||
if not settings.in_demo_mode: | ||
|
@@ -196,18 +230,23 @@ def execute(self, settings: CLISettings) -> None: | |
for s_type, index in indices: | ||
seen = set() # TODO - maybe take the highest certainty? | ||
if self.as_hashes: | ||
results = _match_hashes(path, s_type, index) | ||
results: t.Sequence[_IndexMatchWithRotation] = _match_hashes( | ||
path, s_type, index | ||
) | ||
else: | ||
results = _match_file(path, s_type, index) | ||
results = _match_file(path, s_type, index, rotations=self.rotations) | ||
|
||
for r in results: | ||
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = r.metadata | ||
metadatas: t.List[t.Tuple[str, FetchedSignalMetadata]] = ( | ||
r.match.metadata | ||
) | ||
distance_str = str(r) | ||
|
||
for collab, fetched_data in metadatas: | ||
if not self.all and collab in seen: | ||
continue | ||
seen.add(collab) | ||
# Supposed to be without whitespace, but let's make sure | ||
distance_str = "".join(r.similarity_info.pretty_str().split()) | ||
|
||
print( | ||
s_type.get_name(), | ||
distance_str, | ||
|
@@ -217,18 +256,54 @@ def execute(self, settings: CLISettings) -> None: | |
|
||
|
||
def _match_file( | ||
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex | ||
) -> t.Sequence[IndexMatch]: | ||
path: pathlib.Path, | ||
s_type: t.Type[SignalType], | ||
index: SignalTypeIndex, | ||
rotations: bool = False, | ||
) -> t.Sequence[_IndexMatchWithRotation]: | ||
if issubclass(s_type, MatchesStr): | ||
return index.query(path.read_text()) | ||
matches = index.query(path.read_text()) | ||
return [_IndexMatchWithRotation(match=match) for match in matches] | ||
|
||
assert issubclass(s_type, FileHasher) | ||
return index.query(s_type.hash_from_file(path)) | ||
|
||
if not rotations or s_type != PhotoContent: | ||
matches = index.query(s_type.hash_from_file(path)) | ||
return [_IndexMatchWithRotation(match=match) for match in matches] | ||
|
||
# Handle rotations for photos | ||
with open(path, "rb") as f: | ||
image_data = f.read() | ||
|
||
rotated_images: t.Dict[RotationType, bytes] = PhotoContent.all_simple_rotations( | ||
image_data | ||
) | ||
all_matches = [] | ||
|
||
for rotation_type, rotated_bytes in rotated_images.items(): | ||
# Create a temporary file to hold the image bytes | ||
with tempfile.NamedTemporaryFile() as temp_file: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: This is going to write a lot of files over the course of an execution! We are going to call this method for every single match type. Another approach you could do is refactor this so that the rotated images are inserted higher up in the stack, and then rather than taking rotations: bool, you could pass in the path of the rotation, and an optional enum representing which enum it can take. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking of change s_type to be subclass of BytesHasher so that I can use bytes directly with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the last PR, I mentioned the downside - not all photo formats are knowable without the extension. There's likely a workaround, but let's stay with the current course and see if we can fix it in a followup instead. |
||
temp_file.write(rotated_bytes) | ||
temp_file_path = pathlib.Path(temp_file.name) | ||
matches = index.query(s_type.hash_from_file(temp_file_path)) | ||
temp_file_path.unlink() # Clean up the temporary file | ||
haianhng31 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Add rotation information if any matches were found | ||
matches_with_rotations = [] | ||
for match in matches: | ||
matches_with_rotations.append( | ||
_IndexMatchWithRotation(match=match, rotation_type=rotation_type) | ||
) | ||
|
||
all_matches.extend(matches_with_rotations) | ||
|
||
return all_matches | ||
|
||
|
||
def _match_hashes( | ||
path: pathlib.Path, s_type: t.Type[SignalType], index: SignalTypeIndex | ||
) -> t.Sequence[IndexMatch]: | ||
ret: t.List[IndexMatch] = [] | ||
) -> t.Sequence[_IndexMatchWithRotation]: | ||
ret: t.List[_IndexMatchWithRotation] = [] | ||
for hash in path.read_text().splitlines(): | ||
hash = hash.strip() | ||
if not hash: | ||
|
@@ -244,5 +319,6 @@ def _match_hashes( | |
f"{hash_repr} from {path} is not a valid hash for {s_type.get_name()}", | ||
2, | ||
) | ||
ret.extend(index.query(hash)) | ||
matches = index.query(hash) | ||
ret.extend([_IndexMatchWithRotation(match=match) for match in matches]) | ||
return ret |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
This records all the valid signal types for a piece of content. | ||
""" | ||
|
||
from enum import Enum, auto | ||
from enum import Enum | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
import typing as t | ||
|
||
from threatexchange import common | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
""" | ||
from PIL import Image | ||
import io | ||
import typing as t | ||
|
||
from .content_base import ContentType, RotationType | ||
|
||
|
@@ -82,7 +83,7 @@ def flip_minus1(cls, image_data: bytes) -> bytes: | |
return buffer.getvalue() | ||
|
||
@classmethod | ||
def all_simple_rotations(cls, image_data: bytes): | ||
def all_simple_rotations(cls, image_data: bytes) -> t.Dict[RotationType, bytes]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch |
||
""" | ||
Generate the 8 naive rotations of an image. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before processing this list of files, if rotations is true here, generate new files that you iterate through.
Here's one way to do that