1
// Copyright 2020, Collabora Ltd.
2
// SPDX-License-Identifier: MIT OR Apache-2.0
3

            
4
// Structures when querying deployment
5

            
6
use std::path::{Path, PathBuf};
7

            
8
use bytes::Bytes;
9
use futures::{prelude::*, TryStreamExt};
10
use reqwest::header::RANGE;
11
use reqwest::{Client, Response};
12
use serde::de::{Deserializer, Error as _, IgnoredAny, MapAccess, Visitor};
13
use serde::{Deserialize, Serialize};
14

            
15
use tokio::fs::OpenOptions;
16
use tokio::{
17
    fs::{DirBuilder, File},
18
    io::AsyncWriteExt,
19
};
20

            
21
use crate::ddi::client::Error;
22
use crate::ddi::common::{send_feedback_internal, Execution, Finished, Link};
23

            
24
/// Get the file size from metadata in a platform independent way
25
4
fn file_size(metadata: &std::fs::Metadata) -> u64 {
26
    #[cfg(target_family = "unix")]
27
    {
28
        use std::os::unix::fs::MetadataExt;
29
4
        metadata.size()
30
    }
31
    #[cfg(target_family = "windows")]
32
    {
33
        use std::os::windows::fs::MetadataExt;
34
        metadata.file_size()
35
    }
36
4
}
37

            
38
#[derive(Debug)]
39
/// A pending update whose details have not been retrieved yet.
40
///
41
/// Call [`UpdatePreFetch::fetch()`] to retrieve the details from server.
42
pub struct UpdatePreFetch {
43
    client: Client,
44
    url: String,
45
}
46

            
47
impl UpdatePreFetch {
48
24
    pub(crate) fn new(client: Client, url: String) -> Self {
49
24
        Self { client, url }
50
24
    }
51

            
52
    /// Retrieve details about the update.
53
25
    pub async fn fetch(self) -> Result<Update, Error> {
54
10
        let reply = self.client.get(&self.url).send().await?;
55
10
        reply.error_for_status_ref()?;
56

            
57
10
        let reply = reply.json::<Reply>().await?;
58
10
        Ok(Update::new(self.client, reply, self.url))
59
10
    }
60
}
61

            
62
#[derive(Debug, Deserialize)]
63
#[allow(dead_code)]
64
struct Reply {
65
    id: String,
66
    deployment: Deployment,
67
    #[serde(rename = "actionHistory")]
68
    action_history: Option<ActionHistory>,
69
}
70

            
71
#[derive(Debug, Deserialize)]
72
pub(crate) struct Deployment {
73
    download: Type,
74
    update: Type,
75
    #[serde(rename = "maintenanceWindow")]
76
    maintenance_window: Option<MaintenanceWindow>,
77
    pub(crate) chunks: Vec<ChunkInternal>,
78
}
79

            
80
/// How the download or update should be processed by the target.
81
#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq)]
82
#[serde(rename_all = "lowercase")]
83
pub enum Type {
84
    /// Do not process yet
85
    Skip,
86
    /// Server asks to process
87
    Attempt,
88
    /// Server requests immediate processing
89
    Forced,
90
}
91

            
92
/// Separation of download and installation by defining a maintenance window for the installation.
93
#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq)]
94
#[serde(rename_all = "lowercase")]
95
pub enum MaintenanceWindow {
96
    /// Maintenance window is available
97
    Available,
98
    /// Maintenance window is unavailable
99
    Unavailable,
100
}
101

            
102
#[derive(Debug, Deserialize)]
103
pub(crate) struct ChunkInternal {
104
    #[serde(default)]
105
    metadata: Vec<Metadata>,
106
    part: String,
107
    name: String,
108
    version: String,
109
    artifacts: Vec<ArtifactInternal>,
110
}
111

            
112
#[derive(Debug, Deserialize)]
113
struct Metadata {
114
    key: String,
115
    value: String,
116
}
117

            
118
#[derive(Debug, Deserialize)]
119
struct ArtifactInternal {
120
    filename: String,
121
    hashes: Hashes,
122
    size: u32,
123
    #[serde(rename = "_links")]
124
    links: Links,
125
}
126

            
127
#[derive(Debug, Deserialize, Clone)]
128
#[allow(dead_code)]
129
struct Hashes {
130
    sha1: String,
131
    md5: String,
132
    sha256: String,
133
}
134

            
135
impl<'de> Deserialize<'de> for Links {
136
44
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
137
44
    where
138
44
        D: Deserializer<'de>,
139
    {
140
        struct V;
141

            
142
        impl<'de> Visitor<'de> for V {
143
            type Value = Links;
144

            
145
            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
146
                write!(f, "a map")
147
            }
148

            
149
44
            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
150
44
            where
151
44
                A: MapAccess<'de>,
152
            {
153
44
                let mut download: Option<Link> = None;
154
44
                let mut md5sum: Option<Link> = None;
155
44
                let mut download_http: Option<Link> = None;
156
44
                let mut md5sum_http: Option<Link> = None;
157

            
158
164
                while let Some(key) = map.next_key()? {
159
120
                    match key {
160
120
                        "download" => {
161
30
                            download = match download {
162
                                Some(_) => return Err(A::Error::duplicate_field("download")),
163
30
                                None => Some(map.next_value()?),
164
                            };
165
                        }
166
90
                        "md5sum" => {
167
30
                            md5sum = match md5sum {
168
                                Some(_) => return Err(A::Error::duplicate_field("md5sum")),
169
30
                                None => Some(map.next_value()?),
170
                            };
171
                        }
172
60
                        "download-http" => {
173
30
                            download_http = match download_http {
174
                                Some(_) => return Err(A::Error::duplicate_field("download-http")),
175
30
                                None => Some(map.next_value()?),
176
                            };
177
                        }
178
30
                        "md5sum-http" => {
179
30
                            md5sum_http = match md5sum_http {
180
                                Some(_) => return Err(A::Error::duplicate_field("md5sum-http")),
181
30
                                None => Some(map.next_value()?),
182
                            };
183
                        }
184
                        _ => {
185
                            map.next_value::<IgnoredAny>()?;
186
                        }
187
                    }
188
                }
189

            
190
44
                let https = download.map(|content| Download { content, md5sum });
191
44
                let http = download_http.map(|content| Download {
192
30
                    content,
193
30
                    md5sum: md5sum_http,
194
30
                });
195

            
196
44
                if http.is_none() && https.is_none() {
197
                    Err(A::Error::missing_field("download or download-http"))
198
                } else {
199
44
                    Ok(Links { http, https })
200
                }
201
44
            }
202
        }
203

            
204
44
        let visitor = V;
205

            
206
44
        deserializer.deserialize_map(visitor)
207
44
    }
