Skip to content
Open
29 changes: 26 additions & 3 deletions sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
import warnings
import weakref

from apache_beam.internal.code_object_pickler import get_code_from_identifier
from apache_beam.internal.code_object_pickler import get_code_object_identifier

# The following import is required to be imported in the cloudpickle
# namespace to be able to load pickle files generated with older versions of
# cloudpickle. See: tests/test_backward_compat.py
Expand Down Expand Up @@ -526,6 +529,11 @@ def _make_function(code, globals, name, argdefs, closure):
return types.FunctionType(code, globals, name, argdefs, closure)


def _make_function_from_identifier(code_path, globals, name, argdefs, closure):
fcode = get_code_from_identifier(code_path)
return _make_function(fcode, globals, name, argdefs, closure)


def _make_empty_cell():
if False:
# trick the compiler into creating an empty cell in our lambda
Expand Down Expand Up @@ -1266,7 +1274,11 @@ def _dynamic_function_reduce(self, func):
"""Reduce a function that is not pickleable via attribute lookup."""
newargs = self._function_getnewargs(func)
state = _function_getstate(func)
return (_make_function, newargs, state, None, None, _function_setstate)
if type(newargs[0]) == str:
make_function = _make_function_from_identifier
else:
make_function = _make_function
return (make_function, newargs, state, None, None, _function_setstate)

