Skip to content

Commit 4eafbb5

Browse files
committed
Add configurable retries for the pin
Support for an optional field for setting the number of retries. This can be either a positive number or 'inifinity'. If no value is specify, then the pin tries 10 times all the servers. Signed-off-by: Alice Frosi <[email protected]>
1 parent a605f02 commit 4eafbb5

File tree

2 files changed

+147
-31
lines changed

2 files changed

+147
-31
lines changed

cli/src/main.rs

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,18 @@ struct ClevisHeader {
111111
servers: Vec<Server>,
112112
path: String,
113113
initdata: Option<String>,
114+
#[serde(default)]
115+
num_retries: Option<NumRetries>,
114116
}
115117

116118
fn fetch_and_prepare_jwk<E: CommandExecutor>(
117119
servers: &[Server],
118120
path: &str,
119121
initdata: Option<String>,
122+
num_retries: &NumRetries,
120123
executor: &E,
121124
) -> Result<Jwk> {
122-
let key = fetch_luks_key(servers, path, initdata, executor)?;
125+
let key = fetch_luks_key(servers, path, initdata, num_retries, executor)?;
123126
let key = String::from_utf8(
124127
general_purpose::STANDARD
125128
.decode(&key)
@@ -160,7 +163,8 @@ fn encrypt(config: &str) -> Result<()> {
160163
io::stdin().read_to_end(&mut input)?;
161164

162165
let executor = RealCommandExecutor;
163-
let jwk = fetch_and_prepare_jwk(&config.servers, &config.path, initdata.clone(), &executor)?;
166+
let num_retries = config.num_retries.as_ref().unwrap_or(&NumRetries::Finite(10));
167+
let jwk = fetch_and_prepare_jwk(&config.servers, &config.path, initdata.clone(), num_retries, &executor)?;
164168

165169
eprintln!("{}", jwk);
166170
let encrypter = Dir
@@ -172,6 +176,7 @@ fn encrypt(config: &str) -> Result<()> {
172176
servers: config.servers.clone(),
173177
path: config.path,
174178
initdata,
179+
num_retries: config.num_retries,
175180
};
176181

177182
let mut hdr = josekit::jwe::JweHeader::new();
@@ -207,10 +212,12 @@ fn decrypt() -> Result<()> {
207212
eprintln!("Decrypt with header: {:?}", hdr_clevis);
208213

209214
let executor = RealCommandExecutor;
215+
let num_retries = hdr_clevis.num_retries.as_ref().unwrap_or(&NumRetries::Finite(10));
210216
let decrypter_jwk = fetch_and_prepare_jwk(
211217
&hdr_clevis.servers,
212218
&hdr_clevis.path,
213219
hdr_clevis.initdata,
220+
num_retries,
214221
&executor,
215222
)?;
216223

@@ -227,55 +234,87 @@ fn decrypt() -> Result<()> {
227234
Ok(())
228235
}
229236

237+
fn try_fetch_from_servers<E: CommandExecutor>(
238+
servers: &[Server],
239+
path: &str,
240+
initdata: &Option<String>,
241+
executor: &E,
242+
) -> Option<String> {
243+
for (index, server) in servers.iter().enumerate() {
244+
eprintln!("Trying URL {}/{}: {}", index + 1, servers.len(), server.url);
245+
match executor.try_fetch_luks_key(&server.url, path, &server.cert, initdata.clone()) {
246+
Ok(key) => {
247+
eprintln!("Successfully fetched LUKS key from URL: {}", server.url);
248+
return Some(key);
249+
}
250+
Err(e) => {
251+
eprintln!("Error with URL {}: {}", server.url, e);
252+
}
253+
}
254+
}
255+
None
256+
}
257+
230258
fn fetch_luks_key<E: CommandExecutor>(
231259
servers: &[Server],
232260
path: &str,
233261
initdata: Option<String>,
262+
num_retries: &NumRetries,
234263
executor: &E,
235264
) -> Result<String> {
236-
const MAX_ATTEMPTS: u32 = 3;
237265
const DELAY: Duration = Duration::from_secs(5);
238266

239267
if servers.is_empty() {
240268
return Err(anyhow!("No URLs provided"));
241269
}
242270

243-
(1..=MAX_ATTEMPTS)
244-
.find_map(|attempt| {
245-
eprintln!(
246-
"Attempting to fetch LUKS key (attempt {}/{})",
247-
attempt, MAX_ATTEMPTS
248-
);
249-
250-
for (index, server) in servers.iter().enumerate() {
251-
eprintln!("Trying URL {}/{}: {}", index + 1, servers.len(), server.url);
252-
match executor.try_fetch_luks_key(&server.url, path, &server.cert, initdata.clone())
253-
{
254-
Ok(key) => {
255-
eprintln!("Successfully fetched LUKS key from URL: {}", server.url);
271+
match num_retries {
272+
NumRetries::Finite(max_attempts) => {
273+
(1..=*max_attempts)
274+
.find_map(|attempt| {
275+
eprintln!(
276+
"Attempting to fetch LUKS key (attempt {}/{})",
277+
attempt, max_attempts
278+
);
279+
280+
if let Some(key) = try_fetch_from_servers(servers, path, &initdata, executor) {
256281
return Some(Ok(key));
257282
}
258-
Err(e) => {
259-
eprintln!("Error with URL {}: {}", server.url, e);
283+
284+
if attempt < *max_attempts {
285+
eprintln!(
286+
"All URLs failed for attempt {}. Retrying in {:?} seconds...",
287+
attempt, DELAY
288+
);
289+
thread::sleep(DELAY);
260290
}
291+
None
292+
})
293+
.unwrap_or_else(|| {
294+
Err(anyhow!(
295+
"Failed to fetch the LUKS key from all URLs after {} attempts",
296+
max_attempts
297+
))
298+
})
299+
}
300+
NumRetries::Infinity => {
301+
let mut attempt = 0;
302+
loop {
303+
attempt += 1;
304+
eprintln!("Attempting to fetch LUKS key (attempt {})", attempt);
305+
306+
if let Some(key) = try_fetch_from_servers(servers, path, &initdata, executor) {
307+
return Ok(key);
261308
}
262-
}
263309

264-
if attempt < MAX_ATTEMPTS {
265310
eprintln!(
266311
"All URLs failed for attempt {}. Retrying in {:?} seconds...",
267312
attempt, DELAY
268313
);
269314
thread::sleep(DELAY);
270315
}
271-
None
272-
})
273-
.unwrap_or_else(|| {
274-
Err(anyhow!(
275-
"Failed to fetch the LUKS key from all URLs after {} attempts",
276-
MAX_ATTEMPTS
277-
))
278-
})
316+
}
317+
}
279318
}
280319

281320
/// Clevis PIN for Trustee
@@ -323,7 +362,8 @@ mod tests {
323362
cert: String::new(),
324363
}];
325364

326-
let result = fetch_luks_key(&servers, "/test/path", None, &mock);
365+
let num_retries = NumRetries::Finite(3);
366+
let result = fetch_luks_key(&servers, "/test/path", None, &num_retries, &mock);
327367

328368
assert!(result.is_ok());
329369
assert_eq!(result.unwrap(), "test_luks_key_12345");
@@ -340,7 +380,8 @@ mod tests {
340380
cert: String::new(),
341381
}];
342382

343-
let result = fetch_luks_key(&servers, "/test/path", None, &mock);
383+
let num_retries = NumRetries::Finite(3);
384+
let result = fetch_luks_key(&servers, "/test/path", None, &num_retries, &mock);
344385

345386
assert!(result.is_err());
346387
assert_eq!(

lib/src/lib.rs

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,83 @@
33
//
44
// SPDX-License-Identifier: MIT
55

6-
use serde::{Deserialize, Serialize};
6+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
77
use std::collections::HashMap;
88

9+
#[derive(Debug, Clone, PartialEq)]
10+
pub enum NumRetries {
11+
Finite(u32),
12+
Infinity,
13+
}
14+
15+
impl Serialize for NumRetries {
16+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
17+
where
18+
S: Serializer,
19+
{
20+
match self {
21+
NumRetries::Finite(n) => serializer.serialize_u32(*n),
22+
NumRetries::Infinity => serializer.serialize_str("infinity"),
23+
}
24+
}
25+
}
26+
27+
impl<'de> Deserialize<'de> for NumRetries {
28+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
29+
where
30+
D: Deserializer<'de>,
31+
{
32+
struct NumRetriesVisitor;
33+
34+
impl<'de> serde::de::Visitor<'de> for NumRetriesVisitor {
35+
type Value = NumRetries;
36+
37+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
38+
formatter.write_str("a positive number (>= 1) or the string 'infinity'")
39+
}
40+
41+
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
42+
where
43+
E: serde::de::Error,
44+
{
45+
if value == 0 {
46+
return Err(E::custom("number must be at least 1, got: 0"));
47+
}
48+
if value > u32::MAX as u64 {
49+
return Err(E::custom(format!("number too large: {}", value)));
50+
}
51+
Ok(NumRetries::Finite(value as u32))
52+
}
53+
54+
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
55+
where
56+
E: serde::de::Error,
57+
{
58+
if value <= 0 {
59+
return Err(E::custom(format!("number must be at least 1, got: {}", value)));
60+
}
61+
if value > u32::MAX as i64 {
62+
return Err(E::custom(format!("number too large: {}", value)));
63+
}
64+
Ok(NumRetries::Finite(value as u32))
65+
}
66+
67+
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
68+
where
69+
E: serde::de::Error,
70+
{
71+
if value == "infinity" {
72+
Ok(NumRetries::Infinity)
73+
} else {
74+
Err(E::custom(format!("expected 'infinity', got: '{}'", value)))
75+
}
76+
}
77+
}
78+
79+
deserializer.deserialize_any(NumRetriesVisitor)
80+
}
81+
}
82+
983
#[derive(Debug, Serialize, Deserialize, Clone)]
1084
pub struct Server {
1185
pub url: String,
@@ -17,6 +91,7 @@ pub struct Config {
1791
pub servers: Vec<Server>,
1892
pub path: String,
1993
pub initdata: Option<String>,
94+
pub num_retries: Option<NumRetries>,
2095
}
2196

2297
#[derive(Debug, Serialize, Deserialize)]

0 commit comments

Comments
 (0)