208
}
209

            
210
#[derive(Debug)]
211
#[allow(dead_code)]
212
struct Download {
213
    content: Link,
214
    md5sum: Option<Link>,
215
}
216

            
217
/// Download links a single artifact, at least one of http or https will be
218
/// Some
219
#[derive(Debug)]
220
struct Links {
221
    http: Option<Download>,
222
    https: Option<Download>,
223
}
224

            
225
#[derive(Debug, Deserialize)]
226
#[allow(dead_code)]
227
struct ActionHistory {
228
    status: String,
229
    #[serde(default)]
230
    messages: Vec<String>,
231
}
232

            
233
/// A pending update to deploy.
234
#[derive(Debug)]
235
pub struct Update {
236
    client: Client,
237
    info: Reply,
238
    url: String,
239
}
240

            
241
impl Update {
242
20
    fn new(client: Client, info: Reply, url: String) -> Self {
243
20
        Self { client, info, url }
244
20
    }
245

            
246
    /// The action id of the current update.
247
8
    pub fn action_id(&self) -> &str {
248
8
        &self.info.id
249
8
    }
250

            
251
    /// Handling for the download part of the provisioning process.
252
8
    pub fn download_type(&self) -> Type {
253
8
        self.info.deployment.download
254
8
    }
255

            
256
    /// Handling for the update part of the provisioning process.
257
8
    pub fn update_type(&self) -> Type {
258
8
        self.info.deployment.update
259
8
    }
260

            
261
    /// If set, the update is part of a maintenance window.
262
8
    pub fn maintenance_window(&self) -> Option<MaintenanceWindow> {
263
8
        self.info.deployment.maintenance_window
264
8
    }
265

            
266
    /// An iterator on all the software chunks of the update.
267
20
    pub fn chunks(&self) -> impl Iterator<Item = Chunk<'_>> {
268
20
        let client = self.client.clone();
269

            
270
20
        self.info
271
20
            .deployment
272
20
            .chunks
273
20
            .iter()
274
33
            .map(move |c| Chunk::new(c, client.clone()))
275
20
    }
