Skip to content

Commit e0896b2

Browse files
authored
Merge pull request #501 from Dstack-TEE/fix/algo-k256-compat
fix: accept k256 as algorithm alias and add Version RPC
2 parents fe5f5d0 + 6b610ab commit e0896b2

File tree

12 files changed

+488
-6
lines changed

12 files changed

+488
-6
lines changed

guest-agent/rpc/proto/agent_rpc.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ service Tappd {
2929

3030
// Get app info
3131
rpc Info(google.protobuf.Empty) returns (AppInfo) {}
32+
33+
// Get the guest agent version
34+
rpc Version(google.protobuf.Empty) returns (WorkerVersion) {}
3235
}
3336

3437
// The service for the dstack guest agent
@@ -58,6 +61,9 @@ service DstackGuest {
5861

5962
// Verify a signature
6063
rpc Verify(VerifyRequest) returns (VerifyResponse) {}
64+
65+
// Get the guest agent version
66+
rpc Version(google.protobuf.Empty) returns (WorkerVersion) {}
6167
}
6268

6369
// The request to derive a key

guest-agent/src/rpc_service.rs

Lines changed: 188 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ impl DstackGuestRpc for InternalRpcHandler {
268268
async fn get_key(self, request: GetKeyArgs) -> Result<GetKeyResponse> {
269269
let k256_app_key = &self.state.inner.keys.k256_key;
270270

271-
let (key, pubkey_hex) = match request.algorithm.as_str() {
271+
let algorithm = normalize_algorithm(&request.algorithm);
272+
let (key, pubkey_hex) = match algorithm {
272273
"ed25519" => {
273274
let derived_key = derive_key(k256_app_key, &[request.path.as_bytes()], 32)
274275
.context("Failed to derive ed25519 key")?;
@@ -281,7 +282,7 @@ impl DstackGuestRpc for InternalRpcHandler {
281282
let pubkey_hex = hex::encode(signing_key.verifying_key().as_bytes());
282283
(derived_key, pubkey_hex)
283284
}
284-
"secp256k1" | "secp256k1_prehashed" | "" => {
285+
"secp256k1" | "" => {
285286
let derived_key = derive_key(k256_app_key, &[request.path.as_bytes()], 32)
286287
.context("Failed to derive k256 key")?;
287288

@@ -339,14 +340,20 @@ impl DstackGuestRpc for InternalRpcHandler {
339340
}
340341

341342
async fn sign(self, request: SignRequest) -> Result<SignResponse> {
343+
let algorithm = normalize_algorithm(&request.algorithm);
344+
// Use the base algorithm for key derivation (e.g. secp256k1_prehashed -> secp256k1)
345+
let key_algorithm = match algorithm {
346+
"secp256k1_prehashed" => "secp256k1",
347+
other => other,
348+
};
342349
let key_response = self
343350
.get_key(GetKeyArgs {
344351
path: "vms".to_string(),
345352
purpose: "signing".to_string(),
346-
algorithm: request.algorithm.clone(),
353+
algorithm: key_algorithm.to_string(),
347354
})
348355
.await?;
349-
let (signature, public_key) = match request.algorithm.as_str() {
356+
let (signature, public_key) = match algorithm {
350357
"ed25519" => {
351358
let key_bytes: [u8; 32] = key_response
352359
.key
@@ -392,7 +399,8 @@ impl DstackGuestRpc for InternalRpcHandler {
392399
}
393400

394401
async fn verify(self, request: VerifyRequest) -> Result<VerifyResponse> {
395-
let valid = match request.algorithm.as_str() {
402+
let algorithm = normalize_algorithm(&request.algorithm);
403+
let valid = match algorithm {
396404
"ed25519" => {
397405
let verifying_key = ed25519_dalek::VerifyingKey::from_bytes(
398406
&request
@@ -436,6 +444,22 @@ impl DstackGuestRpc for InternalRpcHandler {
436444
attestation: attestation.into_versioned().to_scale(),
437445
})
438446
}
447+
448+
async fn version(self) -> Result<WorkerVersion> {
449+
Ok(WorkerVersion {
450+
version: env!("CARGO_PKG_VERSION").to_string(),
451+
rev: super::GIT_REV.to_string(),
452+
})
453+
}
454+
}
455+
456+
/// Normalize algorithm name to canonical form.
457+
/// Accepts "k256" as an alias for "secp256k1".
458+
fn normalize_algorithm(algorithm: &str) -> &str {
459+
match algorithm {
460+
"k256" => "secp256k1",
461+
other => other,
462+
}
439463
}
440464

441465
fn pad64(data: &[u8]) -> Option<[u8; 64]> {
@@ -595,6 +619,13 @@ impl TappdRpc for InternalRpcHandlerV0 {
595619
async fn info(self) -> Result<AppInfo> {
596620
get_info(&self.state, false).await
597621
}
622+
623+
async fn version(self) -> Result<WorkerVersion> {
624+
Ok(WorkerVersion {
625+
version: env!("CARGO_PKG_VERSION").to_string(),
626+
rev: super::GIT_REV.to_string(),
627+
})
628+
}
598629
}
599630

600631
impl RpcCall<AppState> for InternalRpcHandlerV0 {
@@ -643,7 +674,8 @@ impl WorkerRpc for ExternalRpcHandler {
643674
})
644675
.await?;
645676

646-
match request.algorithm.as_str() {
677+
let algorithm = normalize_algorithm(&request.algorithm);
678+
match algorithm {
647679
"ed25519" => {
648680
let key_bytes: [u8; 32] = key_response
649681
.key
@@ -1128,4 +1160,154 @@ pNs85uhOZE8z2jr8Pg==
11281160
assert!(result.is_err());
11291161
assert_eq!(result.unwrap_err().to_string(), "Unsupported algorithm");
11301162
}
1163+
1164+
#[test]
1165+
fn test_normalize_algorithm() {
1166+
assert_eq!(normalize_algorithm("k256"), "secp256k1");
1167+
assert_eq!(normalize_algorithm("secp256k1"), "secp256k1");
1168+
assert_eq!(normalize_algorithm("ed25519"), "ed25519");
1169+
assert_eq!(normalize_algorithm(""), "");
1170+
assert_eq!(normalize_algorithm("unknown"), "unknown");
1171+
}
1172+
1173+
#[tokio::test]
1174+
async fn test_get_key_k256_alias() {
1175+
let (state, _guard) = setup_test_state().await;
1176+
let handler_k256 = InternalRpcHandler {
1177+
state: state.clone(),
1178+
};
1179+
let handler_secp = InternalRpcHandler {
1180+
state: state.clone(),
1181+
};
1182+
1183+
let req_k256 = GetKeyArgs {
1184+
path: "test".to_string(),
1185+
purpose: "signing".to_string(),
1186+
algorithm: "k256".to_string(),
1187+
};
1188+
let req_secp = GetKeyArgs {
1189+
path: "test".to_string(),
1190+
purpose: "signing".to_string(),
1191+
algorithm: "secp256k1".to_string(),
1192+
};
1193+
1194+
let resp_k256 = handler_k256.get_key(req_k256).await.unwrap();
1195+
let resp_secp = handler_secp.get_key(req_secp).await.unwrap();
1196+
1197+
// k256 alias should produce the same key as secp256k1
1198+
assert_eq!(resp_k256.key, resp_secp.key);
1199+
}
1200+
1201+
#[tokio::test]
1202+
async fn test_get_key_secp256k1_prehashed_rejected() {
1203+
let (state, _guard) = setup_test_state().await;
1204+
let handler = InternalRpcHandler { state };
1205+
1206+
let request = GetKeyArgs {
1207+
path: "test".to_string(),
1208+
purpose: "signing".to_string(),
1209+
algorithm: "secp256k1_prehashed".to_string(),
1210+
};
1211+
1212+
let result = handler.get_key(request).await;
1213+
assert!(result.is_err());
1214+
assert_eq!(result.unwrap_err().to_string(), "Unsupported algorithm");
1215+
}
1216+
1217+
#[tokio::test]
1218+
async fn test_get_key_ed25519_success() {
1219+
let (state, _guard) = setup_test_state().await;
1220+
let handler = InternalRpcHandler { state };
1221+
1222+
let request = GetKeyArgs {
1223+
path: "test".to_string(),
1224+
purpose: "signing".to_string(),
1225+
algorithm: "ed25519".to_string(),
1226+
};
1227+
1228+
let response = handler.get_key(request).await.unwrap();
1229+
assert!(!response.key.is_empty());
1230+
assert_eq!(response.signature_chain.len(), 2);
1231+
}
1232+
1233+
#[tokio::test]
1234+
async fn test_get_key_default_algorithm() {
1235+
let (state, _guard) = setup_test_state().await;
1236+
let handler_default = InternalRpcHandler {
1237+
state: state.clone(),
1238+
};
1239+
let handler_secp = InternalRpcHandler {
1240+
state: state.clone(),
1241+
};
1242+
1243+
let req_default = GetKeyArgs {
1244+
path: "test".to_string(),
1245+
purpose: "signing".to_string(),
1246+
algorithm: "".to_string(),
1247+
};
1248+
let req_secp = GetKeyArgs {
1249+
path: "test".to_string(),
1250+
purpose: "signing".to_string(),
1251+
algorithm: "secp256k1".to_string(),
1252+
};
1253+
1254+
let resp_default = handler_default.get_key(req_default).await.unwrap();
1255+
let resp_secp = handler_secp.get_key(req_secp).await.unwrap();
1256+
1257+
// Empty algorithm should default to secp256k1
1258+
assert_eq!(resp_default.key, resp_secp.key);
1259+
}
1260+
1261+
#[tokio::test]
1262+
async fn test_get_key_unsupported_algorithm_fails() {
1263+
let (state, _guard) = setup_test_state().await;
1264+
let handler = InternalRpcHandler { state };
1265+
1266+
let request = GetKeyArgs {
1267+
path: "test".to_string(),
1268+
purpose: "signing".to_string(),
1269+
algorithm: "rsa".to_string(),
1270+
};
1271+
1272+
let result = handler.get_key(request).await;
1273+
assert!(result.is_err());
1274+
assert_eq!(result.unwrap_err().to_string(), "Unsupported algorithm");
1275+
}
1276+
1277+
#[tokio::test]
1278+
async fn test_version() {
1279+
let (state, _guard) = setup_test_state().await;
1280+
let handler = InternalRpcHandler { state };
1281+
1282+
let response = handler.version().await.unwrap();
1283+
assert!(!response.version.is_empty());
1284+
}
1285+
1286+
#[tokio::test]
1287+
async fn test_sign_k256_alias() {
1288+
let (state, _guard) = setup_test_state().await;
1289+
let handler_k256 = InternalRpcHandler {
1290+
state: state.clone(),
1291+
};
1292+
let handler_secp = InternalRpcHandler {
1293+
state: state.clone(),
1294+
};
1295+
1296+
let data = b"test message".to_vec();
1297+
1298+
let req_k256 = SignRequest {
1299+
algorithm: "k256".to_string(),
1300+
data: data.clone(),
1301+
};
1302+
let req_secp = SignRequest {
1303+
algorithm: "secp256k1".to_string(),
1304+
data: data.clone(),
1305+
};
1306+
1307+
let resp_k256 = handler_k256.sign(req_k256).await.unwrap();
1308+
let resp_secp = handler_secp.sign(req_secp).await.unwrap();
1309+
1310+
// k256 alias should produce the same public key as secp256k1
1311+
assert_eq!(resp_k256.public_key, resp_secp.public_key);
1312+
}
11311313
}

sdk/go/dstack/client.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,33 @@ func (c *DstackClient) GetTlsKey(
369369
return &response, nil
370370
}
371371

372+
// requiresVersionCheck returns true for algorithms that need OS >= 0.5.7.
373+
func requiresVersionCheck(algorithm string) bool {
374+
switch algorithm {
375+
case "secp256k1", "k256", "":
376+
return false
377+
default:
378+
return true
379+
}
380+
}
381+
382+
// ensureAlgorithmSupported checks the OS version when a non-secp256k1 algorithm is requested.
383+
// On old OS (no Version RPC), it returns an error to prevent silent key type mismatch.
384+
func (c *DstackClient) ensureAlgorithmSupported(ctx context.Context, algorithm string) error {
385+
if !requiresVersionCheck(algorithm) {
386+
return nil
387+
}
388+
if _, err := c.GetVersion(ctx); err != nil {
389+
return fmt.Errorf("algorithm %q is not supported: OS version too old (Version RPC unavailable)", algorithm)
390+
}
391+
return nil
392+
}
393+
372394
// Gets a key from the dstack service.
373395
func (c *DstackClient) GetKey(ctx context.Context, path string, purpose string, algorithm string) (*GetKeyResponse, error) {
396+
if err := c.ensureAlgorithmSupported(ctx, algorithm); err != nil {
397+
return nil, err
398+
}
374399
payload := map[string]interface{}{
375400
"path": path,
376401
"purpose": purpose,
@@ -460,6 +485,29 @@ func (c *DstackClient) Attest(ctx context.Context, reportData []byte) (*AttestRe
460485
return &AttestResponse{Attestation: attestation}, nil
461486
}
462487

488+
// Represents the response from a Version request.
489+
type VersionResponse struct {
490+
Version string `json:"version"`
491+
Rev string `json:"rev"`
492+
}
493+
494+
// Gets the guest-agent version.
495+
//
496+
// Returns the version on OS >= 0.5.7.
497+
// Returns an error on older OS versions that lack the Version RPC.
498+
func (c *DstackClient) GetVersion(ctx context.Context) (*VersionResponse, error) {
499+
data, err := c.sendRPCRequest(ctx, "/Version", map[string]interface{}{})
500+
if err != nil {
501+
return nil, err
502+
}
503+
504+
var response VersionResponse
505+
if err := json.Unmarshal(data, &response); err != nil {
506+
return nil, err
507+
}
508+
return &response, nil
509+
}
510+
463511
// Sends a request to get information about the CVM instance
464512
func (c *DstackClient) Info(ctx context.Context) (*InfoResponse, error) {
465513
data, err := c.sendRPCRequest(ctx, "/Info", map[string]interface{}{})

0 commit comments

Comments
 (0)