def _function_reduce(self, obj):
"""Reducer for function objects.
Expand All @@ -1283,6 +1295,8 @@ def _function_reduce(self, obj):
return self._dynamic_function_reduce(obj)

def _function_getnewargs(self, func):
code_path = get_code_object_identifier(
func) if self.enable_lambda_name else None
code = func.__code__

# base_globals represents the future global namespace of func at
Expand Down Expand Up @@ -1313,7 +1327,10 @@ def _function_getnewargs(self, func):
else:
closure = tuple(_make_empty_cell() for _ in range(len(code.co_freevars)))

return code, base_globals, None, None, closure
if code_path:
return code_path, base_globals, None, None, closure
else:
return code, base_globals, None, None, closure

def dump(self, obj):
try:
Expand All @@ -1326,7 +1343,12 @@ def dump(self, obj):
raise

def __init__(
self, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
self,
file,
protocol=None,
buffer_callback=None,
config=DEFAULT_CONFIG,
enable_lambda_name=False):
if protocol is None:
protocol = DEFAULT_PROTOCOL
super().__init__(file, protocol=protocol, buffer_callback=buffer_callback)
Expand All @@ -1336,6 +1358,7 @@ def __init__(
self.globals_ref = {}
self.proto = int(protocol)
self.config = config
self.enable_lambda_name = enable_lambda_name

if not PYPY:
# pickle.Pickler is the C implementation of the CPython pickler and
Expand Down
4 changes: 3 additions & 1 deletion sdks/python/apache_beam/internal/cloudpickle_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def dumps(
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False,
enable_lambda_name=False,
config: cloudpickle.CloudPickleConfig = DEFAULT_CONFIG) -> bytes:
"""For internal use only; no backwards-compatibility guarantees."""
if enable_best_effort_determinism:
Expand All @@ -127,7 +128,8 @@ def dumps(
'This has only been implemented for dill.')
with _pickle_lock:
with io.BytesIO() as file:
pickler = cloudpickle.CloudPickler(file, config=config)
pickler = cloudpickle.CloudPickler(
file, config=config, enable_lambda_name=enable_lambda_name)
try:
pickler.dispatch_table[type(flags.FLAGS)] = _pickle_absl_flags
except NameError:
Expand Down
90 changes: 58 additions & 32 deletions sdks/python/apache_beam/internal/code_object_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

This module provides helper functions to improve pickling code objects,
especially lambdas, in a consistent way by using code object identifiers. These
helper functions will be used to patch pickler implementations used by Beam
helper functions are used to patch pickler implementations used by Beam
(e.g. Cloudpickle).

A code object identifier is a unique identifier for a code object that provides
Expand Down Expand Up @@ -81,8 +81,9 @@ def get_code_object_identifier(callable: types.FunctionType):
- __main__.ClassWithNestedLambda.process.__code__.co_consts[
<lambda>, ('x',), 1234567890]
"""
if not hasattr(callable, '__module__') or not hasattr(callable,
'__qualname__'):
if (not hasattr(callable, '__module__') or
not hasattr(callable, '__qualname__') or not callable.__module__ or
callable.__module__ not in sys.modules):
return None
code_path: str = _extend_path(
callable.__module__,
Expand All @@ -100,7 +101,7 @@ def _extend_path(prefix: str, current_path: Optional[str]):

Args:
prefix: The prefix of the path.
suffix: The rest of the path.
current_path: The rest of the path.

Returns:
The extended path.
Expand Down Expand Up @@ -189,6 +190,8 @@ def _search_module_or_class(
if path is not None:
return _extend_path(name, _extend_path(f'__defaults__[{i}]', path))
else:
if not hasattr(node, first_part):
return None
return _extend_path(
first_part, _search(callable, getattr(node, first_part), rest))

Expand Down Expand Up @@ -281,6 +284,8 @@ def _search_lambda(
lambda_code_objects_by_name = collections.defaultdict(list)
name = qual_name_parts[0]
code_objects = code_objects_by_name[name]
if not code_objects:
return None
if name == '<lambda>':
for code_object in code_objects:
lambda_name = f'<lambda>, {_signature(code_object)}'
Expand Down Expand Up @@ -315,10 +320,10 @@ def _search_lambda(
_SINGLE_NAME_PATTERN = re.compile(r'co_consts\[([a-zA-Z0-9\<\>_-]+)]')
# Matches a path like: co_consts[<lambda>, ('x',)]
_LAMBDA_WITH_ARGS_PATTERN = re.compile(
r"co_consts\[(<[^>]+>),\s*(\('[^']*'\s*,\s*\))\]")
r"co_consts\[(<.*?>),\s(\('[^']+'(?:,\s*'[^']+')*,?\))\]")
# Matches a path like: co_consts[<lambda>, ('x',), 1234567890]
_LAMBDA_WITH_HASH_PATTERN = re.compile(
r"co_consts\[(<[^>]+>),\s*(\('[^']*'\s*,\s*\)),\s*(.+)\]")
r"co_consts\[(<[^>]+>),\s*(\([^\)]*\)),?\s*(.*)\]")
# Matches a path like: __defaults__[0]
_DEFAULT_PATTERN = re.compile(r'(__defaults__)\[(\d+)\]')
# Matches an argument like: 'x'
Expand All @@ -345,9 +350,10 @@ def _get_code_object_from_single_name_pattern(
raise ValueError(f'Invalid pattern for single name: {name_result.group(0)}')
# Groups are indexed starting at 1, group(0) is the entire match.
name = name_result.group(1)
for co_const in obj.co_consts:
if inspect.iscode(co_const) and co_const.co_name == name:
return co_const
if hasattr(obj, 'co_consts'):
for co_const in obj.co_consts:
if inspect.iscode(co_const) and co_const.co_name == name:
return co_const
raise AttributeError(f'Could not find code object with path: {path}')


Expand All @@ -368,15 +374,16 @@ def _get_code_object_from_lambda_with_args_pattern(
"""
name = lambda_with_args_result.group(1)
code_objects = collections.defaultdict(list)
for co_const in obj.co_consts:
if inspect.iscode(co_const) and co_const.co_name == name:
code_objects[co_const.co_name].append(co_const)
for name, objects in code_objects.items():
for obj_ in objects:
args = tuple(
re.findall(_ARGUMENT_PATTERN, lambda_with_args_result.group(2)))
if obj_.co_varnames == args:
return obj_
if hasattr(obj, 'co_consts'):
for co_const in obj.co_consts:
if inspect.iscode(co_const) and co_const.co_name == name:
code_objects[co_const.co_name].append(co_const)
for name, objects in code_objects.items():
for obj_ in objects:
args = tuple(
re.findall(_ARGUMENT_PATTERN, lambda_with_args_result.group(2)))
if obj_.co_varnames[:_get_arg_count(obj_)] == args:
return obj_
raise AttributeError(f'Could not find code object with path: {path}')


Expand All @@ -397,17 +404,18 @@ def _get_code_object_from_lambda_with_hash_pattern(
"""
name = lambda_with_hash_result.group(1)
code_objects = collections.defaultdict(list)
for co_const in obj.co_consts:
if inspect.iscode(co_const) and co_const.co_name == name:
code_objects[co_const.co_name].append(co_const)
for name, objects in code_objects.items():
for obj_ in objects:
args = tuple(
re.findall(_ARGUMENT_PATTERN, lambda_with_hash_result.group(2)))
if obj_.co_varnames == args:
hash_value = lambda_with_hash_result.group(3)
if hash_value == str(_create_bytecode_hash(obj_)):
return obj_
if hasattr(obj, 'co_consts'):
for co_const in obj.co_consts:
if inspect.iscode(co_const) and co_const.co_name == name:
code_objects[co_const.co_name].append(co_const)
for name, objects in code_objects.items():
for obj_ in objects:
args = tuple(
re.findall(_ARGUMENT_PATTERN, lambda_with_hash_result.group(2)))
if obj_.co_varnames[:_get_arg_count(obj_)] == args:
hash_value = lambda_with_hash_result.group(3)
if hash_value == str(_create_bytecode_hash(obj_)):
return obj_
raise AttributeError(f'Could not find code object with path: {path}')


Expand All @@ -427,6 +435,8 @@ def get_code_from_identifier(code_object_identifier: str):
if not code_object_identifier:
raise ValueError('Path must not be empty.')
parts = code_object_identifier.split('.')
if parts[0] not in sys.modules:
raise AttributeError(f'Module {parts[0]} not found in sys.modules')
obj = sys.modules[parts[0]]
for part in parts[1:]:
if name_result := _SINGLE_NAME_PATTERN.fullmatch(part):
Expand All @@ -447,7 +457,11 @@ def get_code_from_identifier(code_object_identifier: str):
obj = getattr(obj, '__defaults__')[index]
else:
obj = getattr(obj, part)
return obj
if isinstance(obj, types.CodeType):
return obj
else:
raise AttributeError(
f'Could not find code object with path: {code_object_identifier}')


def _signature(obj: types.CodeType):
Expand All @@ -462,12 +476,24 @@ def _signature(obj: types.CodeType):
Returns:
A tuple of the names of the arguments of the code object.
"""
arg_count = (
return obj.co_varnames[:_get_arg_count(obj)]


def _get_arg_count(obj: types.CodeType):
"""Returns the number of arguments of a code object.

Args:
obj: A code object, function, method, or cell.

Returns:
The number of arguments of the code object, or None if the object is not a
code object.
"""
return (
obj.co_argcount + obj.co_kwonlyargcount +
(obj.co_flags & 4 == 4) # PyCF_VARARGS
+ (obj.co_flags & 8 == 8) # PyCF_VARKEYWORDS
)
return obj.co_varnames[:arg_count]


def _create_bytecode_hash(code_object: types.CodeType):
Expand Down
27 changes: 16 additions & 11 deletions sdks/python/apache_beam/internal/code_object_pickler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,14 @@ def test_adding_lambda_variable_in_class_preserves_object(self):
module_2_modified.AddLambdaVariable.my_method(self).__code__,
)

def test_removing_lambda_variable_in_class_changes_object(self):
with self.assertRaisesRegex(AttributeError, "object has no attribute"):
code_object_pickler.get_code_from_identifier(
code_object_pickler.get_code_object_identifier(
module_2.RemoveLambdaVariable.my_method(self)).replace(
"module_2", "module_2_modified"))
def test_removing_lambda_variable_in_class_preserves_object(self):
self.assertEqual(
code_object_pickler.get_code_from_identifier(
code_object_pickler.get_code_object_identifier(
module_2.RemoveLambdaVariable.my_method(self)).replace(
"module_2", "module_2_modified")),
module_2_modified.RemoveLambdaVariable.my_method(self).__code__,
)

def test_adding_nested_function_in_class_preserves_object(self):
self.assertEqual(
Expand Down Expand Up @@ -391,11 +393,14 @@ def test_adding_lambda_variable_in_function_preserves_object(self):
module_1_lambda_variable_added.my_function().__code__,
)

def test_removing_lambda_variable_in_function_raises_exception(self):
with self.assertRaisesRegex(AttributeError, "object has no attribute"):
code_object_pickler.get_code_from_identifier(
code_object_pickler.get_code_object_identifier(
module_3.my_function()).replace("module_3", "module_3_modified"))
def test_removing_lambda_variable_in_function_preserves_object(self):
self.assertEqual(
code_object_pickler.get_code_from_identifier(
code_object_pickler.get_code_object_identifier(
module_3.my_function()).replace(
"module_3", "module_3_modified")),
module_3_modified.my_function().__code__,
)


class CodePathStabilityTest(unittest.TestCase):
Expand Down
7 changes: 6 additions & 1 deletion sdks/python/apache_beam/internal/dill_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,13 @@ def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:
enable_best_effort_determinism=False,
enable_lambda_name=False) -> bytes:
"""For internal use only; no backwards-compatibility guarantees."""
if enable_lambda_name:
logging.info(
'Ignoring unsupported option: enable_lambda_name. '
'This has only been implemented for CloudPickle.')
with _pickle_lock:
if enable_best_effort_determinism:
old_save_set = dill.dill.Pickler.dispatch[set]
Expand Down
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/internal/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:
enable_best_effort_determinism=False,
enable_lambda_name=False) -> bytes:

return desired_pickle_lib.dumps(
o,
enable_trace=enable_trace,
use_zlib=use_zlib,
enable_best_effort_determinism=enable_best_effort_determinism)
enable_best_effort_determinism=enable_best_effort_determinism,
enable_lambda_name=enable_lambda_name)


def loads(encoded, enable_trace=True, use_zlib=False):
Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/internal/pickler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
from apache_beam.internal.pickler import loads


def pickle_depickle(obj, enable_lambda_name):
return loads(dumps(obj, enable_lambda_name=enable_lambda_name))


class PicklerTest(unittest.TestCase):

NO_MAPPINGPROXYTYPE = not hasattr(types, "MappingProxyType")
Expand Down Expand Up @@ -278,6 +282,18 @@ def test_disable_best_effort_determinism(self):
dumps(set1, enable_best_effort_determinism=False),
dumps(set2, enable_best_effort_determinism=False))

def test_enable_lambda_name_pickling(self):
pickler.set_library('cloudpickle')
pickled = pickle_depickle(lambda x: x, enable_lambda_name=True)
pickled_type = type(pickled)
self.assertIsInstance(pickled, pickled_type)

def test_disable_lambda_name_pickling(self):
pickler.set_library('cloudpickle')
pickled = pickle_depickle(lambda x: x, enable_lambda_name=False)
pickled_type = type(pickled)
self.assertIsInstance(pickled, pickled_type)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions sdks/python/apache_beam/runners/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(
self.iterable_state_write = iterable_state_write
self._requirements = set(requirements)
self.enable_best_effort_deterministic_pickling = False
self.enable_lambda_name_pickling = False

def add_requirement(self, requirement: str) -> None:
self._requirements.add(requirement)
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/transforms/ptransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def to_runner_api_pickled(self, context):
self,
enable_best_effort_determinism=context.
enable_best_effort_deterministic_pickling,
enable_lambda_name=context.enable_lambda_name_pickling,
),
)

Expand Down
Loading