276

            
277
    /// Download all software chunks to the directory defined in `dir`.
278
    pub async fn download(&self, dir: &Path) -> Result<Vec<DownloadedArtifact>, Error> {
279
        let mut result = Vec::new();
280
        for c in self.chunks() {
281
            let downloaded = c.download(dir).await?;
282
            result.extend(downloaded);
283
        }
284

            
285
        Ok(result)
286
    }
287

            
288
    /// Send feedback to server about this update, with custom progress information.
289
    ///
290
    /// # Arguments
291
    /// * `execution`: status of the action execution.
292
    /// * `finished`: defined status of the result. The action will be kept open on the server until the controller on the device reports either [`Finished::Success`] or [`Finished::Failure`].
293
    /// * `progress`: progress assumption of the device.
294
    /// * `details`: list of details message information.
295
2
    pub async fn send_feedback_with_progress<T: Serialize>(
296
2
        &self,
297
2
        execution: Execution,
298
2
        finished: Finished,
299
2
        progress: T,
300
2
        details: Vec<&str>,
301
2
    ) -> Result<(), Error> {
302
2
        send_feedback_internal(
303
2
            &self.client,
304
2
            &self.url,
305
2
            &self.info.id,
306
2
            execution,
307
2
            finished,
308
2
            Some(progress),
309
2
            details,
310
2
        )
311
2
        .await
312
2
    }
313

            
314
    /// Send feedback to server about this update.
315
    ///
316
    /// Same as [`Update::send_feedback_with_progress`] but without passing custom progress information about the update.
317
4
    pub async fn send_feedback(
318
4
        &self,
319
4
        execution: Execution,
320
4
        finished: Finished,
321
4
        details: Vec<&str>,
322
5
    ) -> Result<(), Error> {
323
2
        send_feedback_internal::<bool>(
324
2
            &self.client,
325
2
            &self.url,
326
2
            &self.info.id,
327
2
            execution,
328
2
            finished,
329
2
            None,
330
2
            details,
331
2
        )
332
2
        .await
333
2
    }
334
}
335

            
336
/// Software chunk of an update.
337
#[derive(Debug)]
338
pub struct Chunk<'a> {
339
    chunk: &'a ChunkInternal,
340
    client: Client,
341
}
342

            
343
impl<'a> Chunk<'a> {
344
72
    pub(crate) fn new(chunk: &'a ChunkInternal, client: Client) -> Self {
345
72
        Self { chunk, client }
346
72
    }
347

            
348
    /// Type of the chunk.
349
12
    pub fn part(&self) -> &str {
350
12
        &self.chunk.part
351
12
    }
352

            
353
    /// Name of the chunk.
354
32
    pub fn name(&self) -> &str {
355
32
        &self.chunk.name
356
32
    }
357

            
358
    /// Software version of the chunk.
359
12
    pub fn version(&self) -> &str {
360
12
        &self.chunk.version
361
12
    }
362

            
363
    /// An iterator on all the artifacts of the chunk.
364
56
    pub fn artifacts(&self) -> impl Iterator<Item = Artifact<'_>> {
365
56
        let client = self.client.clone();
366

            
367
56
        self.chunk
368
56
            .artifacts
369
56
            .iter()
370
70
            .map(move |a| Artifact::new(a, client.clone()))
371
56
    }
372

            
373
    /// An iterator on all the metadata of the chunk.
374
36
    pub fn metadata(&self) -> impl Iterator<Item = (&str, &str)> {
375
36
        self.chunk
376
36
            .metadata
377
36
            .iter()
378
45
            .map(|a| (a.key.as_str(), a.value.as_str()))
379
36
    }
