diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index ea65b65a..f88cd7fc 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -120,6 +120,18 @@ This will remove all traces of the DiffusionBee application (cache, generated im Open your Home folder, and press the cmd + shift + . (command + shift + period) keys. This will show the hidden files and directories. You can safely drag and drop the .diffusionbee folder in the Trash. +## Options + +### Proxy Support +Some environments may need proxy support to obtain and download models. Custom proxies can be set in DiffusionBee's config file. Just create or add the following content in: +`~/.diffusionbee/config.ini`: +``` +[proxy] +host: http://localhost:8000 +``` +Afterwards, DiffusionBee will use this host as `http` and `https` proxy for downloading new models. + + ## Extra tips * For prompt ideas and help, check out: diff --git a/backends/stable_diffusion/downloader.py b/backends/stable_diffusion/downloader.py index ad2d541f..4fef6674 100644 --- a/backends/stable_diffusion/downloader.py +++ b/backends/stable_diffusion/downloader.py @@ -10,6 +10,7 @@ if not os.path.isdir(projects_root_path): os.mkdir(projects_root_path) +import configparser import requests import os import hashlib @@ -72,6 +73,33 @@ def is_already_downloaded(self, out_fname=None, md5_checksum=None): return False return True + def get_proxy_from_config(config_path): + """ Parse diffusionbee's dotfile and evaluate for optional + proxy settings to use for downloading further objects + Args: + config_path (str): The path to the dotfile. + Returns: + proxies (dict): Containing the http/https proxies to use. + """ + # Read proxy host from dotfile if key is present + proxy_host = False + + try: + config = configparser.ConfigParser() + config.read(config_path) + proxy_host = config['proxy']['host'] + except configparser.ParsingError: + print("Error parsing config file: Impossible to parse file.") + except KeyError: + pass + + if proxy_host: + print("Setting proxy") + proxies = {} + proxies['http'] = proxy_host + proxies['https'] = proxy_host + return proxies + def download(self, url, out_fname=None, md5_checksum=None, verify_ssl=True, extract_zip=False, dont_use_cache=False): """Download the file @@ -98,6 +126,13 @@ def download(self, url, out_fname=None, md5_checksum=None, print("sdbk mlpr %d"%int(-1) ) print("sdbk mltl Checking Model") + # Proxy support + home = os.path.expanduser("~") + diffusionbee_config = f"{home}/.diffusionbee/config.ini" + if os.path.exists(diffusionbee_config): + proxies = self.get_proxy_from_config(diffusionbee_config) + + if (not dont_use_cache) and self.is_already_downloaded( out_fname=out_fname, md5_checksum=md5_checksum): if extract_zip: @@ -112,7 +147,10 @@ def download(self, url, out_fname=None, md5_checksum=None, with open(out_abs_path, "wb") as f: print("sdbk mltl " + self.title) - response = requests.get(url, stream=True, verify=verify_ssl) + if proxies is not None: + response = requests.get(url, stream=True, verify=verify_ssl, proxies=proxies) + else: + response = requests.get(url, stream=True, verify=verify_ssl) total_length = response.headers.get('content-length')