xref: /qemu/tests/functional/qemu_test/asset.py (revision 28ea66f6f9856c398afa75f2cabb1f21c8b04208)
1# Test utilities for fetching & caching assets
2#
3# Copyright 2024 Red Hat, Inc.
4#
5# This work is licensed under the terms of the GNU GPL, version 2 or
6# later.  See the COPYING file in the top-level directory.
7
8import hashlib
9import logging
10import os
11import stat
12import sys
13import unittest
14import urllib.request
15from time import sleep
16from pathlib import Path
17from shutil import copyfileobj
18from urllib.error import HTTPError
19
20class AssetError(Exception):
21    def __init__(self, asset, msg, transient=False):
22        self.url = asset.url
23        self.msg = msg
24        self.transient = transient
25
26    def __str__(self):
27        return "%s: %s" % (self.url, self.msg)
28
29# Instances of this class must be declared as class level variables
30# starting with a name "ASSET_". This enables the pre-caching logic
31# to easily find all referenced assets and download them prior to
32# execution of the tests.
33class Asset:
34
35    def __init__(self, url, hashsum):
36        self.url = url
37        self.hash = hashsum
38        cache_dir_env = os.getenv('QEMU_TEST_CACHE_DIR')
39        if cache_dir_env:
40            self.cache_dir = Path(cache_dir_env, "download")
41        else:
42            self.cache_dir = Path(Path("~").expanduser(),
43                                  ".cache", "qemu", "download")
44        self.cache_file = Path(self.cache_dir, hashsum)
45        self.log = logging.getLogger('qemu-test')
46
47    def __repr__(self):
48        return "Asset: url=%s hash=%s cache=%s" % (
49            self.url, self.hash, self.cache_file)
50
51    def __str__(self):
52        return str(self.cache_file)
53
54    def _check(self, cache_file):
55        if self.hash is None:
56            return True
57        if len(self.hash) == 64:
58            hl = hashlib.sha256()
59        elif len(self.hash) == 128:
60            hl = hashlib.sha512()
61        else:
62            raise AssetError(self, "unknown hash type")
63
64        # Calculate the hash of the file:
65        with open(cache_file, 'rb') as file:
66            while True:
67                chunk = file.read(1 << 20)
68                if not chunk:
69                    break
70                hl.update(chunk)
71
72        return self.hash == hl.hexdigest()
73
74    def valid(self):
75        return self.cache_file.exists() and self._check(self.cache_file)
76
77    def fetchable(self):
78        return not os.environ.get("QEMU_TEST_NO_DOWNLOAD", False)
79
80    def available(self):
81        return self.valid() or self.fetchable()
82
83    def _wait_for_other_download(self, tmp_cache_file):
84        # Another thread already seems to download the asset, so wait until
85        # it is done, while also checking the size to see whether it is stuck
86        try:
87            current_size = tmp_cache_file.stat().st_size
88            new_size = current_size
89        except:
90            if os.path.exists(self.cache_file):
91                return True
92            raise
93        waittime = lastchange = 600
94        while waittime > 0:
95            sleep(1)
96            waittime -= 1
97            try:
98                new_size = tmp_cache_file.stat().st_size
99            except:
100                if os.path.exists(self.cache_file):
101                    return True
102                raise
103            if new_size != current_size:
104                lastchange = waittime
105                current_size = new_size
106            elif lastchange - waittime > 90:
107                return False
108
109        self.log.debug("Time out while waiting for %s!", tmp_cache_file)
110        raise
111
112    def fetch(self):
113        if not self.cache_dir.exists():
114            self.cache_dir.mkdir(parents=True, exist_ok=True)
115
116        if self.valid():
117            self.log.debug("Using cached asset %s for %s",
118                           self.cache_file, self.url)
119            return str(self.cache_file)
120
121        if not self.fetchable():
122            raise AssetError(self,
123                             "Asset cache is invalid and downloads disabled")
124
125        self.log.info("Downloading %s to %s...", self.url, self.cache_file)
126        tmp_cache_file = self.cache_file.with_suffix(".download")
127
128        for retries in range(3):
129            try:
130                with tmp_cache_file.open("xb") as dst:
131                    with urllib.request.urlopen(self.url) as resp:
132                        copyfileobj(resp, dst)
133                        length_hdr = resp.getheader("Content-Length")
134
135                # Verify downloaded file size against length metadata, if
136                # available.
137                if length_hdr is not None:
138                    length = int(length_hdr)
139                    fsize = tmp_cache_file.stat().st_size
140                    if fsize != length:
141                        self.log.error("Unable to download %s: "
142                                       "connection closed before "
143                                       "transfer complete (%d/%d)",
144                                       self.url, fsize, length)
145                        tmp_cache_file.unlink()
146                        continue
147                break
148            except FileExistsError:
149                self.log.debug("%s already exists, "
150                               "waiting for other thread to finish...",
151                               tmp_cache_file)
152                if self._wait_for_other_download(tmp_cache_file):
153                    return str(self.cache_file)
154                self.log.debug("%s seems to be stale, "
155                               "deleting and retrying download...",
156                               tmp_cache_file)
157                tmp_cache_file.unlink()
158                continue
159            except HTTPError as e:
160                tmp_cache_file.unlink()
161                self.log.error("Unable to download %s: HTTP error %d",
162                               self.url, e.code)
163                # Treat 404 as fatal, since it is highly likely to
164                # indicate a broken test rather than a transient
165                # server or networking problem
166                if e.code == 404:
167                    raise AssetError(self, "Unable to download: "
168                                     "HTTP error %d" % e.code)
169                continue
170            except Exception as e:
171                tmp_cache_file.unlink()
172                raise AssetError(self, "Unable to download: " % e)
173
174        if not os.path.exists(tmp_cache_file):
175            raise AssetError(self, "Download retries exceeded", transient=True)
176
177        try:
178            # Set these just for informational purposes
179            os.setxattr(str(tmp_cache_file), "user.qemu-asset-url",
180                        self.url.encode('utf8'))
181            os.setxattr(str(tmp_cache_file), "user.qemu-asset-hash",
182                        self.hash.encode('utf8'))
183        except Exception as e:
184            self.log.debug("Unable to set xattr on %s: %s", tmp_cache_file, e)
185            pass
186
187        if not self._check(tmp_cache_file):
188            tmp_cache_file.unlink()
189            raise AssetError(self, "Hash does not match %s" % self.hash)
190        tmp_cache_file.replace(self.cache_file)
191        # Remove write perms to stop tests accidentally modifying them
192        os.chmod(self.cache_file, stat.S_IRUSR | stat.S_IRGRP)
193
194        self.log.info("Cached %s at %s" % (self.url, self.cache_file))
195        return str(self.cache_file)
196
197    def precache_test(test):
198        log = logging.getLogger('qemu-test')
199        log.setLevel(logging.DEBUG)
200        handler = logging.StreamHandler(sys.stdout)
201        handler.setLevel(logging.DEBUG)
202        formatter = logging.Formatter(
203            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
204        handler.setFormatter(formatter)
205        log.addHandler(handler)
206        for name, asset in vars(test.__class__).items():
207            if name.startswith("ASSET_") and type(asset) == Asset:
208                log.info("Attempting to cache '%s'" % asset)
209                try:
210                    asset.fetch()
211                except AssetError as e:
212                    if not e.transient:
213                        raise
214                    log.error("%s: skipping asset precache" % e)
215
216        log.removeHandler(handler)
217
218    def precache_suite(suite):
219        for test in suite:
220            if isinstance(test, unittest.TestSuite):
221                Asset.precache_suite(test)
222            elif isinstance(test, unittest.TestCase):
223                Asset.precache_test(test)
224
225    def precache_suites(path, cacheTstamp):
226        loader = unittest.loader.defaultTestLoader
227        tests = loader.loadTestsFromNames([path], None)
228
229        with open(cacheTstamp, "w") as fh:
230            Asset.precache_suite(tests)
231