380

            
381
    /// Download all artifacts of the chunk to the directory defined in `dir`.
382
20
    pub async fn download(&'a self, dir: &Path) -> Result<Vec<DownloadedArtifact>, Error> {
383
8
        let mut dir = dir.to_path_buf();
384
8
        dir.push(self.name());
385
8
        let mut result = Vec::new();
386

            
387
8
        for a in self.artifacts() {
388
8
            let downloaded = a.download(&dir).await?;
389
8
            result.push(downloaded);
390
        }
391

            
392
8
        Ok(result)
393
8
    }
394
}
395

            
396
/// A single file part of a [`Chunk`] to download.
397
#[derive(Debug)]
398
pub struct Artifact<'a> {
399
    artifact: &'a ArtifactInternal,
400
    client: Client,
401
}
402

            
403
impl<'a> Artifact<'a> {
404
56
    fn new(artifact: &'a ArtifactInternal, client: Client) -> Self {
405
56
        Self { artifact, client }
406
56
    }
407

            
408
    /// The name of the file.
409
76
    pub fn filename(&self) -> &str {
410
76
        &self.artifact.filename
411
76
    }
412

            
413
    /// The size of the file.
414
16
    pub fn size(&self) -> u32 {
415
16
        self.artifact.size
416
16
    }
417

            
418
55
    async fn download_response(&'a self) -> Result<Response, Error> {
419
22
        let download = self
420
22
            .artifact
421
22
            .links
422
22
            .https
423
22
            .as_ref()
424
22
            .or(self.artifact.links.http.as_ref())
425
22
            .expect("Missing content link in for artifact");
426

            
427
22
        let resp = self.client.get(download.content.to_string()).send().await?;
428

            
429
22
        resp.error_for_status_ref()?;
430
22
        Ok(resp)
431
22
    }
432

            
433
5
    async fn download_response_range(&'a self, offset: u64) -> Result<Response, Error> {
434
2
        let download = self
435
2
            .artifact
436
2
            .links
437
2
            .https
438
2
            .as_ref()
439
2
            .or(self.artifact.links.http.as_ref())
440
2
            .expect("Missing content link in for artifact");
441

            
442
2
        let resp = self
443
2
            .client
444
2
            .get(download.content.to_string())
445
2
            .header(RANGE, format!("bytes={offset}-"))
446
2
            .send()
447
2
            .await?;
448

            
449
2
        resp.error_for_status_ref()?;
450
2
        Ok(resp)
451
2
    }
452

            
453
    /// Download the artifact file to the directory defined in `dir`.
454
25
    pub async fn download(&'a self, dir: &Path) -> Result<DownloadedArtifact, Error> {
455
10
        if !dir.exists() {
456
6
            DirBuilder::new().recursive(true).create(dir).await?;
457
4
        }
458

            
459
        // Check if the file is already there (e.g. downloaded in a previous try)
460
        // and seems to be the file we are expecting.
461
        // In this case we can use this file directly and skip the download.
462
        #[cfg(feature = "hash-sha256")]
463
        {
464
10
            let mut file_name = dir.to_path_buf();
465
10
            file_name.push(self.filename());
466
10
            if tokio::fs::try_exists(&file_name).await? {
467
                // lets check if the files size matches our expectation
468
                let metadata = tokio::fs::metadata(&file_name).await?;
469
                if file_size(&metadata) == self.artifact.size as u64 {
470
                    // lets check if the file hash matches our expectation
471
                    let artifact = DownloadedArtifact::new(file_name, self.artifact.hashes.clone());
472
                    if artifact.check_sha256().await.is_ok() {
473
                        // filename, size and hash are as expected.
474
                        // so we we can assume that the existant file, is the file given in the deployment
475
                        // so we can skip the download and use the file from cache
476
                        return Ok(artifact);
477
                    }
478
                }
479
10
            }
480
        }
481

            
482
        // the file is first downloaded to a .part file in order to
483
        // be able to resume the download in case of a disconnection.
484
        // If a part file already exists, we try to resume the download (if supported by the server).
485
10
        let mut file_name_part = dir.to_path_buf();
486
10
        file_name_part.push(format!("{}.part", self.filename()));
487

            
488
10
        let mut resp = if tokio::fs::try_exists(&file_name_part).await? {
489
            // try to resume the download
490
2
            let metadata = tokio::fs::metadata(&file_name_part).await?;
491
2
            self.download_response_range(file_size(&metadata)).await?
492
        } else {
493
8
            self.download_response().await?
494
        };
495

            
496
10
        let mut dest = if resp.status() == reqwest::StatusCode::PARTIAL_CONTENT {
497
            // the server supports range requests, we can resume the download
498
2
            OpenOptions::new()
499
2
                .append(true)
500
2
                .open(&file_name_part)
501
2
                .await?
502
        } else {
503
8
            File::create(&file_name_part).await?
504
        };
505

            
506
20
        while let Some(chunk) = resp.chunk().await? {
507
10
            dest.write_all(&chunk).await?;
508
        }
509

            
510
10
        let mut file_name = dir.to_path_buf();
511
10
        file_name.push(self.filename());
512

            
513
        // rename the file to remove the .part extension after the download is complete
514
10
        tokio::fs::rename(&file_name_part, &file_name).await?;
515

            
516
10
        Ok(DownloadedArtifact::new(
517
10
            file_name,
518
10
            self.artifact.hashes.clone(),
519
10
        ))
520
10
    }
521

            
522
    /// Provide a `Stream` of `Bytes` to download the artifact.
523
    ///
524
    /// This can be used as an alternative to [`Artifact::download`],
525
    /// for example, to extract an archive while it's being downloaded,
526
    /// saving the need to store the archive file on disk.
527
28
    pub async fn download_stream(
528
28
        &'a self,
529
35
    ) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
530
14
        let resp = self.download_response().await?;
531

            
532
14
        Ok(resp.bytes_stream().map_err(|e| e.into()))
533
14
    }
