diff --git a/boltz_ph/model_utils.py b/boltz_ph/model_utils.py index 0e18c37..d812ebb 100755 --- a/boltz_ph/model_utils.py +++ b/boltz_ph/model_utils.py @@ -654,8 +654,8 @@ def plot_run_metrics( run_save_dir: str, name: str, run_id: int, num_cycles: int, run_metrics: dict ): """Plots per-run metrics (iPTM, pLDDT, Alanine Count) over design cycles.""" - fig, axs = plt.subplots(1, 4, figsize=(16, 4)) # Increased figure width - colors = ["#9B59B6", "#E94560", "#FF7F11", "#2ECC71"] + fig, axs = plt.subplots(1, 5, figsize=(20, 4)) # Increased figure width + colors = ["#9B59B6", "#E94560", "#FF7F11", "#2ECC71", "#1673EC"] # Helper to retrieve data and format def get_metric_data(key_suffix, label, ymin, ymax, fmt): @@ -673,6 +673,7 @@ def get_metric_data(key_suffix, label, ymin, ymax, fmt): max([run_metrics.get(f"cycle_{i}_alanine", 0) for i in range(num_cycles + 1)]) + 2, "{}", ), + get_metric_data("ipsae_min", "ipSAE_min", 0, 1, "{:.3f}"), ] design_cycles = list(range(num_cycles + 1)) diff --git a/boltz_ph/pipeline.py b/boltz_ph/pipeline.py index d5f095f..0c5da0e 100644 --- a/boltz_ph/pipeline.py +++ b/boltz_ph/pipeline.py @@ -35,6 +35,157 @@ smart_split, ) +# --- ipAE / ipSAE helpers + +def extract_cb_coords(structure, coords): + + atom_coords = coords[0].detach().cpu().numpy() # (N_atoms, 3) + cb_list = [] + + for chain in structure.chains: + res_start = chain["res_idx"] + res_end = res_start + chain["res_num"] + + for res in structure.residues[res_start:res_end]: + a0 = res["atom_idx"] + a1 = a0 + res["atom_num"] + residue_atoms = structure.atoms[a0:a1] + + cb_coord = None + ca_coord = None + + for local_i, atom in enumerate(residue_atoms): + name = atom["name"] + + if name == "CB": + cb_coord = atom_coords[a0 + local_i] + break # prefer CB immediately + + if name == "CA": + ca_coord = atom_coords[a0 + local_i] + + # Use CB if present + if cb_coord is not None: + cb_list.append(cb_coord) + continue + + # Use CA if present (covers Gly automatically) + if ca_coord is not None: + cb_list.append(ca_coord) + continue + + # Fallback: centroid (same as extract_ca_coords) + meanpos = atom_coords[a0:a1].mean(axis=0) + cb_list.append(meanpos) + + return np.array(cb_list) + + + + +def compute_best_target_ip_metrics(pae, coords, structure, + target_lengths, binder_len): + """ + Corrected version: assumes chain ordering is: + [binder][target1][target2]...[targetN] + + Parameters + ---------- + pae : torch.Tensor [1, L, L] + coords : torch.Tensor [1, N_atoms, 3] + structure : Boltz structure + target_lengths : list[int] + binder_len : int + + Returns + ------- + best_ipae : float + best_ipsae_min : float + best_target_index : int + """ + + pae_np = pae.detach().cpu().numpy()[0] + + # Extract CA coordinates (same code as your script) + cb = extract_cb_coords(structure, coords) + assert cb.shape[0] == (binder_len + sum(target_lengths)) + + # binder first in the chain order + B = np.arange(0, binder_len) + + best_ipae = -1e9 + best_ipsae_min = 0.0 + best_t = -1 + + # offset starts AFTER binder + offset = binder_len + + # Precompute distance matrix + diff = cb[:, None, :] - cb[None, :, :] + dist = np.linalg.norm(diff, axis=-1) + + for t_idx, L_T in enumerate(target_lengths): + + A = np.arange(offset, offset + L_T) + offset += L_T + + # --- ipAE --- + block_AB = pae_np[np.ix_(A, B)] + block_BA = pae_np[np.ix_(B, A)] + ipae = float(np.mean(np.concatenate([block_AB.ravel(), block_BA.ravel()]))) + + # --- ipSAE_min --- + def _asym_d0res(S, T): + """ + Ground-truth ipSAE_d0res_asym used in ipsae.py: + + For each residue i in S: + - find valid partner residues j in T with: + pae[i,j] < 15 + dist[i,j] < 15 Å + - n0res(i) = number of such residues j + - d0(i) = calc_d0(n0res(i)) with minimum 1.0 + - ipSAE(i) = mean_j PTM(pae[i,j], d0(i)) + Score = max_i ipSAE(i) + """ + best = 0.0 + + for i in S: + # indices j in T that satisfy both cutoffs + mask = (pae_np[i, T] < 15.0) & (dist[i, T] < 15.0) + if not np.any(mask): + continue + + vals = pae_np[i, T][mask] + n0res = len(vals) + + # --- d0res formula --- + L_eff = max(n0res, 27) + d0 = max(1.0, 1.24 * (L_eff - 15)**(1/3) - 1.8) + + # PTM scores + ptm = 1.0 / (1.0 + (vals / d0)**2) + score_i = float(ptm.mean()) + + if score_i > best: + best = score_i + + return best + + + a_to_b = _asym_d0res(A, B) + b_to_a = _asym_d0res(B, A) + ipsae_min = float(min(a_to_b, b_to_a)) + + # Track best + if ipae > best_ipae: + best_ipae = ipae + best_ipsae_min = ipsae_min + best_t = t_idx + + return best_ipae, best_ipsae_min + + class InputDataBuilder: """Handles parsing command-line arguments and constructing the base Boltz input data dictionary.""" @@ -114,7 +265,6 @@ def _build_conditional_data(self): # Assign chain IDs to proteins first protein_chain_ids = [chr(ord('B') + i) for i in range(len(protein_seqs_list))] - # Find next available chain letter next_chain_idx = len(protein_chain_ids) @@ -310,7 +460,7 @@ def _load_boltz_model(self): "sampling_steps": self.args.diffuse_steps, "diffusion_samples": 1, "write_confidence_summary": True, - "write_full_pae": False, + "write_full_pae": True, "write_full_pde": False, "max_parallel_samples": 1, } @@ -343,6 +493,8 @@ def _run_design_cycle(self, data_cp, run_id, pocket_conditioning): best_pdb_filename = None best_cycle_idx = -1 best_alanine_percentage = None + best_ipae = float("-inf") + best_ipsae_min = float("-inf") run_metrics = {"run_id": run_id} @@ -437,9 +589,27 @@ def update_binder_sequence(new_seq): clean_memory() # <-- ADD THIS CALL HERE # Capture Cycle 0 metrics + + # -------- Determine target_lengths and binder_len -------- + protein_seqs = [] + binder_seq = None + + for entry in data_cp["sequences"]: + if "protein" in entry: + cid = entry["protein"]["id"][0] + seq = entry["protein"]["sequence"] + if cid == self.binder_chain: + binder_len = len(seq) + else: + protein_seqs.append(seq) + + target_lengths = [len(s) for s in protein_seqs] + # --------------------------------------------------------- + + binder_chain_idx = CHAIN_TO_NUMBER[self.binder_chain] pair_chains = output["pair_chains_iptm"] - + # Calculate i-pTM if len(pair_chains) > 1: values = [ @@ -454,6 +624,15 @@ def update_binder_sequence(new_seq): cycle_0_iptm = float(np.mean(values) if values else 0.0) else: cycle_0_iptm = 0.0 + + cycle_0_ipae, cycle_0_ipsae_min = compute_best_target_ip_metrics( + output["pae"], + output["coords"], + structure, + target_lengths, + binder_len, + ) + run_metrics["cycle_0_iptm"] = cycle_0_iptm run_metrics["cycle_0_plddt"] = float( @@ -464,6 +643,9 @@ def update_binder_sequence(new_seq): ) run_metrics["cycle_0_alanine"] = 0 + run_metrics["cycle_0_ipae"] = cycle_0_ipae + run_metrics["cycle_0_ipsae_min"] = cycle_0_ipsae_min + # --- Optimization Cycles --- for cycle in range(a.num_cycles): print(f"\n--- Run {run_id}, Cycle {cycle + 1} ---") @@ -531,6 +713,13 @@ def update_binder_sequence(new_seq): current_iptm = float(np.mean(values) if values else 0.0) else: current_iptm = 0.0 + curr_ipae, curr_ipsae_min = compute_best_target_ip_metrics( + output["pae"], + output["coords"], + structure, + target_lengths, + binder_len, + ) # Update best structure (only if alanine content is acceptable) if alanine_percentage <= 0.20 and current_iptm > best_iptm: @@ -570,6 +759,13 @@ def update_binder_sequence(new_seq): run_metrics[f"cycle_{cycle + 1}_iplddt"] = curr_iplddt run_metrics[f"cycle_{cycle + 1}_alanine"] = alanine_count run_metrics[f"cycle_{cycle + 1}_seq"] = seq + run_metrics[f"cycle_{cycle + 1}_ipae"] = curr_ipae + run_metrics[f"cycle_{cycle + 1}_ipsae_min"] = curr_ipsae_min + + if curr_ipae > best_ipae: + best_ipae = curr_ipae + if curr_ipsae_min > best_ipsae_min: + best_ipsae_min = curr_ipsae_min pdb_filename = ( f"{run_save_dir}/{a.name}_run_{run_id}_predicted_cycle_{cycle + 1}.pdb" @@ -579,7 +775,7 @@ def update_binder_sequence(new_seq): clean_memory() print( - f"ipTM: {current_iptm:.2f} pLDDT: {curr_plddt:.2f} iPLDDT: {curr_iplddt:.2f} Alanine count: {alanine_count}" + f"ipTM: {current_iptm:.2f} ipAE: {curr_ipae:.2f} ipSAE_min: {curr_ipsae_min:.2f} pLDDT: {curr_plddt:.2f} iPLDDT: {curr_iplddt:.2f} Alanine count: {alanine_count}" ) # 4. Save YAML for High ipTM @@ -648,7 +844,9 @@ def update_binder_sequence(new_seq): "alanine_count": alanine_count, "sequence": seq, "pdb_filename": pdb_base_name, # Log the *new* filename - "yaml_filename": yaml_base_name + "yaml_filename": yaml_base_name, + "ipae": curr_ipae, + "ipsae_min": curr_ipsae_min, } # Check if file exists to write header only once @@ -676,11 +874,15 @@ def update_binder_sequence(new_seq): .numpy()[0] ) run_metrics["best_seq"] = best_seq + run_metrics["best_ipae"] = float(best_ipae) + run_metrics["best_ipsae_min"] = float(best_ipsae_min) else: run_metrics["best_iptm"] = float("nan") run_metrics["best_cycle"] = None run_metrics["best_plddt"] = float("nan") run_metrics["best_seq"] = None + run_metrics["best_ipae"] = float("nan") + run_metrics["best_ipsae_min"] = float("nan") if a.plot: plot_run_metrics(run_save_dir, a.name, run_id, a.num_cycles, run_metrics) @@ -700,10 +902,12 @@ def _save_summary_metrics(self, all_run_metrics): f"cycle_{i}_iplddt", f"cycle_{i}_alanine", f"cycle_{i}_seq", + f"cycle_{i}_ipae", + f"cycle_{i}_ipsae_min", ] ) # Best metric columns - columns.extend(["best_iptm", "best_cycle", "best_plddt", "best_seq"]) + columns.extend(["best_iptm", "best_cycle", "best_plddt", "best_seq", "best_ipae", "best_ipsae_min",]) summary_csv = os.path.join(self.save_dir, "summary_all_runs.csv") df = pd.DataFrame(all_run_metrics)