Skip to content

Commit 75ae4eb

Browse files
committed
update code logic based on new output design
1 parent e9b82eb commit 75ae4eb

File tree

1 file changed

+182
-32
lines changed

1 file changed

+182
-32
lines changed

sdks/python/apache_beam/yaml/yaml_transform.py

Lines changed: 182 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,16 @@ def expand_leaf_transform(spec, scope):
526526

527527
# Optional output schema was found, so lets expand on that before returning.
528528
if output_schema_spec:
529+
error_handling_spec = {}
530+
# Obtain original transform error_handling_spec, so that all validate
531+
# schema errors use that.
532+
if 'error_handling' in spec.get('config', None):
533+
error_handling_spec = spec.get('config').get('error_handling', {})
534+
529535
outputs = expand_output_schema_transform(
530-
spec=output_schema_spec, outputs=outputs)
536+
spec=output_schema_spec,
537+
outputs=outputs,
538+
error_handling_spec=error_handling_spec)
531539

532540
if isinstance(outputs, dict):
533541
# TODO: Handle (or at least reject) nested case.
@@ -544,77 +552,219 @@ def expand_leaf_transform(spec, scope):
544552
f'{type(outputs)}')
545553

546554

547-
def expand_output_schema_transform(spec, outputs):
548-
"""
549-
Expands to add a Validate transform after the current transform and
550-
before returning the output data for it to the next transform.
555+
def expand_output_schema_transform(spec, outputs, error_handling_spec):
556+
"""Applies a `Validate` transform to the output of another transform.
557+
558+
This function is called when an `output_schema` is defined on a transform.
559+
It wraps the original transform's output(s) with a `Validate` transform
560+
to ensure the data conforms to the specified schema.
561+
562+
If the original transform has error handling configured, validation errors
563+
will be routed to the specified error output. If not, validation failures
564+
will cause the pipeline to fail.
565+
566+
Args:
567+
spec (dict): The `output_schema` specification from the YAML config.
568+
outputs (beam.PCollection or dict[str, beam.PCollection]): The output(s)
569+
from the transform to be validated.
570+
error_handling_spec (dict): The `error_handling` configuration from the
571+
original transform.
572+
573+
Returns:
574+
The validated PCollection(s). If error handling is enabled, this will be a
575+
dictionary containing the 'good' output and any error outputs.
576+
577+
Raises:
578+
ValueError: If `error_handling` is incorrectly specified within the
579+
`output_schema` spec itself, or if the main output of a multi-output
580+
transform cannot be determined.
551581
"""
552-
# Check for error handling spec
553-
error_handling_spec = {}
554582
if 'error_handling' in spec:
555-
error_handling_spec = spec.pop('error_handling')
583+
raise ValueError(
584+
'error_handling config is not supported directly in '
585+
'the output_schema. Please use error_handling config in '
586+
'the transform.')
556587

557588
# Strip metadata such as __line__ and __uuid__ as these will interfere with
558589
# the validation downstream.
559590
clean_schema = SafeLineLoader.strip_metadata(spec)
560591

592+
# If no error handling is specified for the main transform, warn the user
593+
# that the pipeline may fail if any output data fails the output schema
594+
# validation.
595+
if not error_handling_spec:
596+
_LOGGER.warning("Output_schema config is attached to a transform that has "\
597+
"no error_handling config specified. Any failures validating on output" \
598+
"schema will fail the pipeline unless the user specifies an" \
599+
"error_handling config on a capable transform or the user can remove the" \
600+
"output_schema config on this transform and add a ValidateWithSchema " \
601+
"transform downstream of the current transform.")
602+
561603
# The transform produced outputs with a single beam.PCollection
562604
if isinstance(outputs, beam.PCollection):
563605
outputs = _enforce_schema(
564606
outputs, 'EnforceOutputSchema', error_handling_spec, clean_schema)
607+
if isinstance(outputs, dict):
608+
main_tag = error_handling_spec.get('main_tag', 'good')
609+
main_output = outputs.pop(main_tag)
610+
if error_handling_spec:
611+
error_output_tag = error_handling_spec.get('output')
612+
if error_output_tag in outputs:
613+
return {
614+
'output': main_output,
615+
error_output_tag: outputs.pop(error_output_tag)
616+
}
617+
return main_output
618+
565619
# The transform produced outputs with many named PCollections and need to
566620
# determine which PCollection should be validated on.
567621
elif isinstance(outputs, dict):
568-
main_output_key = 'output'
569-
if main_output_key not in outputs:
570-
if 'good' in outputs:
571-
main_output_key = 'good'
572-
elif len(outputs) == 1:
573-
main_output_key = next(iter(outputs.keys()))
574-
else:
575-
raise ValueError(
576-
f"Transform {identify_object(spec)} has outputs "
577-
f"{list(outputs.keys())}, but none are named 'output'. To apply "
578-
"an 'output_schema', please ensure the transform has exactly one "
579-
"output, or that the main output is named 'output'.")
622+
main_output_key = _get_main_output_key(spec, outputs)
580623

581624
validation_result = _enforce_schema(
582625
outputs[main_output_key],
583626
f'EnforceOutputSchema_{main_output_key}',
584627
error_handling_spec,
585628
clean_schema)
629+
outputs = _integrate_validation_results(
630+
outputs, validation_result, main_output_key, error_handling_spec)
586631

587-
# Integrate the validation results back into the 'outputs' dictionary.
588-
if isinstance(validation_result, dict):
589-
# The main output from validation is the good output.
590-
main_tag = error_handling_spec.get('main_tag', 'good')
591-
outputs[main_output_key] = validation_result.pop(main_tag)
592-
outputs.update(validation_result)
632+
return outputs
633+
634+
635+
def _get_main_output_key(spec, outputs):
636+
"""Determines the main output key from a dictionary of PCollections.
637+
638+
This is used to identify which output of a multi-output transform should be
639+
validated against an `output_schema`.
640+
641+
The main output is determined using the following precedence:
642+
1. An output with the key 'output'.
643+
2. An output with the key 'good'.
644+
3. The single output if there is only one.
645+
646+
Args:
647+
spec: The transform specification, used for creating informative error
648+
messages.
649+
outputs: A dictionary mapping output tags to their corresponding
650+
PCollections.
651+
652+
Returns:
653+
The key of the main output PCollection.
654+
655+
Raises:
656+
ValueError: If a main output cannot be determined because there are
657+
multiple outputs and none are named 'output' or 'good'.
658+
"""
659+
main_output_key = 'output'
660+
if main_output_key not in outputs:
661+
if 'good' in outputs:
662+
main_output_key = 'good'
663+
elif len(outputs) == 1:
664+
main_output_key = next(iter(outputs.keys()))
593665
else:
594-
outputs[main_output_key] = validation_result
666+
raise ValueError(
667+
f"Transform {identify_object(spec)} has outputs "
668+
f"{list(outputs.keys())}, but none are named 'output'. To apply "
669+
"an 'output_schema', please ensure the transform has exactly one "
670+
"output, or that the main output is named 'output'.")
671+
return main_output_key
672+
673+
674+
def _integrate_validation_results(
675+
outputs, validation_result, main_output_key, error_handling_spec):
676+
"""
677+
Integrates the results of a validation transform back into the outputs of
678+
the original transform.
679+
680+
This function handles merging the "good" and "bad" outputs from a
681+
`Validate` transform with the existing outputs of the transform that was
682+
validated.
683+
684+
Args:
685+
outputs: The original dictionary of output PCollections from the transform.
686+
validation_result: The output of the `Validate` transform. This can be a
687+
single PCollection (if all elements passed) or a dictionary of
688+
PCollections (if error handling was enabled for validation).
689+
main_output_key: The key in the `outputs` dictionary corresponding to the
690+
PCollection that was validated.
691+
error_handling_spec: The error handling configuration of the original
692+
transform.
693+
694+
Returns:
695+
The updated dictionary of output PCollections, with validation results
696+
integrated.
697+
698+
Raises:
699+
ValueError: If the validation transform produces unexpected outputs.
700+
"""
701+
if not isinstance(validation_result, dict):
702+
outputs[main_output_key] = validation_result
703+
return outputs
704+
705+
# The main output from validation is the good output.
706+
main_tag = error_handling_spec.get('main_tag', 'good')
707+
outputs[main_output_key] = validation_result.pop(main_tag)
708+
709+
if error_handling_spec:
710+
error_output_tag = error_handling_spec['output']
711+
if error_output_tag in validation_result:
712+
schema_error_pcoll = validation_result.pop(error_output_tag)
713+
if error_output_tag in outputs:
714+
# The original transform also had an error output. Merge them.
715+
outputs[error_output_tag] = (
716+
(outputs[error_output_tag], schema_error_pcoll)
717+
| f'FlattenErrors_{main_output_key}' >> beam.Flatten())
718+
else:
719+
# No error output in the original transform, so just add this one.
720+
outputs[error_output_tag] = schema_error_pcoll
721+
722+
# There should be no other outputs from validation.
723+
if validation_result:
724+
raise ValueError(
725+
"Unexpected outputs from validation: "
726+
f"{list(validation_result.keys())}")
595727

596728
return outputs
597729

598730

599731
def _enforce_schema(pcoll, label, error_handling_spec, clean_schema):
732+
"""Applies schema to PCollection elements if necessary, then validates.
733+
734+
This function ensures that the input PCollection conforms to a specified
735+
schema. If the PCollection is schemaless (i.e., its element_type is Any),
736+
it attempts to convert its elements into schema-aware `beam.Row` objects
737+
based on the provided `clean_schema`. After ensuring the PCollection has
738+
a defined schema, it applies a `Validate` transform to perform the actual
739+
schema validation.
740+
741+
Args:
742+
pcoll: The input PCollection to be schema-enforced and validated.
743+
label: A string label to be used for the Beam transforms created within this
744+
function.
745+
error_handling_spec: A dictionary specifying how to handle validation
746+
errors.
747+
clean_schema: A dictionary representing the schema to enforce and validate
748+
against.
749+
750+
Returns:
751+
A PCollection (or PCollectionTuple if error handling is enabled) resulting
752+
from the `Validate` transform.
600753
"""
601-
Applies schema to PCollection elements if necessary, then validates.
602-
"""
603-
# This was typically seen when a transform also had its own error handling.
604754
if pcoll.element_type == typehints.Any:
605755
_LOGGER.info(
606756
"PCollection for %s has no schema (element_type=Any). "
607757
"Converting elements to beam.Row based on provided output_schema.",
608758
label)
609759
try:
610-
# Attempt to conver the schemaless elements into schema-aware beam.Row
760+
# Attempt to confer the schemaless elements into schema-aware beam.Row
611761
# objects
612762
beam_schema = json_utils.json_schema_to_beam_schema(clean_schema)
613763
row_type_constraint = schemas.named_tuple_from_schema(beam_schema)
614764

615765
def to_row(element):
616766
"""
617-
Convert a single element inte the row type constraint type.
767+
Convert a single element into the row type constraint type.
618768
"""
619769
if isinstance(element, dict):
620770
return row_type_constraint(**element)

0 commit comments

Comments
 (0)