Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Aimstack crashing on FIPS enabled servers. #3217

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions aim/storage/hashing/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Our implementation are less prone to manually designed collisions.
"""

import _hashlib
import hashlib

from typing import Tuple, Union
Expand Down Expand Up @@ -34,6 +35,44 @@
_HASH_STR_SALT = encode_int64(7540324813251503183)
_HASH_BYTES_SALT = encode_int64(-6836296829636613855)

# Invoke hashlib algorithm based on the security mode.
# In normal mode use the original blake2b based hashing.
# If we in the restrictive FIPS mode, RHEL FIPS mode restricts
# the hashlib functions like blake2 to use openssl blake2 implementations
# which limit the parameters and hence doesn't allow to customise the digest size.
# So we use shake_256 as alternative in FIPS mode which provides variable length
# digest support and is an acceptable SHA-3 Algorithm.
# This class writes a wrapper as the digest signature is different for both.
class aim_hash_algorithm:
digest_size: int = _HASH_SIZE
salt: int
is_fips_mode_enabled: bool
hashlib_state: None

# Based on the FIPS mode choose between blake2b or shake_256 hash function
def _invoke_hashlib(self):
if not self.is_fips_mode_enabled:
return hashlib.blake2b(digest_size=self.digest_size, salt=self.salt)
else:
return hashlib.shake_256()

def __init__(self, digest_size = None, salt = None):
if digest_size:
self.digest_size = digest_size
self.salt = salt
self.is_fips_mode_enabled = True if _hashlib.get_fips_mode() == 1 else False
self.hashlib_state = self._invoke_hashlib()

def update(self, obj: bytes):
self.hashlib_state.update(obj)

def digest(self):
if not self.is_fips_mode_enabled:
# blake2 digest signature
return self.hashlib_state.digest()
else:
# shake_256 digest signature with variable length
return self.hashlib_state.digest(length=self.digest_size)

def hash_none(obj: NoneType = None) -> int:
"""Hash None values."""
Expand All @@ -47,7 +86,8 @@ def hash_uniform(bad_hash):
in real applications) craft / find such examples that `a != b` but
`hash(a) == hash(b)`
"""
state = hashlib.blake2b(encode_int64(bad_hash), digest_size=_HASH_SIZE, salt=_HASH_UNIFORM_SALT)
state = aim_hash_algorithm(salt=_HASH_UNIFORM_SALT)
state.update(encode_int64(bad_hash))
return decode_int64(state.digest())


Expand Down Expand Up @@ -75,17 +115,18 @@ def hash_bool(obj: bool) -> int:
def hash_bytes(obj: bytes) -> int:
"""Hash an `bytes` buffer"""
# We use `blake2b` to hash the `bytes` object
state = hashlib.blake2b(obj, digest_size=_HASH_SIZE, salt=_HASH_BYTES_SALT)
state = aim_hash_algorithm(salt=_HASH_BYTES_SALT)
state.update(obj)
return decode_int64(state.digest())


def hash_string(obj: str) -> int:
"""Hash an string object"""
# Similar to `bytes`, we use `blake2b` to hash strings as well
# First, we encode them to `utf-8` and then compute the hash
# but *a different hash seed is provided* to make sure strings and their
# utf-8 encoded blobs do not map to the same hash.
state = hashlib.blake2b(obj.encode('utf-8'), digest_size=_HASH_SIZE, salt=_HASH_STR_SALT)
state = aim_hash_algorithm(salt=_HASH_STR_SALT)
state.update(obj.encode('utf-8'))
return decode_int64(state.digest())


Expand All @@ -95,11 +136,10 @@ def hash_array(obj: AimObjectArray) -> int:
We do not take into account whether it is a `list` or `tuple`, so
`hash([1, 2, ['x', 5]]) == hash((1, 2, ('x', 5)))`
"""
state = hashlib.blake2b(digest_size=_HASH_SIZE, salt=_HASH_ARRAY_SALT)
state = aim_hash_algorithm(salt=_HASH_ARRAY_SALT)
for i in obj:
piece_hash = hash_auto(i)
state.update(encode_int64(piece_hash))

return decode_int64(state.digest())


Expand All @@ -117,7 +157,7 @@ def hash_object(obj: AimObjectDict) -> int:
The implementation does not take into account the order
`hash({'a': 5, 'b': 7}) == hash({'b': 7, 'a': 5})`
"""
state = hashlib.blake2b(digest_size=_HASH_SIZE, salt=_HASH_OBJECT_SALT)
state = aim_hash_algorithm(salt=_HASH_OBJECT_SALT)
# Here we use `key_cmp` to run over the object keys in an (meaningless but)
# deterministic order.
for key_val_tuple in sorted(obj.items(), key=key_cmp):
Expand Down
23 changes: 12 additions & 11 deletions tests/api/test_run_images_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def tearDown(self) -> None:
@parameterized.expand([(1,), (5,), (10,)])
def test_images_uri_bulk_load_api(self, uri_count):
# take random N URIs
uris = random.sample(self.uri_map.keys(), uri_count)
uris = random.sample(list(self.uri_map.keys()), uri_count)

client = self.client
response = client.post('/api/runs/images/get-batch', json=uris)
Expand Down Expand Up @@ -420,12 +420,12 @@ def test_run_info_get_all_sequences_api(self, qparams, trace_type_count):
self.assertEqual('image_lists', response_data['traces']['images'][0]['name'])
metrics_data = response_data['traces']['metric']
self.assertEqual(3, len(metrics_data))
self.assertEqual('floats', metrics_data[0]['name'])
self.assertEqual('floats', metrics_data[1]['name'])
self.assertEqual('integers', metrics_data[2]['name'])
self.assertDictEqual({'subset': 'val'}, metrics_data[0]['context'])
self.assertDictEqual({'subset': 'train'}, metrics_data[1]['context'])
self.assertDictEqual({'subset': 'train'}, metrics_data[2]['context'])
contexts = []
for m in metrics_data:
contexts.append((m['context'], m['name']))
self.assertTrue(({'subset': 'val'}, 'floats') in contexts)
self.assertTrue(({'subset': 'train'}, 'floats') in contexts)
self.assertTrue(({'subset': 'train'}, 'integers') in contexts)

response = client.get(f'api/runs/{self.run2_hash}/info', params={'sequence': ('images', 'metric')})
self.assertEqual(200, response.status_code)
Expand All @@ -437,10 +437,11 @@ def test_run_info_get_all_sequences_api(self, qparams, trace_type_count):
self.assertEqual('single_images', response_data['traces']['images'][0]['name'])
metrics_data = response_data['traces']['metric']
self.assertEqual(2, len(metrics_data))
self.assertEqual('floats', metrics_data[0]['name'])
self.assertEqual('floats', metrics_data[1]['name'])
self.assertDictEqual({'subset': 'val'}, metrics_data[0]['context'])
self.assertDictEqual({'subset': 'train'}, metrics_data[1]['context'])
contexts = []
for m in metrics_data:
contexts.append((m['context'], m['name']))
self.assertTrue(({'subset': 'val'}, 'floats') in contexts)
self.assertTrue(({'subset': 'train'}, 'floats') in contexts)

def test_run_info_get_metrics_only_api(self):
client = self.client
Expand Down