diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 0db2f86..42e280e 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -46,13 +46,14 @@ def __init__( self, root: Node, flat_model: linear.FlatModel, - weight_map: np.ndarray, + node_ptr: np.ndarray, ): self.name = "tree" self.root = root self.flat_model = flat_model - self.weight_map = weight_map + self.node_ptr = node_ptr self.multiclass = False + self._model_separated = False # Indicates whether the model has been separated for pruning tree. def predict_values( self, @@ -68,10 +69,93 @@ def predict_values( Returns: np.ndarray: A matrix with dimension number of instances * number of classes. """ - # number of instances * number of labels + total number of metalabels - all_preds = linear.predict_values(self.flat_model, x) + if beam_width >= len(self.root.children): + # Beam_width is sufficiently large; pruning not applied. + # Calculates decision values for all nodes. + all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels) + else: + # Beam_width is small; pruning applied to reduce computation. + if not self._model_separated: + self._separate_model_for_pruning_tree() + self._model_separated = True + all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) + def _separate_model_for_pruning_tree(self): + """ + This function separates the weights for the root node and its children into (K+1) FlatModel + for efficient beam search traversal in Python. + """ + tree_flat_model_params = { + 'bias': self.root.model.bias, + 'thresholds': 0, + 'multiclass': False + } + slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]] + self.root_model = linear.FlatModel( + name="root-flattened-tree", + weights=self.flat_model.weights[slice].tocsr(), + **tree_flat_model_params + ) + + self.subtree_models = [] + for i in range(len(self.root.children)): + subtree_weights_start = self.node_ptr[self.root.children[i].index] + subtree_weights_end = self.node_ptr[self.root.children[i+1].index] if i+1 < len(self.root.children) else -1 + slice = np.s_[:, subtree_weights_start:subtree_weights_end] + subtree_flatmodel = linear.FlatModel( + name="subtree-flattened-tree", + weights=self.flat_model.weights[slice].tocsr(), + **tree_flat_model_params + ) + self.subtree_models.append(subtree_flatmodel) + + def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: + """Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees. + + Only subtrees corresponding to the top beam_width candidates from the root are evaluated, + skipping the rest to avoid unnecessary computation. + + Args: + x (sparse.csr_matrix): A matrix with dimension number of instances * number of features. + beam_width (int): Number of top candidate branches considered for prediction. + + Returns: + np.ndarray: A matrix with dimension number of instances * (number of labels + total number of metalabels). + """ + # Initialize space for all predictions with negative infinity + num_instances, num_labels = x.shape[0], self.node_ptr[-1] + all_preds = np.full((num_instances, num_labels), -np.inf) + + # Calculate root decision values and scores + root_preds = linear.predict_values(self.root_model, x) + children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds)) + + slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]] + all_preds[slice] = root_preds + + # Select indices of the top beam_width subtrees for each instance + top_beam_width_indices = np.argsort(-children_scores, axis=1, kind="stable")[:, :beam_width] + + # Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance + mask = np.zeros_like(children_scores, dtype=np.bool_) + np.put_along_axis(mask, top_beam_width_indices, True, axis=1) + + # Calculate predictions for each subtree with its corresponding instances + for subtree_idx in range(len(self.root.children)): + subtree_model = self.subtree_models[subtree_idx] + instances_mask = mask[:, subtree_idx] + reduced_instances = x[np.s_[instances_mask], :] + + # Locate the position of the subtree root in the weight mapping of all nodes + subtree_weights_start = self.node_ptr[self.root.children[subtree_idx].index] + subtree_weights_end = subtree_weights_start + subtree_model.weights.shape[1] + + slice = np.s_[instances_mask, subtree_weights_start:subtree_weights_end] + all_preds[slice] = linear.predict_values(subtree_model, reduced_instances) + + return all_preds + def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarray: """Predict with beam search using cached probability estimates for a single instance. @@ -93,7 +177,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra if node.isLeaf(): next_level.append((node, score)) continue - slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]] + slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]] pred = instance_preds[slice] children_score = score - np.square(np.maximum(0, 1 - pred)) next_level.extend(zip(node.children, children_score.tolist())) @@ -102,9 +186,9 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra next_level = [] num_labels = len(self.root.label_map) - scores = np.full(num_labels, 0.0) + scores = np.zeros(num_labels) for node, score in cur_level: - slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]] + slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]] pred = instance_preds[slice] scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred))) return scores @@ -130,7 +214,7 @@ def train_tree( verbose (bool, optional): Output extra progress information. Defaults to True. Returns: - A model which can be used in predict_values. + TreeModel: A model which can be used in predict_values. """ label_representation = (y.T * x).tocsr() label_representation = sklearn.preprocessing.normalize(label_representation, norm="l2", axis=1) @@ -173,8 +257,8 @@ def visit(node): root.dfs(visit) pbar.close() - flat_model, weight_map = _flatten_model(root) - return TreeModel(root, flat_model, weight_map) + flat_model, node_ptr = _flatten_model(root) + return TreeModel(root, flat_model, node_ptr) def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node: @@ -188,7 +272,7 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, dmax (int): Maximum depth of the tree. Returns: - Node: root of the (sub)tree built from label_representation. + Node: Root of the (sub)tree built from label_representation. """ if d >= dmax or label_representation.shape[0] <= K: return Node(label_map=label_map, children=[]) @@ -261,11 +345,10 @@ def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]: """Flattens tree weight matrices into a single weight matrix. The flattened weight matrix is used to predict all possible values, which is cached for beam search. This pessimizes complexity but is faster in practice. - Consecutive values of the returned map denotes the start and end indices of the - weights of each node. Conceptually, given root and node: - flat_model, weight_map = _flatten_model(root) - slice = np.s_[weight_map[node.index]: - weight_map[node.index+1]] + Consecutive values of the returned array denote the start and end indices of each node in the tree. + To extract a node's classifiers: + slice = np.s_[node_ptr[node.index]: + node_ptr[node.index+1]] node.model.weights == flat_model.weights[:, slice] Args: @@ -296,6 +379,6 @@ def visit(node): ) # w.shape[1] is the number of labels/metalabels of each node - weight_map = np.cumsum([0] + list(map(lambda w: w.shape[1], weights))) + node_ptr = np.cumsum([0] + list(map(lambda w: w.shape[1], weights))) - return model, weight_map + return model, node_ptr