1use crate::errors::{AppError, AppResult};
2use crate::models::ProcurementType;
3use crate::utils::{format_duration, mb_from_bytes, round_two_decimals};
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Instant;
7use tokio::fs;
8use tokio::fs::File;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::Semaphore;
11use tokio::task::JoinHandle;
12use tracing::{debug, info, warn};
13
14type DownloadTaskResult = Result<(String, bool, Option<String>), AppError>;
17
18fn extract_status_code(msg: &str) -> Option<u16> {
23 let prefix = "HTTP ";
24 if let Some(start) = msg.find(prefix) {
25 let start = start + prefix.len();
26 let end = msg[start..].find(':').unwrap_or(msg[start..].len());
27 msg[start..start + end].trim().parse().ok()
28 } else {
29 None
30 }
31}
32
33fn should_retry(error: &AppError) -> bool {
38 match error {
39 AppError::NetworkError(msg) => {
40 if let Some(status_code) = extract_status_code(msg) {
42 status_code >= 500
45 } else {
46 !msg.contains("400")
49 && !msg.contains("401")
50 && !msg.contains("403")
51 && !msg.contains("404")
52 && !msg.contains("client error")
53 }
54 }
55 AppError::IoError(_) => false, AppError::ParseError(_) => false, AppError::UrlError(_) => false, AppError::RegexError(_) => false, AppError::SelectorError(_) => false, AppError::PeriodValidationError { .. } => false, AppError::InvalidInput(_) => false, }
63}
64
65pub(crate) struct RetryConfig {
67 max_retries: u32,
68 initial_delay_ms: u64,
69 max_delay_ms: u64,
70}
71
72impl Default for RetryConfig {
73 fn default() -> Self {
74 Self {
75 max_retries: 3,
76 initial_delay_ms: 1000,
77 max_delay_ms: 10000,
78 }
79 }
80}
81
82fn calculate_backoff(attempt: u32, config: &RetryConfig) -> u64 {
86 let delay = config.initial_delay_ms * 2_u64.pow(attempt);
87 delay.min(config.max_delay_ms)
88}
89
90pub(crate) async fn download_with_retry_internal(
92 client: &reqwest::Client,
93 url: &str,
94 tmp_path: &Path,
95 file_path: &Path,
96 filename: &str,
97 retry_config: &RetryConfig,
98) -> AppResult<()> {
99 let mut last_error: Option<AppError> = None;
100
101 for attempt in 0..=retry_config.max_retries {
102 match download_single_file(client, url, tmp_path, file_path, filename).await {
103 Ok(()) => return Ok(()),
104 Err(e) => {
105 if attempt < retry_config.max_retries && should_retry(&e) {
106 let delay_ms = calculate_backoff(attempt, retry_config);
107 warn!(
108 filename = filename,
109 attempt = attempt + 1,
110 max_retries = retry_config.max_retries + 1,
111 delay_ms = delay_ms,
112 error = %e,
113 "Retrying download after error"
114 );
115 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
116 last_error = Some(e);
117 continue;
118 }
119 return Err(e);
120 }
121 }
122 }
123
124 Err(last_error.unwrap_or_else(|| {
125 AppError::NetworkError(format!(
126 "Download failed after {} retries (no error recorded)",
127 retry_config.max_retries + 1
128 ))
129 }))
130}
131
132async fn download_single_file(
137 client: &reqwest::Client,
138 url: &str,
139 tmp_path: &Path,
140 file_path: &Path,
141 filename: &str,
142) -> AppResult<()> {
143 let response = client.get(url).send().await.map_err(|e| {
145 AppError::NetworkError(format!("Failed to download {filename}: {e}"))
147 })?;
148
149 let status = response.status();
151 let mut response = response.error_for_status().map_err(|e| {
152 let status_code = status.as_u16();
154 AppError::NetworkError(format!(
155 "HTTP {status_code}: Failed to download {filename}: {e}"
156 ))
157 })?;
158
159 let mut file = File::create(tmp_path).await.map_err(|e| {
160 AppError::IoError(format!(
161 "Failed to create temp file {}: {}",
162 tmp_path.display(),
163 e
164 ))
165 })?;
166
167 while let Some(chunk) = response.chunk().await? {
168 file.write_all(&chunk).await.map_err(|e| {
169 AppError::IoError(format!(
170 "Failed to write to temp file {}: {}",
171 tmp_path.display(),
172 e
173 ))
174 })?;
175 }
176
177 drop(file);
179
180 fs::rename(tmp_path, file_path).await.map_err(|e| {
182 AppError::IoError(format!(
183 "Failed to rename temp file {} to {}: {}",
184 tmp_path.display(),
185 file_path.display(),
186 e
187 ))
188 })?;
189
190 Ok(())
191}
192
193pub async fn download_files(
220 client: &reqwest::Client,
221 filtered_links: &std::collections::BTreeMap<String, String>,
222 proc_type: &ProcurementType,
223 config: &crate::config::ResolvedConfig,
224) -> AppResult<()> {
225 let download_dir = proc_type.download_dir(config);
226 if !download_dir.exists() {
228 fs::create_dir_all(&download_dir)
229 .await
230 .map_err(|e| AppError::IoError(format!("Failed to create directory: {e}")))?;
231 }
232
233 let files_to_download: Vec<(String, String)> = filtered_links
236 .iter()
237 .filter(|(period, _)| {
238 let file_path = download_dir.join(format!("{period}.zip"));
239 !file_path.exists()
240 })
241 .map(|(period, url)| (period.clone(), url.clone()))
242 .collect();
243
244 let total_files = files_to_download.len();
245 let skipped_count = filtered_links.len() - total_files;
246
247 if total_files == 0 {
248 info!(
249 count = filtered_links.len(),
250 "All files already exist, skipping downloads"
251 );
252 return Ok(());
253 }
254
255 let start = Instant::now();
256 let mut total_bytes = 0u64;
257
258 let concurrent_downloads = config.concurrent_downloads;
260 let semaphore = Arc::new(Semaphore::new(concurrent_downloads));
261 let client = Arc::new(client.clone());
262 let download_dir_path = download_dir.clone();
263 let download_dir_arc = Arc::new(download_dir_path);
264
265 let retry_max_retries = config.max_retries;
267 let retry_initial_delay_ms = config.retry_initial_delay_ms;
268 let retry_max_delay_ms = config.retry_max_delay_ms;
269
270 let mut errors = Vec::with_capacity(10);
272 let mut success_count = 0;
273
274 let mut handles: Vec<JoinHandle<DownloadTaskResult>> = Vec::with_capacity(total_files);
276
277 for (period, url) in files_to_download.iter() {
278 let filename = format!("{period}.zip");
279
280 let semaphore = semaphore.clone();
282 let client = client.clone();
283 let download_dir = download_dir_arc.clone();
284 let period = period.clone();
285 let url = url.clone();
286 let filename_for_task = filename.clone();
287
288 let max_retries = retry_max_retries;
290 let initial_delay_ms = retry_initial_delay_ms;
291 let max_delay_ms = retry_max_delay_ms;
292
293 let handle = tokio::spawn(async move {
295 let file_path = download_dir.join(&filename_for_task);
297 let tmp_path = download_dir.join(format!("{period}.zip.part"));
298
299 let _permit = semaphore.acquire().await.map_err(|e| {
301 AppError::IoError(format!("Failed to acquire semaphore permit: {e}"))
302 })?;
303
304 if tmp_path.exists() {
306 if let Err(e) = fs::remove_file(&tmp_path).await {
307 warn!(
308 file_path = %tmp_path.display(),
309 error = %e,
310 "Failed to remove stale temp file"
311 );
312 }
313 }
314
315 let retry_config = RetryConfig {
318 max_retries,
319 initial_delay_ms,
320 max_delay_ms,
321 };
322
323 let result = download_with_retry_internal(
324 &client,
325 &url,
326 &tmp_path,
327 &file_path,
328 &filename_for_task,
329 &retry_config,
330 )
331 .await;
332
333 match &result {
335 Ok(_) => Ok((filename_for_task, true, None)),
336 Err(e) => {
337 let error_msg = format!("Failed to download {filename_for_task}: {e}");
338 warn!(
339 filename = filename_for_task,
340 error = %e,
341 "Failed to download file"
342 );
343 Ok((filename_for_task, false, Some(error_msg)))
344 }
345 }
346 });
347
348 handles.push(handle);
349 }
350
351 for handle in handles {
353 match handle.await {
354 Ok(Ok((filename, success, error_msg))) => {
355 if success {
356 success_count += 1;
357 let file_path = download_dir.join(&filename);
358 match fs::metadata(&file_path).await {
359 Ok(metadata) => total_bytes += metadata.len(),
360 Err(e) => warn!(
361 file = %file_path.display(),
362 error = %e,
363 "Failed to read downloaded file metadata"
364 ),
365 }
366 } else if let Some(msg) = error_msg {
367 errors.push(msg);
368 }
369 }
370 Ok(Err(e)) => {
371 errors.push(format!("Task error: {e}"));
372 }
373 Err(e) => {
374 errors.push(format!("Task join error: {e}"));
375 }
376 }
377 }
378
379 let elapsed = start.elapsed();
380 let elapsed_str = format_duration(elapsed);
381 let total_mb = mb_from_bytes(total_bytes);
382 let throughput = if elapsed.as_secs_f64() > 0.0 {
383 total_mb / elapsed.as_secs_f64()
384 } else {
385 total_mb
386 };
387 let size_mb = round_two_decimals(total_mb);
388 let throughput_mb_s = round_two_decimals(throughput);
389
390 if errors.is_empty() {
391 info!(
392 downloaded = success_count,
393 skipped = skipped_count,
394 failed = 0,
395 elapsed = elapsed_str,
396 size_mb = size_mb,
397 throughput_mb_s = throughput_mb_s,
398 "Download completed"
399 );
400 } else {
401 info!(
402 downloaded = success_count,
403 failed = errors.len(),
404 skipped = skipped_count,
405 elapsed = elapsed_str,
406 size_mb = size_mb,
407 throughput_mb_s = throughput_mb_s,
408 "Download completed with errors"
409 );
410 }
411
412 if skipped_count > 0 {
413 debug!(skipped = skipped_count, "Skipped existing files");
414 }
415
416 if !errors.is_empty() {
418 return Err(AppError::NetworkError(format!(
419 "Failed to download {} file(s): {}",
420 errors.len(),
421 errors.join("; ")
422 )));
423 }
424
425 Ok(())
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn extract_status_code_no_prefix() {
434 assert!(extract_status_code("network error").is_none());
435 }
436
437 #[test]
438 fn extract_status_code_with_http() {
439 assert_eq!(extract_status_code("HTTP 404: not found"), Some(404));
440 assert_eq!(extract_status_code("HTTP 500: oh no"), Some(500));
441 }
442
443 #[test]
444 fn should_retry_network_5xx() {
445 let err = AppError::NetworkError("HTTP 500: server".to_string());
446 assert!(should_retry(&err));
447 }
448
449 #[test]
450 fn should_not_retry_network_4xx() {
451 let err = AppError::NetworkError("HTTP 404: client".to_string());
452 assert!(!should_retry(&err));
453 }
454
455 #[test]
456 fn should_not_retry_io_error() {
457 let err = AppError::IoError("disk full".to_string());
458 assert!(!should_retry(&err));
459 }
460
461 #[test]
462 fn calculate_backoff_capped() {
463 let config = RetryConfig::default();
464 assert_eq!(calculate_backoff(0, &config), 1000);
465 assert_eq!(calculate_backoff(1, &config), 2000);
466 assert_eq!(calculate_backoff(10, &config), 10000);
467 }
468}