123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364 |
- '''
- Helper module to download extra data from Internet
- '''
- from __future__ import print_function
- import os
- import cv2
- import sys
- import yaml
- import argparse
- import tarfile
- import platform
- import tempfile
- import hashlib
- import requests
- import shutil
- from pathlib import Path
- from datetime import datetime
- if sys.version_info[0] < 3:
- from urllib2 import urlopen
- else:
- from urllib.request import urlopen
- import xml.etree.ElementTree as ET
- __all__ = ["downloadFile"]
- class HashMismatchException(Exception):
- def __init__(self, expected, actual):
- Exception.__init__(self)
- self.expected = expected
- self.actual = actual
- def __str__(self):
- return 'Hash mismatch: expected {} vs actual of {}'.format(self.expected, self.actual)
- def getHashsumFromFile(filepath):
- sha = hashlib.sha1()
- if os.path.exists(filepath):
- print(' there is already a file with the same name')
- with open(filepath, 'rb') as f:
- while True:
- buf = f.read(10*1024*1024)
- if not buf:
- break
- sha.update(buf)
- hashsum = sha.hexdigest()
- return hashsum
- def checkHashsum(expected_sha, filepath, silent=True):
- print(' expected SHA1: {}'.format(expected_sha))
- actual_sha = getHashsumFromFile(filepath)
- print(' actual SHA1:{}'.format(actual_sha))
- hashes_matched = expected_sha == actual_sha
- if not hashes_matched and not silent:
- raise HashMismatchException(expected_sha, actual_sha)
- return hashes_matched
- def isArchive(filepath):
- return tarfile.is_tarfile(filepath)
- class DownloadInstance:
- def __init__(self, **kwargs):
- self.name = kwargs.pop('name')
- self.filename = kwargs.pop('filename')
- self.loader = kwargs.pop('loader', None)
- self.save_dir = kwargs.pop('save_dir')
- self.sha = kwargs.pop('sha', None)
- def __str__(self):
- return 'DownloadInstance <{}>'.format(self.name)
- def get(self):
- print(" Working on " + self.name)
- print(" Getting file " + self.filename)
- if self.sha is None:
- print(' No expected hashsum provided, loading file')
- else:
- filepath = os.path.join(self.save_dir, self.sha, self.filename)
- if checkHashsum(self.sha, filepath):
- print(' hash match - file already exists, skipping')
- return filepath
- else:
- print(' hash didn\'t match, loading file')
- if not os.path.exists(self.save_dir):
- print(' creating directory: ' + self.save_dir)
- os.makedirs(self.save_dir)
- print(' hash check failed - loading')
- assert self.loader
- try:
- self.loader.load(self.filename, self.sha, self.save_dir)
- print(' done')
- print(' file {}'.format(self.filename))
- if self.sha is None:
- download_path = os.path.join(self.save_dir, self.filename)
- self.sha = getHashsumFromFile(download_path)
- new_dir = os.path.join(self.save_dir, self.sha)
- if not os.path.exists(new_dir):
- os.makedirs(new_dir)
- filepath = os.path.join(new_dir, self.filename)
- if not (os.path.exists(filepath)):
- shutil.move(download_path, new_dir)
- print(' No expected hashsum provided, actual SHA is {}'.format(self.sha))
- else:
- checkHashsum(self.sha, filepath, silent=False)
- except Exception as e:
- print(" There was some problem with loading file {} for {}".format(self.filename, self.name))
- print(" Exception: {}".format(e))
- return
- print(" Finished " + self.name)
- return filepath
- class Loader(object):
- MB = 1024*1024
- BUFSIZE = 10*MB
- def __init__(self, download_name, download_sha, archive_member = None):
- self.download_name = download_name
- self.download_sha = download_sha
- self.archive_member = archive_member
- def load(self, requested_file, sha, save_dir):
- if self.download_sha is None:
- download_dir = save_dir
- else:
- # create a new folder in save_dir to avoid possible name conflicts
- download_dir = os.path.join(save_dir, self.download_sha)
- if not os.path.exists(download_dir):
- os.makedirs(download_dir)
- download_path = os.path.join(download_dir, self.download_name)
- print(" Preparing to download file " + self.download_name)
- if checkHashsum(self.download_sha, download_path):
- print(' hash match - file already exists, no need to download')
- else:
- filesize = self.download(download_path)
- print(' Downloaded {} with size {} Mb'.format(self.download_name, filesize/self.MB))
- if self.download_sha is not None:
- checkHashsum(self.download_sha, download_path, silent=False)
- if self.download_name == requested_file:
- return
- else:
- if isArchive(download_path):
- if sha is not None:
- extract_dir = os.path.join(save_dir, sha)
- else:
- extract_dir = save_dir
- if not os.path.exists(extract_dir):
- os.makedirs(extract_dir)
- self.extract(requested_file, download_path, extract_dir)
- else:
- raise Exception("Downloaded file has different name")
- def download(self, filepath):
- print("Warning: download is not implemented, this is a base class")
- return 0
- def extract(self, requested_file, archive_path, save_dir):
- filepath = os.path.join(save_dir, requested_file)
- try:
- with tarfile.open(archive_path) as f:
- if self.archive_member is None:
- pathDict = dict((os.path.split(elem)[1], os.path.split(elem)[0]) for elem in f.getnames())
- self.archive_member = pathDict[requested_file]
- assert self.archive_member in f.getnames()
- self.save(filepath, f.extractfile(self.archive_member))
- except Exception as e:
- print(' catch {}'.format(e))
- def save(self, filepath, r):
- with open(filepath, 'wb') as f:
- print(' progress ', end="")
- sys.stdout.flush()
- while True:
- buf = r.read(self.BUFSIZE)
- if not buf:
- break
- f.write(buf)
- print('>', end="")
- sys.stdout.flush()
- class URLLoader(Loader):
- def __init__(self, download_name, download_sha, url, archive_member = None):
- super(URLLoader, self).__init__(download_name, download_sha, archive_member)
- self.download_name = download_name
- self.download_sha = download_sha
- self.url = url
- def download(self, filepath):
- r = urlopen(self.url, timeout=60)
- self.printRequest(r)
- self.save(filepath, r)
- return os.path.getsize(filepath)
- def printRequest(self, r):
- def getMB(r):
- d = dict(r.info())
- for c in ['content-length', 'Content-Length']:
- if c in d:
- return int(d[c]) / self.MB
- return '<unknown>'
- print(' {} {} [{} Mb]'.format(r.getcode(), r.msg, getMB(r)))
- class GDriveLoader(Loader):
- BUFSIZE = 1024 * 1024
- PROGRESS_SIZE = 10 * 1024 * 1024
- def __init__(self, download_name, download_sha, gid, archive_member = None):
- super(GDriveLoader, self).__init__(download_name, download_sha, archive_member)
- self.download_name = download_name
- self.download_sha = download_sha
- self.gid = gid
- def download(self, filepath):
- session = requests.Session() # re-use cookies
- URL = "https://docs.google.com/uc?export=download"
- response = session.get(URL, params = { 'id' : self.gid }, stream = True)
- def get_confirm_token(response): # in case of large files
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
- token = get_confirm_token(response)
- if token:
- params = { 'id' : self.gid, 'confirm' : token }
- response = session.get(URL, params = params, stream = True)
- sz = 0
- progress_sz = self.PROGRESS_SIZE
- with open(filepath, "wb") as f:
- for chunk in response.iter_content(self.BUFSIZE):
- if not chunk:
- continue # keep-alive
- f.write(chunk)
- sz += len(chunk)
- if sz >= progress_sz:
- progress_sz += self.PROGRESS_SIZE
- print('>', end='')
- sys.stdout.flush()
- print('')
- return sz
- def produceDownloadInstance(instance_name, filename, sha, url, save_dir, download_name=None, download_sha=None, archive_member=None):
- spec_param = url
- loader = URLLoader
- if download_name is None:
- download_name = filename
- if download_sha is None:
- download_sha = sha
- if "drive.google.com" in url:
- token = ""
- token_part = url.rsplit('/', 1)[-1]
- if "&id=" not in token_part:
- token_part = url.rsplit('/', 1)[-2]
- for param in token_part.split("&"):
- if param.startswith("id="):
- token = param[3:]
- if token:
- loader = GDriveLoader
- spec_param = token
- else:
- print("Warning: possibly wrong Google Drive link")
- return DownloadInstance(
- name=instance_name,
- filename=filename,
- sha=sha,
- save_dir=save_dir,
- loader=loader(download_name, download_sha, spec_param, archive_member)
- )
- def getSaveDir():
- env_path = os.environ.get("OPENCV_DOWNLOAD_DATA_PATH", None)
- if env_path:
- save_dir = env_path
- else:
- # TODO reuse binding function cv2.utils.fs.getCacheDirectory when issue #19011 is fixed
- if platform.system() == "Darwin":
- #On Apple devices
- temp_env = os.environ.get("TMPDIR", None)
- if temp_env is None or not os.path.isdir(temp_env):
- temp_dir = Path("/tmp")
- print("Using world accessible cache directory. This may be not secure: ", temp_dir)
- else:
- temp_dir = temp_env
- elif platform.system() == "Windows":
- temp_dir = tempfile.gettempdir()
- else:
- xdg_cache_env = os.environ.get("XDG_CACHE_HOME", None)
- if (xdg_cache_env and xdg_cache_env[0] and os.path.isdir(xdg_cache_env)):
- temp_dir = xdg_cache_env
- else:
- home_env = os.environ.get("HOME", None)
- if (home_env and home_env[0] and os.path.isdir(home_env)):
- home_path = os.path.join(home_env, ".cache/")
- if os.path.isdir(home_path):
- temp_dir = home_path
- else:
- temp_dir = tempfile.gettempdir()
- print("Using world accessible cache directory. This may be not secure: ", temp_dir)
- save_dir = os.path.join(temp_dir, "downloads")
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- return save_dir
- def downloadFile(url, sha=None, save_dir=None, filename=None):
- if save_dir is None:
- save_dir = getSaveDir()
- if filename is None:
- filename = "download_" + datetime.now().__str__()
- name = filename
- return produceDownloadInstance(name, filename, sha, url, save_dir).get()
- def parseMetalinkFile(metalink_filepath, save_dir):
- NS = {'ml': 'urn:ietf:params:xml:ns:metalink'}
- models = []
- for file_elem in ET.parse(metalink_filepath).getroot().findall('ml:file', NS):
- url = file_elem.find('ml:url', NS).text
- fname = file_elem.attrib['name']
- name = file_elem.find('ml:identity', NS).text
- hash_sum = file_elem.find('ml:hash', NS).text
- models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir))
- return models
- def parseYAMLFile(yaml_filepath, save_dir):
- models = []
- with open(yaml_filepath, 'r') as stream:
- data_loaded = yaml.safe_load(stream)
- for name, params in data_loaded.items():
- load_info = params.get("load_info", None)
- if load_info:
- fname = os.path.basename(params.get("model"))
- hash_sum = load_info.get("sha1")
- url = load_info.get("url")
- download_sha = load_info.get("download_sha")
- download_name = load_info.get("download_name")
- archive_member = load_info.get("member")
- models.append(produceDownloadInstance(name, fname, hash_sum, url, save_dir,
- download_name=download_name, download_sha=download_sha, archive_member=archive_member))
- return models
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='This is a utility script for downloading DNN models for samples.')
- parser.add_argument('--save_dir', action="store", default=os.getcwd(),
- help='Path to the directory to store downloaded files')
- parser.add_argument('model_name', type=str, default="", nargs='?', action="store",
- help='name of the model to download')
- args = parser.parse_args()
- models = []
- save_dir = args.save_dir
- selected_model_name = args.model_name
- models.extend(parseMetalinkFile('face_detector/weights.meta4', save_dir))
- models.extend(parseYAMLFile('models.yml', save_dir))
- for m in models:
- print(m)
- if selected_model_name and not m.name.startswith(selected_model_name):
- continue
- print('Model: ' + selected_model_name)
- m.get()
|