sppd_cli/downloader/
file_downloader.rs

1use crate::errors::{AppError, AppResult};
2use crate::models::ProcurementType;
3use crate::utils::{format_duration, mb_from_bytes, round_two_decimals};
4use std::path::Path;
5use std::sync::Arc;
6use std::time::Instant;
7use tokio::fs;
8use tokio::fs::File;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::Semaphore;
11use tokio::task::JoinHandle;
12use tracing::{debug, info, warn};
13
14/// Result type for parallel download tasks.
15/// Returns (filename, success, optional_error_message)
16type DownloadTaskResult = Result<(String, bool, Option<String>), AppError>;
17
18/// Extracts HTTP status code from error message if present.
19///
20/// Looks for the pattern "HTTP {status_code}:" in the error message.
21/// Returns `Some(status_code)` if found, `None` otherwise.
22fn extract_status_code(msg: &str) -> Option<u16> {
23    let prefix = "HTTP ";
24    if let Some(start) = msg.find(prefix) {
25        let start = start + prefix.len();
26        let end = msg[start..].find(':').unwrap_or(msg[start..].len());
27        msg[start..start + end].trim().parse().ok()
28    } else {
29        None
30    }
31}
32
33/// Determines if an error should trigger a retry attempt.
34///
35/// Returns `true` for retryable errors (network errors, timeouts, 5xx HTTP status codes).
36/// Returns `false` for non-retryable errors (4xx client errors, I/O errors, validation errors).
37fn should_retry(error: &AppError) -> bool {
38    match error {
39        AppError::NetworkError(msg) => {
40            // Extract status code from message if present
41            if let Some(status_code) = extract_status_code(msg) {
42                // 4xx = client error, don't retry
43                // 5xx = server error, retry
44                status_code >= 500
45            } else {
46                // No status code means network/timeout error - retry by default
47                // Legacy string matching fallback for older error formats
48                !msg.contains("400")
49                    && !msg.contains("401")
50                    && !msg.contains("403")
51                    && !msg.contains("404")
52                    && !msg.contains("client error")
53            }
54        }
55        AppError::IoError(_) => false,       // Don't retry I/O errors
56        AppError::ParseError(_) => false,    // Don't retry parse errors
57        AppError::UrlError(_) => false,      // Don't retry URL errors
58        AppError::RegexError(_) => false,    // Don't retry regex errors
59        AppError::SelectorError(_) => false, // Don't retry selector errors
60        AppError::PeriodValidationError { .. } => false, // Don't retry validation errors
61        AppError::InvalidInput(_) => false,  // Don't retry invalid input errors
62    }
63}
64
65/// Configuration for retry behavior.
66pub(crate) struct RetryConfig {
67    max_retries: u32,
68    initial_delay_ms: u64,
69    max_delay_ms: u64,
70}
71
72impl Default for RetryConfig {
73    fn default() -> Self {
74        Self {
75            max_retries: 3,
76            initial_delay_ms: 1000,
77            max_delay_ms: 10000,
78        }
79    }
80}
81
82/// Calculates exponential backoff delay in milliseconds.
83///
84/// Formula: `min(initial_delay * 2^attempt, max_delay)`
85fn calculate_backoff(attempt: u32, config: &RetryConfig) -> u64 {
86    let delay = config.initial_delay_ms * 2_u64.pow(attempt);
87    delay.min(config.max_delay_ms)
88}
89
90/// Internal retry function that takes RetryConfig directly.
91pub(crate) async fn download_with_retry_internal(
92    client: &reqwest::Client,
93    url: &str,
94    tmp_path: &Path,
95    file_path: &Path,
96    filename: &str,
97    retry_config: &RetryConfig,
98) -> AppResult<()> {
99    let mut last_error: Option<AppError> = None;
100
101    for attempt in 0..=retry_config.max_retries {
102        match download_single_file(client, url, tmp_path, file_path, filename).await {
103            Ok(()) => return Ok(()),
104            Err(e) => {
105                if attempt < retry_config.max_retries && should_retry(&e) {
106                    let delay_ms = calculate_backoff(attempt, retry_config);
107                    warn!(
108                        filename = filename,
109                        attempt = attempt + 1,
110                        max_retries = retry_config.max_retries + 1,
111                        delay_ms = delay_ms,
112                        error = %e,
113                        "Retrying download after error"
114                    );
115                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
116                    last_error = Some(e);
117                    continue;
118                }
119                return Err(e);
120            }
121        }
122    }
123
124    Err(last_error.unwrap_or_else(|| {
125        AppError::NetworkError(format!(
126            "Download failed after {} retries (no error recorded)",
127            retry_config.max_retries + 1
128        ))
129    }))
130}
131
132/// Downloads a single ZIP file.
133///
134/// This is a helper function that performs the download of a single file,
135/// used by `download_files` to enable error collection and continuation.
136async fn download_single_file(
137    client: &reqwest::Client,
138    url: &str,
139    tmp_path: &Path,
140    file_path: &Path,
141    filename: &str,
142) -> AppResult<()> {
143    // Send request and handle send errors (network/timeout errors)
144    let response = client.get(url).send().await.map_err(|e| {
145        // For send errors, these are typically network/timeout errors (retryable)
146        AppError::NetworkError(format!("Failed to download {filename}: {e}"))
147    })?;
148
149    // Check status before error_for_status (which converts 4xx/5xx to errors)
150    let status = response.status();
151    let mut response = response.error_for_status().map_err(|e| {
152        // Include status code in error message for retry logic
153        let status_code = status.as_u16();
154        AppError::NetworkError(format!(
155            "HTTP {status_code}: Failed to download {filename}: {e}"
156        ))
157    })?;
158
159    let mut file = File::create(tmp_path).await.map_err(|e| {
160        AppError::IoError(format!(
161            "Failed to create temp file {}: {}",
162            tmp_path.display(),
163            e
164        ))
165    })?;
166
167    while let Some(chunk) = response.chunk().await? {
168        file.write_all(&chunk).await.map_err(|e| {
169            AppError::IoError(format!(
170                "Failed to write to temp file {}: {}",
171                tmp_path.display(),
172                e
173            ))
174        })?;
175    }
176
177    // Ensure the file is closed before renaming
178    drop(file);
179
180    // Atomically move the temp file to the final destination
181    fs::rename(tmp_path, file_path).await.map_err(|e| {
182        AppError::IoError(format!(
183            "Failed to rename temp file {} to {}: {}",
184            tmp_path.display(),
185            file_path.display(),
186            e
187        ))
188    })?;
189
190    Ok(())
191}
192
193/// Downloads ZIP files to the appropriate directory based on procurement type.
194///
195/// This function downloads ZIP files from the provided URLs to the directory
196/// specified by the procurement type (e.g., `data/tmp/mc` or `data/tmp/pt`).
197///
198/// # Behavior
199///
200/// - **Atomic downloads**: Files are downloaded to temporary `.part` files and
201///   atomically renamed when complete, preventing partial downloads.
202/// - **Skip existing**: Files that already exist are automatically skipped.
203/// - **Progress tracking**: Elapsed time and throughput are logged after downloads complete.
204///
205/// # Arguments
206///
207/// * `client` - HTTP client for making requests
208/// * `filtered_links` - Map of period strings to download URLs (typically from
209///   `filter_periods_by_range()`)
210/// * `proc_type` - Procurement type determining the download directory
211///
212/// # Errors
213///
214/// Returns an error if:
215/// - Directory creation fails
216/// - Network requests fail
217/// - File I/O operations fail
218///
219pub async fn download_files(
220    client: &reqwest::Client,
221    filtered_links: &std::collections::BTreeMap<String, String>,
222    proc_type: &ProcurementType,
223    config: &crate::config::ResolvedConfig,
224) -> AppResult<()> {
225    let download_dir = proc_type.download_dir(config);
226    // Create directory if it doesn't exist
227    if !download_dir.exists() {
228        fs::create_dir_all(&download_dir)
229            .await
230            .map_err(|e| AppError::IoError(format!("Failed to create directory: {e}")))?;
231    }
232
233    // Count files that need downloading (excluding existing ones)
234    // Collect as owned values to avoid lifetime issues with spawned tasks
235    let files_to_download: Vec<(String, String)> = filtered_links
236        .iter()
237        .filter(|(period, _)| {
238            let file_path = download_dir.join(format!("{period}.zip"));
239            !file_path.exists()
240        })
241        .map(|(period, url)| (period.clone(), url.clone()))
242        .collect();
243
244    let total_files = files_to_download.len();
245    let skipped_count = filtered_links.len() - total_files;
246
247    if total_files == 0 {
248        info!(
249            count = filtered_links.len(),
250            "All files already exist, skipping downloads"
251        );
252        return Ok(());
253    }
254
255    let start = Instant::now();
256    let mut total_bytes = 0u64;
257
258    // Create semaphore to limit concurrent downloads
259    let concurrent_downloads = config.concurrent_downloads;
260    let semaphore = Arc::new(Semaphore::new(concurrent_downloads));
261    let client = Arc::new(client.clone());
262    let download_dir_path = download_dir.clone();
263    let download_dir_arc = Arc::new(download_dir_path);
264
265    // Extract retry config values before moving into async blocks
266    let retry_max_retries = config.max_retries;
267    let retry_initial_delay_ms = config.retry_initial_delay_ms;
268    let retry_max_delay_ms = config.retry_max_delay_ms;
269
270    // Pre-allocate errors Vec (usually small, but could accumulate)
271    let mut errors = Vec::with_capacity(10);
272    let mut success_count = 0;
273
274    // Spawn download tasks with bounded concurrency
275    let mut handles: Vec<JoinHandle<DownloadTaskResult>> = Vec::with_capacity(total_files);
276
277    for (period, url) in files_to_download.iter() {
278        let filename = format!("{period}.zip");
279
280        // Clone Arc references and owned values for the task
281        let semaphore = semaphore.clone();
282        let client = client.clone();
283        let download_dir = download_dir_arc.clone();
284        let period = period.clone();
285        let url = url.clone();
286        let filename_for_task = filename.clone();
287
288        // Clone retry config values for this task
289        let max_retries = retry_max_retries;
290        let initial_delay_ms = retry_initial_delay_ms;
291        let max_delay_ms = retry_max_delay_ms;
292
293        // Spawn task that will acquire semaphore permit before downloading
294        let handle = tokio::spawn(async move {
295            // Create paths inside the task
296            let file_path = download_dir.join(&filename_for_task);
297            let tmp_path = download_dir.join(format!("{period}.zip.part"));
298
299            // Acquire permit (will wait if 4 downloads are already in progress)
300            let _permit = semaphore.acquire().await.map_err(|e| {
301                AppError::IoError(format!("Failed to acquire semaphore permit: {e}"))
302            })?;
303
304            // Remove stale tmp file if present (best-effort)
305            if tmp_path.exists() {
306                if let Err(e) = fs::remove_file(&tmp_path).await {
307                    warn!(
308                        file_path = %tmp_path.display(),
309                        error = %e,
310                        "Failed to remove stale temp file"
311                    );
312                }
313            }
314
315            // Attempt download with retry logic
316            // Create RetryConfig from cloned values
317            let retry_config = RetryConfig {
318                max_retries,
319                initial_delay_ms,
320                max_delay_ms,
321            };
322
323            let result = download_with_retry_internal(
324                &client,
325                &url,
326                &tmp_path,
327                &file_path,
328                &filename_for_task,
329                &retry_config,
330            )
331            .await;
332
333            // Handle download result and collect errors
334            match &result {
335                Ok(_) => Ok((filename_for_task, true, None)),
336                Err(e) => {
337                    let error_msg = format!("Failed to download {filename_for_task}: {e}");
338                    warn!(
339                        filename = filename_for_task,
340                        error = %e,
341                        "Failed to download file"
342                    );
343                    Ok((filename_for_task, false, Some(error_msg)))
344                }
345            }
346        });
347
348        handles.push(handle);
349    }
350
351    // Await all tasks and collect results
352    for handle in handles {
353        match handle.await {
354            Ok(Ok((filename, success, error_msg))) => {
355                if success {
356                    success_count += 1;
357                    let file_path = download_dir.join(&filename);
358                    match fs::metadata(&file_path).await {
359                        Ok(metadata) => total_bytes += metadata.len(),
360                        Err(e) => warn!(
361                            file = %file_path.display(),
362                            error = %e,
363                            "Failed to read downloaded file metadata"
364                        ),
365                    }
366                } else if let Some(msg) = error_msg {
367                    errors.push(msg);
368                }
369            }
370            Ok(Err(e)) => {
371                errors.push(format!("Task error: {e}"));
372            }
373            Err(e) => {
374                errors.push(format!("Task join error: {e}"));
375            }
376        }
377    }
378
379    let elapsed = start.elapsed();
380    let elapsed_str = format_duration(elapsed);
381    let total_mb = mb_from_bytes(total_bytes);
382    let throughput = if elapsed.as_secs_f64() > 0.0 {
383        total_mb / elapsed.as_secs_f64()
384    } else {
385        total_mb
386    };
387    let size_mb = round_two_decimals(total_mb);
388    let throughput_mb_s = round_two_decimals(throughput);
389
390    if errors.is_empty() {
391        info!(
392            downloaded = success_count,
393            skipped = skipped_count,
394            failed = 0,
395            elapsed = elapsed_str,
396            size_mb = size_mb,
397            throughput_mb_s = throughput_mb_s,
398            "Download completed"
399        );
400    } else {
401        info!(
402            downloaded = success_count,
403            failed = errors.len(),
404            skipped = skipped_count,
405            elapsed = elapsed_str,
406            size_mb = size_mb,
407            throughput_mb_s = throughput_mb_s,
408            "Download completed with errors"
409        );
410    }
411
412    if skipped_count > 0 {
413        debug!(skipped = skipped_count, "Skipped existing files");
414    }
415
416    // Return error if any downloads failed
417    if !errors.is_empty() {
418        return Err(AppError::NetworkError(format!(
419            "Failed to download {} file(s): {}",
420            errors.len(),
421            errors.join("; ")
422        )));
423    }
424
425    Ok(())
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn extract_status_code_no_prefix() {
434        assert!(extract_status_code("network error").is_none());
435    }
436
437    #[test]
438    fn extract_status_code_with_http() {
439        assert_eq!(extract_status_code("HTTP 404: not found"), Some(404));
440        assert_eq!(extract_status_code("HTTP 500: oh no"), Some(500));
441    }
442
443    #[test]
444    fn should_retry_network_5xx() {
445        let err = AppError::NetworkError("HTTP 500: server".to_string());
446        assert!(should_retry(&err));
447    }
448
449    #[test]
450    fn should_not_retry_network_4xx() {
451        let err = AppError::NetworkError("HTTP 404: client".to_string());
452        assert!(!should_retry(&err));
453    }
454
455    #[test]
456    fn should_not_retry_io_error() {
457        let err = AppError::IoError("disk full".to_string());
458        assert!(!should_retry(&err));
459    }
460
461    #[test]
462    fn calculate_backoff_capped() {
463        let config = RetryConfig::default();
464        assert_eq!(calculate_backoff(0, &config), 1000);
465        assert_eq!(calculate_backoff(1, &config), 2000);
466        assert_eq!(calculate_backoff(10, &config), 10000);
467    }
468}