diff --git a/src/python/tests/tdx_check.py b/src/python/tests/tdx_check.py index 59fc8d8..83c94ef 100644 --- a/src/python/tests/tdx_check.py +++ b/src/python/tests/tdx_check.py @@ -23,11 +23,7 @@ def _replay_eventlog(): rtmrs = [bytearray(rtmr_len)] * rtmr_cnt event_logs = CCTrustedVmSdk.inst().get_cc_eventlog() assert event_logs is not None - for event in event_logs: - if event.event_type != TcgEventType.EV_NO_ACTION: - sha384_algo = hashlib.sha384() - sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash) - rtmrs[event.imr_index] = sha384_algo.digest() + rtmrs = CCTrustedApi.replay_cc_eventlog(event_logs) return rtmrs def _check_imr(imr_index: int, alg_id: int, rtmr: bytes): @@ -50,7 +46,8 @@ def _check_imr(imr_index: int, alg_id: int, rtmr: bytes): assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384 digest_hash = digest_obj.hash assert digest_hash is not None - assert digest_hash == rtmr, f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}" + assert digest_hash == rtmr, \ + f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}" def tdx_check_measurement_imrs(): """Test measurement result. @@ -59,10 +56,8 @@ def tdx_check_measurement_imrs(): """ alg = CCTrustedVmSdk.inst().get_default_algorithms() rtmrs = _replay_eventlog() - _check_imr(0, alg.alg_id, rtmrs[0]) - _check_imr(1, alg.alg_id, rtmrs[1]) - _check_imr(2, alg.alg_id, rtmrs[2]) - _check_imr(3, alg.alg_id, rtmrs[3]) + for imr_idx, _ in rtmrs.items(): + _check_imr(imr_idx, alg.alg_id, rtmrs[imr_idx][alg.alg_id]) def _gen_valid_nonce(): """Generate nonce for test. @@ -129,10 +124,18 @@ def _check_quote_rtmrs(quote): body = quote.body assert body is not None and isinstance(body, TdxQuoteBody) rtmrs = _replay_eventlog() - assert body.rtmr0 == rtmrs[0], "RTMR0 doesn't equal the replay from event log!" - assert body.rtmr1 == rtmrs[1], "RTMR1 doesn't equal the replay from event log!" - assert body.rtmr2 == rtmrs[2], "RTMR2 doesn't equal the replay from event log!" - assert body.rtmr3 == rtmrs[3], "RTMR3 doesn't equal the replay from event log!" + alg = CCTrustedVmSdk.inst().get_default_algorithms() + # Replay result only contains the RTMR values covered by the event logs + # Need to fill back the RTMRs that are not covered by the event logs + for idx in range(TdxRTMR.RTMR_COUNT): + if idx not in rtmrs.keys(): + rtmrs[idx] = {} + rtmrs[idx][alg.alg_id] = bytearray(TdxRTMR.RTMR_LENGTH_BY_BYTES) + # Compare all the RTMR values + assert body.rtmr0 == rtmrs[0][alg.alg_id], "RTMR0 doesn't equal the replay from event log!" + assert body.rtmr1 == rtmrs[1][alg.alg_id], "RTMR1 doesn't equal the replay from event log!" + assert body.rtmr2 == rtmrs[2][alg.alg_id], "RTMR2 doesn't equal the replay from event log!" + assert body.rtmr3 == rtmrs[3][alg.alg_id], "RTMR3 doesn't equal the replay from event log!" def _check_quote_reportdata(quote, nonce=None, userdata=None): """Check the userdata in quote result.""" @@ -253,6 +256,6 @@ def tdx_check_replay_eventlog_with_invalid_input(): # Check the replay result when input invalid eventlog. invalid_eventlog = _gen_invalid_eventlog() - replay_result = CCTrustedVmSdk.inst().replay_cc_eventlog(invalid_eventlog.event_logs) + replay_result = CCTrustedApi.replay_cc_eventlog(invalid_eventlog.event_logs) assert replay_result is not None assert 0 == len(replay_result)