534

            
535
    /// Provide a `Stream` of `Bytes` to download the artifact while checking md5 checksum.
536
    ///
537
    /// The stream will yield the same data as [`Artifact::download_stream`] but will raise
538
    /// an error if the md5sum of the downloaded data does not match the one provided by the server.
539
    #[cfg(feature = "hash-md5")]
540
8
    pub async fn download_stream_with_md5_check(
541
8
        &'a self,
542
10
    ) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
543
4
        let stream = self.download_stream().await?;
544
4
        let hasher = DownloadHasher::new_md5(self.artifact.hashes.md5.clone());
545

            
546
4
        let stream = DownloadStreamHash {
547
4
            stream: Box::new(stream),
548
4
            hasher,
549
4
        };
550

            
551
4
        Ok(stream)
552
4
    }
553

            
554
    /// Provide a `Stream` of `Bytes` to download the artifact while checking sha1 checksum.
555
    ///
556
    /// The stream will yield the same data as [`Artifact::download_stream`] but will raise
557
    /// an error if the sha1sum of the downloaded data does not match the one provided by the server.
558
    #[cfg(feature = "hash-sha1")]
559
8
    pub async fn download_stream_with_sha1_check(
560
8
        &'a self,
561
10
    ) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
562
4
        let stream = self.download_stream().await?;
563
4
        let hasher = DownloadHasher::new_sha1(self.artifact.hashes.sha1.clone());
564

            
565
4
        let stream = DownloadStreamHash {
566
4
            stream: Box::new(stream),
567
4
            hasher,
568
4
        };
569

            
570
4
        Ok(stream)
571
4
    }
572

            
573
    /// Provide a `Stream` of `Bytes` to download the artifact while checking sha256 checksum.
574
    ///
575
    /// The stream will yield the same data as [`Artifact::download_stream`] but will raise
576
    /// an error if the sha256sum of the downloaded data does not match the one provided by the server.
577
    #[cfg(feature = "hash-sha256")]
578
8
    pub async fn download_stream_with_sha256_check(
579
8
        &'a self,
580
10
    ) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
581
4
        let stream = self.download_stream().await?;
582
4
        let hasher = DownloadHasher::new_sha256(self.artifact.hashes.sha256.clone());
