|
| 1 | +import argparse |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import ray |
| 5 | +import pickle |
| 6 | +import requests |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +import explainers.wrappers as wrappers |
| 10 | + |
| 11 | +from collections import namedtuple |
| 12 | +from ray import serve |
| 13 | +from timeit import default_timer as timer |
| 14 | +from typing import Any, Dict, List, Tuple |
| 15 | +from explainers.utils import get_filename, batch, load_data, load_model |
| 16 | + |
| 17 | + |
| 18 | +logging.basicConfig(level=logging.INFO) |
| 19 | + |
| 20 | +PREDICTOR_URL = 'https://storage.googleapis.com/seldon-models/alibi/distributed_kernel_shap/predictor.pkl' |
| 21 | +PREDICTOR_PATH = 'assets/predictor.pkl' |
| 22 | +""" |
| 23 | +str: The file containing the predictor. The predictor can be created by running `fit_adult_model.py` or output by |
| 24 | +calling `explainers.utils.load_model()`, which will download a default predictor if `assets/` does not contain one. |
| 25 | +""" |
| 26 | + |
| 27 | + |
| 28 | +def endpont_setup(tag: str, backend_tag: str, route: str = "/"): |
| 29 | + """ |
| 30 | + Creates an endpoint for serving explanations. |
| 31 | +
|
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + tag |
| 35 | + Endpoint tag. |
| 36 | + backend_tag |
| 37 | + A tag for the backend this explainer will connect to. |
| 38 | + route |
| 39 | + The URL where the explainer can be queried. |
| 40 | + """ |
| 41 | + serve.create_endpoint(tag, backend=backend_tag, route=route, methods=["GET"]) |
| 42 | + |
| 43 | + |
| 44 | +def backend_setup(tag: str, worker_args: Tuple, replicas: int, max_batch_size: int) -> None: |
| 45 | + """ |
| 46 | + Setups the backend for the distributed explanation task. |
| 47 | +
|
| 48 | + Parameters |
| 49 | + ---------- |
| 50 | + tag |
| 51 | + A tag for the backend component. The same tag must be passed to `endpoint_setup`. |
| 52 | + worker_args |
| 53 | + A tuple containing the arguments for initialising the explainer and fitting it. |
| 54 | + replicas |
| 55 | + The number of backend replicas that serve explanations. |
| 56 | + max_batch_size |
| 57 | + Maximum number of requests to batch and send to a worker process. |
| 58 | + """ |
| 59 | + |
| 60 | + if max_batch_size == 1: |
| 61 | + config = {'num_replicas': max(replicas, 1)} |
| 62 | + serve.create_backend(tag, wrappers.KernelShapModel, *worker_args) |
| 63 | + else: |
| 64 | + config = {'num_replicas': max(replicas, 1), 'max_batch_size': max_batch_size} |
| 65 | + serve.create_backend(tag, wrappers.BatchKernelShapModel, *worker_args) |
| 66 | + serve.update_backend_config(tag, config) |
| 67 | + |
| 68 | + logging.info(f"Backends: {serve.list_backends()}") |
| 69 | + |
| 70 | + |
| 71 | +def prepare_explainer_args(data: Dict[str, Any]) -> Tuple[str, np.ndarray, dict, dict]: |
| 72 | + """ |
| 73 | + Extracts the name of the features (group_names) and the columns corresponding to each feature in the faeture matrix |
| 74 | + (group_names) from the `data` dict and defines the explainer arguments. The background data necessary to initialise |
| 75 | + the explainer is also extracted from the same dictionary. |
| 76 | +
|
| 77 | + Parameters |
| 78 | + ---------- |
| 79 | + data |
| 80 | + A dictionary that contains all information necessary to initialise the explainer. |
| 81 | +
|
| 82 | + Returns |
| 83 | + ------- |
| 84 | + A tuple containing the positional and keyword arguments necessary for initialising the explainers. |
| 85 | + """ |
| 86 | + |
| 87 | + groups = data['all']['groups'] |
| 88 | + group_names = data['all']['group_names'] |
| 89 | + background_data = data['background']['X']['preprocessed'] |
| 90 | + assert background_data.shape[0] == 100 |
| 91 | + init_kwargs = {'link': 'logit', 'feature_names': group_names, 'seed': 0} |
| 92 | + fit_kwargs = {'groups': groups, 'group_names': group_names} |
| 93 | + predictor = load_model(PREDICTOR_URL) |
| 94 | + worker_args = (predictor, background_data, init_kwargs, fit_kwargs) |
| 95 | + |
| 96 | + return worker_args |
| 97 | + |
| 98 | + |
| 99 | +@ray.remote |
| 100 | +def distribute_request(instance: np.ndarray, url: str = "http://localhost:8000/explain") -> str: |
| 101 | + """ |
| 102 | + Task for distributing the explanations across the backend. |
| 103 | +
|
| 104 | + Parameters |
| 105 | + ---------- |
| 106 | + instance |
| 107 | + Instance to be explained. |
| 108 | + url: |
| 109 | + The explainer URL. |
| 110 | +
|
| 111 | + Returns |
| 112 | + ------- |
| 113 | + A str representation of the explanation output json file. |
| 114 | + """ |
| 115 | + |
| 116 | + resp = requests.get(url, json={"array": instance.tolist()}) |
| 117 | + return resp.json() |
| 118 | + |
| 119 | + |
| 120 | +def request_explanations(instances: List[np.ndarray], *, url: str) -> namedtuple: |
| 121 | + """ |
| 122 | + Sends the instances to the explainer URL. |
| 123 | +
|
| 124 | + Parameters |
| 125 | + ---------- |
| 126 | + instances: |
| 127 | + Array of instances to be explained. |
| 128 | + url |
| 129 | + Explainer endpoint. |
| 130 | +
|
| 131 | +
|
| 132 | + Returns |
| 133 | + ------- |
| 134 | + responses |
| 135 | + A named tuple with a `responses` field and a `t_elapsed` field. |
| 136 | + """ |
| 137 | + |
| 138 | + run_output = namedtuple('run_output', 'responses t_elapsed') |
| 139 | + tstart = timer() |
| 140 | + responses_id = [distribute_request.remote(instance, url=url) for instance in instances] |
| 141 | + responses = [ray.get(resp_id) for resp_id in responses_id] |
| 142 | + t_elapsed = timer() - tstart |
| 143 | + logging.info(f"Time elapsed: {t_elapsed}...") |
| 144 | + |
| 145 | + return run_output(responses=responses, t_elapsed=t_elapsed) |
| 146 | + |
| 147 | + |
| 148 | +def run_explainer(X_explain: np.ndarray, |
| 149 | + n_runs: int, |
| 150 | + replicas: int, |
| 151 | + max_batch_size: int, |
| 152 | + batch_mode: str = 'ray', |
| 153 | + url: str = "http://localhost:8000/explain"): |
| 154 | + """ |
| 155 | + Setup an endpoint and a backend and send requests to the endpoint. |
| 156 | +
|
| 157 | + Parameters |
| 158 | + ----------- |
| 159 | + X_explain |
| 160 | + Instances to be explained. Each row is an instance that is explained independently of others. |
| 161 | + n_runs |
| 162 | + Number of times to run an experiment where the entire set of explanations is sent to the explainer endpoint. |
| 163 | + Used to determine the average runtime given the number of cores. |
| 164 | + replicas |
| 165 | + How many backend replicas should be used for distributing the workload |
| 166 | + max_batch_size |
| 167 | + The maximum batch size the explainer accepts. |
| 168 | + batch_mode : {'ray', 'default'} |
| 169 | + If 'ray', ray_serve components are leveraged for minibatches. Otherwise the input tensor is split into |
| 170 | + minibatches which are sent to the endpoint. |
| 171 | + url |
| 172 | + The url of the explainer endpoint. |
| 173 | + """ |
| 174 | + |
| 175 | + result = {'t_elapsed': [], 'explanations': []} |
| 176 | + # extract instances to be explained from the dataset |
| 177 | + assert X_explain.shape[0] == 2560 |
| 178 | + |
| 179 | + # split input into separate requests |
| 180 | + if batch_mode == 'ray': |
| 181 | + instances = np.split(X_explain, X_explain.shape[0]) # use ray serve to batch the requests |
| 182 | + logging.info(f"Explaining {len(instances)} instances...") |
| 183 | + else: |
| 184 | + instances = batch(X_explain, batch_size=max_batch_size) |
| 185 | + logging.info(f"Explaining {len(instances)} mini-batches of size {max_batch_size}...") |
| 186 | + |
| 187 | + # distribute it |
| 188 | + for run in range(n_runs): |
| 189 | + logging.info(f"Experiment run: {run}...") |
| 190 | + results = request_explanations(instances, url=url) |
| 191 | + result['t_elapsed'].append(results.t_elapsed) |
| 192 | + result['explanations'].append(results.responses) |
| 193 | + |
| 194 | + with open(get_filename(replicas, max_batch_size), 'wb') as f: |
| 195 | + pickle.dump(result, f) |
| 196 | + |
| 197 | + |
| 198 | +def main(): |
| 199 | + |
| 200 | + if not os.path.exists('results'): |
| 201 | + os.mkdir('results') |
| 202 | + |
| 203 | + data = load_data() |
| 204 | + X_explain = data['all']['X']['processed']['test'].toarray() |
| 205 | + |
| 206 | + max_batch_size = [int(elem) for elem in args.max_batch_size][0] |
| 207 | + batch_mode, replicas = args.batch_mode, args.replicas |
| 208 | + ray.init(address='auto') # connect to the cluster |
| 209 | + serve.init(http_host='0.0.0.0') # listen on 0.0.0.0 to make endpoint accessible from other machines |
| 210 | + host, route = os.environ.get("RAY_HEAD_SERVICE_HOST", args.host), "explain" |
| 211 | + url = f"http://{host}:{args.port}/{route}" |
| 212 | + backend_tag = "kernel_shap:b100" # b100 means 100 background samples |
| 213 | + endpoint_tag = f"{backend_tag}_endpoint" |
| 214 | + worker_args = prepare_explainer_args(data) |
| 215 | + if batch_mode == 'ray': |
| 216 | + backend_setup(backend_tag, worker_args, replicas, max_batch_size) |
| 217 | + logging.info(f"Batching with max_batch_size of {max_batch_size} ...") |
| 218 | + else: # minibatches are sent to the ray worker |
| 219 | + backend_setup(backend_tag, worker_args, replicas, 1) |
| 220 | + logging.info(f"Minibatches distributed of size {max_batch_size} ...") |
| 221 | + endpont_setup(endpoint_tag, backend_tag, route=f"/{route}") |
| 222 | + |
| 223 | + run_explainer(X_explain, args.n_runs, replicas, max_batch_size, batch_mode=batch_mode, url=url) |
| 224 | + |
| 225 | + |
| 226 | +if __name__ == '__main__': |
| 227 | + parser = argparse.ArgumentParser() |
| 228 | + parser.add_argument( |
| 229 | + "-r", |
| 230 | + "--replicas", |
| 231 | + default=1, |
| 232 | + type=int, |
| 233 | + help="The number of backend replicas used to serve the explainer." |
| 234 | + ) |
| 235 | + parser.add_argument( |
| 236 | + "-batch", |
| 237 | + "--max_batch_size", |
| 238 | + nargs='+', |
| 239 | + help="A list of values representing the maximum batch size of pending queries sent to the same worker." |
| 240 | + "This should only contain one element as the backend is reset from `k8s_benchmark_serve.sh`.", |
| 241 | + required=True, |
| 242 | + ) |
| 243 | + parser.add_argument( |
| 244 | + "-batch_mode", |
| 245 | + type=str, |
| 246 | + default='ray', |
| 247 | + help="If set to 'ray' the batching will be leveraging ray serve. Otherwise, the input array is split into " |
| 248 | + "minibatches that are sent to the endpoint.", |
| 249 | + required=True, |
| 250 | + ) |
| 251 | + parser.add_argument( |
| 252 | + "-n", |
| 253 | + "--n_runs", |
| 254 | + default=5, |
| 255 | + type=int, |
| 256 | + help="Controls how many times an experiment is run (in benchmark mode) for a given number of cores to obtain " |
| 257 | + "run statistics." |
| 258 | + ) |
| 259 | + parser.add_argument( |
| 260 | + "-ho", |
| 261 | + "--host", |
| 262 | + default="localhost", |
| 263 | + type=str, |
| 264 | + help="Hostname." |
| 265 | + ) |
| 266 | + parser.add_argument( |
| 267 | + "-p", |
| 268 | + "--port", |
| 269 | + default="8000", |
| 270 | + type=str, |
| 271 | + help="Port." |
| 272 | + ) |
| 273 | + args = parser.parse_args() |
| 274 | + main() |
0 commit comments