@@ -526,8 +526,16 @@ def expand_leaf_transform(spec, scope):
526
526
527
527
# Optional output schema was found, so lets expand on that before returning.
528
528
if output_schema_spec :
529
+ error_handling_spec = {}
530
+ # Obtain original transform error_handling_spec, so that all validate
531
+ # schema errors use that.
532
+ if 'error_handling' in spec .get ('config' , None ):
533
+ error_handling_spec = spec .get ('config' ).get ('error_handling' , {})
534
+
529
535
outputs = expand_output_schema_transform (
530
- spec = output_schema_spec , outputs = outputs )
536
+ spec = output_schema_spec ,
537
+ outputs = outputs ,
538
+ error_handling_spec = error_handling_spec )
531
539
532
540
if isinstance (outputs , dict ):
533
541
# TODO: Handle (or at least reject) nested case.
@@ -544,77 +552,219 @@ def expand_leaf_transform(spec, scope):
544
552
f'{ type (outputs )} ' )
545
553
546
554
547
- def expand_output_schema_transform (spec , outputs ):
548
- """
549
- Expands to add a Validate transform after the current transform and
550
- before returning the output data for it to the next transform.
555
+ def expand_output_schema_transform (spec , outputs , error_handling_spec ):
556
+ """Applies a `Validate` transform to the output of another transform.
557
+
558
+ This function is called when an `output_schema` is defined on a transform.
559
+ It wraps the original transform's output(s) with a `Validate` transform
560
+ to ensure the data conforms to the specified schema.
561
+
562
+ If the original transform has error handling configured, validation errors
563
+ will be routed to the specified error output. If not, validation failures
564
+ will cause the pipeline to fail.
565
+
566
+ Args:
567
+ spec (dict): The `output_schema` specification from the YAML config.
568
+ outputs (beam.PCollection or dict[str, beam.PCollection]): The output(s)
569
+ from the transform to be validated.
570
+ error_handling_spec (dict): The `error_handling` configuration from the
571
+ original transform.
572
+
573
+ Returns:
574
+ The validated PCollection(s). If error handling is enabled, this will be a
575
+ dictionary containing the 'good' output and any error outputs.
576
+
577
+ Raises:
578
+ ValueError: If `error_handling` is incorrectly specified within the
579
+ `output_schema` spec itself, or if the main output of a multi-output
580
+ transform cannot be determined.
551
581
"""
552
- # Check for error handling spec
553
- error_handling_spec = {}
554
582
if 'error_handling' in spec :
555
- error_handling_spec = spec .pop ('error_handling' )
583
+ raise ValueError (
584
+ 'error_handling config is not supported directly in '
585
+ 'the output_schema. Please use error_handling config in '
586
+ 'the transform.' )
556
587
557
588
# Strip metadata such as __line__ and __uuid__ as these will interfere with
558
589
# the validation downstream.
559
590
clean_schema = SafeLineLoader .strip_metadata (spec )
560
591
592
+ # If no error handling is specified for the main transform, warn the user
593
+ # that the pipeline may fail if any output data fails the output schema
594
+ # validation.
595
+ if not error_handling_spec :
596
+ _LOGGER .warning ("Output_schema config is attached to a transform that has " \
597
+ "no error_handling config specified. Any failures validating on output" \
598
+ "schema will fail the pipeline unless the user specifies an" \
599
+ "error_handling config on a capable transform or the user can remove the" \
600
+ "output_schema config on this transform and add a ValidateWithSchema " \
601
+ "transform downstream of the current transform." )
602
+
561
603
# The transform produced outputs with a single beam.PCollection
562
604
if isinstance (outputs , beam .PCollection ):
563
605
outputs = _enforce_schema (
564
606
outputs , 'EnforceOutputSchema' , error_handling_spec , clean_schema )
607
+ if isinstance (outputs , dict ):
608
+ main_tag = error_handling_spec .get ('main_tag' , 'good' )
609
+ main_output = outputs .pop (main_tag )
610
+ if error_handling_spec :
611
+ error_output_tag = error_handling_spec .get ('output' )
612
+ if error_output_tag in outputs :
613
+ return {
614
+ 'output' : main_output ,
615
+ error_output_tag : outputs .pop (error_output_tag )
616
+ }
617
+ return main_output
618
+
565
619
# The transform produced outputs with many named PCollections and need to
566
620
# determine which PCollection should be validated on.
567
621
elif isinstance (outputs , dict ):
568
- main_output_key = 'output'
569
- if main_output_key not in outputs :
570
- if 'good' in outputs :
571
- main_output_key = 'good'
572
- elif len (outputs ) == 1 :
573
- main_output_key = next (iter (outputs .keys ()))
574
- else :
575
- raise ValueError (
576
- f"Transform { identify_object (spec )} has outputs "
577
- f"{ list (outputs .keys ())} , but none are named 'output'. To apply "
578
- "an 'output_schema', please ensure the transform has exactly one "
579
- "output, or that the main output is named 'output'." )
622
+ main_output_key = _get_main_output_key (spec , outputs )
580
623
581
624
validation_result = _enforce_schema (
582
625
outputs [main_output_key ],
583
626
f'EnforceOutputSchema_{ main_output_key } ' ,
584
627
error_handling_spec ,
585
628
clean_schema )
629
+ outputs = _integrate_validation_results (
630
+ outputs , validation_result , main_output_key , error_handling_spec )
586
631
587
- # Integrate the validation results back into the 'outputs' dictionary.
588
- if isinstance (validation_result , dict ):
589
- # The main output from validation is the good output.
590
- main_tag = error_handling_spec .get ('main_tag' , 'good' )
591
- outputs [main_output_key ] = validation_result .pop (main_tag )
592
- outputs .update (validation_result )
632
+ return outputs
633
+
634
+
635
+ def _get_main_output_key (spec , outputs ):
636
+ """Determines the main output key from a dictionary of PCollections.
637
+
638
+ This is used to identify which output of a multi-output transform should be
639
+ validated against an `output_schema`.
640
+
641
+ The main output is determined using the following precedence:
642
+ 1. An output with the key 'output'.
643
+ 2. An output with the key 'good'.
644
+ 3. The single output if there is only one.
645
+
646
+ Args:
647
+ spec: The transform specification, used for creating informative error
648
+ messages.
649
+ outputs: A dictionary mapping output tags to their corresponding
650
+ PCollections.
651
+
652
+ Returns:
653
+ The key of the main output PCollection.
654
+
655
+ Raises:
656
+ ValueError: If a main output cannot be determined because there are
657
+ multiple outputs and none are named 'output' or 'good'.
658
+ """
659
+ main_output_key = 'output'
660
+ if main_output_key not in outputs :
661
+ if 'good' in outputs :
662
+ main_output_key = 'good'
663
+ elif len (outputs ) == 1 :
664
+ main_output_key = next (iter (outputs .keys ()))
593
665
else :
594
- outputs [main_output_key ] = validation_result
666
+ raise ValueError (
667
+ f"Transform { identify_object (spec )} has outputs "
668
+ f"{ list (outputs .keys ())} , but none are named 'output'. To apply "
669
+ "an 'output_schema', please ensure the transform has exactly one "
670
+ "output, or that the main output is named 'output'." )
671
+ return main_output_key
672
+
673
+
674
+ def _integrate_validation_results (
675
+ outputs , validation_result , main_output_key , error_handling_spec ):
676
+ """
677
+ Integrates the results of a validation transform back into the outputs of
678
+ the original transform.
679
+
680
+ This function handles merging the "good" and "bad" outputs from a
681
+ `Validate` transform with the existing outputs of the transform that was
682
+ validated.
683
+
684
+ Args:
685
+ outputs: The original dictionary of output PCollections from the transform.
686
+ validation_result: The output of the `Validate` transform. This can be a
687
+ single PCollection (if all elements passed) or a dictionary of
688
+ PCollections (if error handling was enabled for validation).
689
+ main_output_key: The key in the `outputs` dictionary corresponding to the
690
+ PCollection that was validated.
691
+ error_handling_spec: The error handling configuration of the original
692
+ transform.
693
+
694
+ Returns:
695
+ The updated dictionary of output PCollections, with validation results
696
+ integrated.
697
+
698
+ Raises:
699
+ ValueError: If the validation transform produces unexpected outputs.
700
+ """
701
+ if not isinstance (validation_result , dict ):
702
+ outputs [main_output_key ] = validation_result
703
+ return outputs
704
+
705
+ # The main output from validation is the good output.
706
+ main_tag = error_handling_spec .get ('main_tag' , 'good' )
707
+ outputs [main_output_key ] = validation_result .pop (main_tag )
708
+
709
+ if error_handling_spec :
710
+ error_output_tag = error_handling_spec ['output' ]
711
+ if error_output_tag in validation_result :
712
+ schema_error_pcoll = validation_result .pop (error_output_tag )
713
+ if error_output_tag in outputs :
714
+ # The original transform also had an error output. Merge them.
715
+ outputs [error_output_tag ] = (
716
+ (outputs [error_output_tag ], schema_error_pcoll )
717
+ | f'FlattenErrors_{ main_output_key } ' >> beam .Flatten ())
718
+ else :
719
+ # No error output in the original transform, so just add this one.
720
+ outputs [error_output_tag ] = schema_error_pcoll
721
+
722
+ # There should be no other outputs from validation.
723
+ if validation_result :
724
+ raise ValueError (
725
+ "Unexpected outputs from validation: "
726
+ f"{ list (validation_result .keys ())} " )
595
727
596
728
return outputs
597
729
598
730
599
731
def _enforce_schema (pcoll , label , error_handling_spec , clean_schema ):
732
+ """Applies schema to PCollection elements if necessary, then validates.
733
+
734
+ This function ensures that the input PCollection conforms to a specified
735
+ schema. If the PCollection is schemaless (i.e., its element_type is Any),
736
+ it attempts to convert its elements into schema-aware `beam.Row` objects
737
+ based on the provided `clean_schema`. After ensuring the PCollection has
738
+ a defined schema, it applies a `Validate` transform to perform the actual
739
+ schema validation.
740
+
741
+ Args:
742
+ pcoll: The input PCollection to be schema-enforced and validated.
743
+ label: A string label to be used for the Beam transforms created within this
744
+ function.
745
+ error_handling_spec: A dictionary specifying how to handle validation
746
+ errors.
747
+ clean_schema: A dictionary representing the schema to enforce and validate
748
+ against.
749
+
750
+ Returns:
751
+ A PCollection (or PCollectionTuple if error handling is enabled) resulting
752
+ from the `Validate` transform.
600
753
"""
601
- Applies schema to PCollection elements if necessary, then validates.
602
- """
603
- # This was typically seen when a transform also had its own error handling.
604
754
if pcoll .element_type == typehints .Any :
605
755
_LOGGER .info (
606
756
"PCollection for %s has no schema (element_type=Any). "
607
757
"Converting elements to beam.Row based on provided output_schema." ,
608
758
label )
609
759
try :
610
- # Attempt to conver the schemaless elements into schema-aware beam.Row
760
+ # Attempt to confer the schemaless elements into schema-aware beam.Row
611
761
# objects
612
762
beam_schema = json_utils .json_schema_to_beam_schema (clean_schema )
613
763
row_type_constraint = schemas .named_tuple_from_schema (beam_schema )
614
764
615
765
def to_row (element ):
616
766
"""
617
- Convert a single element inte the row type constraint type.
767
+ Convert a single element into the row type constraint type.
618
768
"""
619
769
if isinstance (element , dict ):
620
770
return row_type_constraint (** element )
0 commit comments