583

            
584
4
        let stream = DownloadStreamHash {
585
4
            stream: Box::new(stream),
586
4
            hasher,
587
4
        };
588

            
589
4
        Ok(stream)
590
4
    }
591
}
592

            
593
/// A downloaded file part of a [`Chunk`].
594
#[allow(dead_code)]
595
#[derive(Debug)]
596
pub struct DownloadedArtifact {
597
    file: PathBuf,
598
    hashes: Hashes,
599
}
600

            
601
cfg_if::cfg_if! {
602
    if #[cfg(feature = "hash-digest")] {
603
        use std::{
604
            pin::Pin,
605
            task::Poll,
606
        };
607
        use digest::Digest;
608

            
609
        const HASH_BUFFER_SIZE: usize = 4096;
610

            
611
        /// Enum representing the different type of supported checksums
612
        #[derive(Debug, strum::Display, Clone)]
613
        pub enum ChecksumType {
614
            /// md5
615
            #[cfg(feature = "hash-md5")]
616
            Md5,
617
            /// sha1
618
            #[cfg(feature = "hash-sha1")]
619
            Sha1,
620
            /// sha256
621
            #[cfg(feature = "hash-sha256")]
622
            Sha256,
623
        }
624

            
625
        // quite complex trait bounds because of requirements so LowerHex is implemented on the output
626
        #[derive(Clone)]
627
        struct DownloadHasher<T>
628
        where
629
            T: Digest,
630
            <T as digest::OutputSizeUser>::OutputSize: core::ops::Add,
631
            <<T as digest::OutputSizeUser>::OutputSize as core::ops::Add>::Output: digest::generic_array::ArrayLength<u8>
632
        {
633
            hasher: T,
634
            expected: String,
635
            error: ChecksumType,
636
        }
637

            
638
        impl<T> DownloadHasher<T>
639
        where
640
            T: Digest,
641
            <T as digest::OutputSizeUser>::OutputSize: core::ops::Add,
642
            <<T as digest::OutputSizeUser>::OutputSize as core::ops::Add>::Output: digest::generic_array::ArrayLength<u8>
643
        {
644
36
            fn update(&mut self, data: impl AsRef<[u8]>) {
645
36
                self.hasher.update(data);
646
36
            }
647

            
648
36
            fn finalize(self) -> Result<(), Error> {
649
36
                let digest = self.hasher.finalize();
650

            
651
36
                if format!("{:x}", digest) == self.expected {
652
24
                    Ok(())
653
                } else {
654
12
                    Err(Error::ChecksumError(self.error))
655
                }
656
36
            }
657
        }
658

            
659
        #[cfg(feature = "hash-md5")]
660
        impl DownloadHasher<md5::Md5> {
661
24
            fn new_md5(expected: String) -> Self {
662
24
                Self {
663
24
                    hasher: md5::Md5::new(),
664
24
                    expected,
665
24
                    error: ChecksumType::Md5,
666
24
                }
667
24
            }
668
        }
669

            
670
        #[cfg(feature = "hash-sha1")]
671
        impl DownloadHasher<sha1::Sha1> {
672
24
            fn new_sha1(expected: String) -> Self {
673
24
                Self {
674
24
                    hasher: sha1::Sha1::new(),
675
24
                    expected,
676
24
                    error: ChecksumType::Sha1,
677
24
                }
678
24
            }
679
        }
680

            
681
        #[cfg(feature = "hash-sha256")]
682
        impl DownloadHasher<sha2::Sha256> {
683
24
            fn new_sha256(expected: String) -> Self {
684
24
                Self {
685
24
                    hasher: sha2::Sha256::new(),
686
24
                    expected,
687
24
                    error: ChecksumType::Sha256,
688
24
                }
689
24
            }
690
        }
691

            
692
        struct DownloadStreamHash<T>
693
        where
694
            T: Digest,
695
            <T as digest::OutputSizeUser>::OutputSize: core::ops::Add,
696
            <<T as digest::OutputSizeUser>::OutputSize as core::ops::Add>::Output: digest::generic_array::ArrayLength<u8>
697
        {
698
            stream: Box<dyn Stream<Item = Result<Bytes, Error>> + Unpin + Send + Sync>,
699
            hasher: DownloadHasher<T>,
700
        }
701

            
702
        impl<T> Stream for DownloadStreamHash<T>
703
        where
704
            T: Digest,
705
            <T as digest::OutputSizeUser>::OutputSize: core::ops::Add,
706
            <<T as digest::OutputSizeUser>::OutputSize as core::ops::Add>::Output: digest::generic_array::ArrayLength<u8>,
707
            T: Unpin,
708
            T: Clone,
709
        {
710
            type Item = Result<Bytes, Error>;
711

            
712
24
            fn poll_next(
713
24
                self: std::pin::Pin<&mut Self>,
714
24
                cx: &mut std::task::Context<'_>,
715
24
            ) -> std::task::Poll<Option<Self::Item>> {
716
24
                let me = Pin::into_inner(self);
717

            
718
24
                match Pin::new(&mut me.stream).poll_next(cx) {
719
12
                    Poll::Ready(Some(Ok(data))) => {
720
                        // feed data to the hasher and then pass them back to the stream
721
12
                        me.hasher.update(&data);
722
12
                        Poll::Ready(Some(Ok(data)))
723
                    }
724
                    Poll::Ready(None) => {
725
                        // download is done, check the hash
726
12
                        match me.hasher.clone().finalize() {
727
6
                            Ok(_) => Poll::Ready(None),
728
6
                            Err(e) => Poll::Ready(Some(Err(e))),
729
                        }
730
                    }
731
                    // passthrough on errors and pendings
732
                    Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
733
                    Poll::Pending => Poll::Pending,
734
                }
735
24
            }
736
        }
737
    }
738
}
739

            
740
impl DownloadedArtifact {
741
20
    fn new(file: PathBuf, hashes: Hashes) -> Self {
742
20
        Self { file, hashes }
743
20
    }
744

            
745
    /// Path of the downloaded file.
746
16
    pub fn file(&self) -> &PathBuf {
747
16
        &self.file
748
16
    }
749

            
750
    #[cfg(feature = "hash-digest")]
751
24
    async fn hash<T>(&self, mut hasher: DownloadHasher<T>) -> Result<(), Error>
752
24
    where
753
24
        T: digest::Digest,
754
24
        <T as digest::OutputSizeUser>::OutputSize: core::ops::Add,
755
24
        <<T as digest::OutputSizeUser>::OutputSize as core::ops::Add>::Output:
756
24
            digest::generic_array::ArrayLength<u8>,
757
24
    {
758
        use tokio::io::AsyncReadExt;
759

            
760
24
        let mut file = File::open(&self.file).await?;
761
24
        let mut buffer = [0; HASH_BUFFER_SIZE];
762

            
763
        loop {
764
48
            let n = file.read(&mut buffer[..]).await?;
765
48
            if n == 0 {
766
24
                break;
767
24
            }
768
24
            hasher.update(&buffer[..n]);
769
        }
770

            
771
24
        hasher.finalize()
772
24
    }
773

            
774
    /// Check if the md5sum of the downloaded file matches the one provided by the server.
775
    #[cfg(feature = "hash-md5")]
776
20
    pub async fn check_md5(&self) -> Result<(), Error> {
777
8
        let hasher = DownloadHasher::new_md5(self.hashes.md5.clone());
778
8
        self.hash(hasher).await
779
8
    }
780

            
781
    /// Check if the sha1sum of the downloaded file matches the one provided by the server.
782
    #[cfg(feature = "hash-sha1")]
783
20
    pub async fn check_sha1(&self) -> Result<(), Error> {
784
8
        let hasher = DownloadHasher::new_sha1(self.hashes.sha1.clone());
785
8
        self.hash(hasher).await
786
8
    }
787

            
788
    /// Check if the sha256sum of the downloaded file matches the one provided by the server.
789
    #[cfg(feature = "hash-sha256")]
790
20
    pub async fn check_sha256(&self) -> Result<(), Error> {
791
8
        let hasher = DownloadHasher::new_sha256(self.hashes.sha256.clone());
792
8
        self.hash(hasher).await
793
8
    }
794
}