Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions boltz_ph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,8 @@ def plot_run_metrics(
run_save_dir: str, name: str, run_id: int, num_cycles: int, run_metrics: dict
):
"""Plots per-run metrics (iPTM, pLDDT, Alanine Count) over design cycles."""
fig, axs = plt.subplots(1, 4, figsize=(16, 4)) # Increased figure width
colors = ["#9B59B6", "#E94560", "#FF7F11", "#2ECC71"]
fig, axs = plt.subplots(1, 5, figsize=(20, 4)) # Increased figure width
colors = ["#9B59B6", "#E94560", "#FF7F11", "#2ECC71", "#1673EC"]

# Helper to retrieve data and format
def get_metric_data(key_suffix, label, ymin, ymax, fmt):
Expand All @@ -673,6 +673,7 @@ def get_metric_data(key_suffix, label, ymin, ymax, fmt):
max([run_metrics.get(f"cycle_{i}_alanine", 0) for i in range(num_cycles + 1)]) + 2,
"{}",
),
get_metric_data("ipsae_min", "ipSAE_min", 0, 1, "{:.3f}"),
]

design_cycles = list(range(num_cycles + 1))
Expand Down
216 changes: 210 additions & 6 deletions boltz_ph/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,157 @@
smart_split,
)

# --- ipAE / ipSAE helpers

def extract_cb_coords(structure, coords):

atom_coords = coords[0].detach().cpu().numpy() # (N_atoms, 3)
cb_list = []

for chain in structure.chains:
res_start = chain["res_idx"]
res_end = res_start + chain["res_num"]

for res in structure.residues[res_start:res_end]:
a0 = res["atom_idx"]
a1 = a0 + res["atom_num"]
residue_atoms = structure.atoms[a0:a1]

cb_coord = None
ca_coord = None

for local_i, atom in enumerate(residue_atoms):
name = atom["name"]

if name == "CB":
cb_coord = atom_coords[a0 + local_i]
break # prefer CB immediately

if name == "CA":
ca_coord = atom_coords[a0 + local_i]

# Use CB if present
if cb_coord is not None:
cb_list.append(cb_coord)
continue

# Use CA if present (covers Gly automatically)
if ca_coord is not None:
cb_list.append(ca_coord)
continue

# Fallback: centroid (same as extract_ca_coords)
meanpos = atom_coords[a0:a1].mean(axis=0)
cb_list.append(meanpos)

return np.array(cb_list)




def compute_best_target_ip_metrics(pae, coords, structure,
target_lengths, binder_len):
"""
Corrected version: assumes chain ordering is:
[binder][target1][target2]...[targetN]

Parameters
----------
pae : torch.Tensor [1, L, L]
coords : torch.Tensor [1, N_atoms, 3]
structure : Boltz structure
target_lengths : list[int]
binder_len : int

Returns
-------
best_ipae : float
best_ipsae_min : float
best_target_index : int
"""

pae_np = pae.detach().cpu().numpy()[0]

# Extract CA coordinates (same code as your script)
cb = extract_cb_coords(structure, coords)
assert cb.shape[0] == (binder_len + sum(target_lengths))

# binder first in the chain order
B = np.arange(0, binder_len)

best_ipae = -1e9
best_ipsae_min = 0.0
best_t = -1

# offset starts AFTER binder
offset = binder_len

# Precompute distance matrix
diff = cb[:, None, :] - cb[None, :, :]
dist = np.linalg.norm(diff, axis=-1)

for t_idx, L_T in enumerate(target_lengths):

A = np.arange(offset, offset + L_T)
offset += L_T

# --- ipAE ---
block_AB = pae_np[np.ix_(A, B)]
block_BA = pae_np[np.ix_(B, A)]
ipae = float(np.mean(np.concatenate([block_AB.ravel(), block_BA.ravel()])))

# --- ipSAE_min ---
def _asym_d0res(S, T):
"""
Ground-truth ipSAE_d0res_asym used in ipsae.py:

For each residue i in S:
- find valid partner residues j in T with:
pae[i,j] < 15
dist[i,j] < 15 Å
- n0res(i) = number of such residues j
- d0(i) = calc_d0(n0res(i)) with minimum 1.0
- ipSAE(i) = mean_j PTM(pae[i,j], d0(i))
Score = max_i ipSAE(i)
"""
best = 0.0

for i in S:
# indices j in T that satisfy both cutoffs
mask = (pae_np[i, T] < 15.0) & (dist[i, T] < 15.0)
if not np.any(mask):
continue

vals = pae_np[i, T][mask]
n0res = len(vals)

# --- d0res formula ---
L_eff = max(n0res, 27)
d0 = max(1.0, 1.24 * (L_eff - 15)**(1/3) - 1.8)

# PTM scores
ptm = 1.0 / (1.0 + (vals / d0)**2)
score_i = float(ptm.mean())

if score_i > best:
best = score_i

return best


a_to_b = _asym_d0res(A, B)
b_to_a = _asym_d0res(B, A)
ipsae_min = float(min(a_to_b, b_to_a))

# Track best
if ipae > best_ipae:
best_ipae = ipae
best_ipsae_min = ipsae_min
best_t = t_idx

return best_ipae, best_ipsae_min


class InputDataBuilder:
"""Handles parsing command-line arguments and constructing the base Boltz input data dictionary."""

Expand Down Expand Up @@ -114,7 +265,6 @@ def _build_conditional_data(self):

# Assign chain IDs to proteins first
protein_chain_ids = [chr(ord('B') + i) for i in range(len(protein_seqs_list))]

# Find next available chain letter
next_chain_idx = len(protein_chain_ids)

Expand Down Expand Up @@ -310,7 +460,7 @@ def _load_boltz_model(self):
"sampling_steps": self.args.diffuse_steps,
"diffusion_samples": 1,
"write_confidence_summary": True,
"write_full_pae": False,
"write_full_pae": True,
"write_full_pde": False,
"max_parallel_samples": 1,
}
Expand Down Expand Up @@ -343,6 +493,8 @@ def _run_design_cycle(self, data_cp, run_id, pocket_conditioning):
best_pdb_filename = None
best_cycle_idx = -1
best_alanine_percentage = None
best_ipae = float("-inf")
best_ipsae_min = float("-inf")
run_metrics = {"run_id": run_id}


Expand Down Expand Up @@ -437,9 +589,27 @@ def update_binder_sequence(new_seq):

clean_memory() # <-- ADD THIS CALL HERE
# Capture Cycle 0 metrics

# -------- Determine target_lengths and binder_len --------
protein_seqs = []
binder_seq = None

for entry in data_cp["sequences"]:
if "protein" in entry:
cid = entry["protein"]["id"][0]
seq = entry["protein"]["sequence"]
if cid == self.binder_chain:
binder_len = len(seq)
else:
protein_seqs.append(seq)

target_lengths = [len(s) for s in protein_seqs]
# ---------------------------------------------------------


binder_chain_idx = CHAIN_TO_NUMBER[self.binder_chain]
pair_chains = output["pair_chains_iptm"]

# Calculate i-pTM
if len(pair_chains) > 1:
values = [
Expand All @@ -454,6 +624,15 @@ def update_binder_sequence(new_seq):
cycle_0_iptm = float(np.mean(values) if values else 0.0)
else:
cycle_0_iptm = 0.0

cycle_0_ipae, cycle_0_ipsae_min = compute_best_target_ip_metrics(
output["pae"],
output["coords"],
structure,
target_lengths,
binder_len,
)


run_metrics["cycle_0_iptm"] = cycle_0_iptm
run_metrics["cycle_0_plddt"] = float(
Expand All @@ -464,6 +643,9 @@ def update_binder_sequence(new_seq):
)
run_metrics["cycle_0_alanine"] = 0

run_metrics["cycle_0_ipae"] = cycle_0_ipae
run_metrics["cycle_0_ipsae_min"] = cycle_0_ipsae_min

# --- Optimization Cycles ---
for cycle in range(a.num_cycles):
print(f"\n--- Run {run_id}, Cycle {cycle + 1} ---")
Expand Down Expand Up @@ -531,6 +713,13 @@ def update_binder_sequence(new_seq):
current_iptm = float(np.mean(values) if values else 0.0)
else:
current_iptm = 0.0
curr_ipae, curr_ipsae_min = compute_best_target_ip_metrics(
output["pae"],
output["coords"],
structure,
target_lengths,
binder_len,
)

# Update best structure (only if alanine content is acceptable)
if alanine_percentage <= 0.20 and current_iptm > best_iptm:
Expand Down Expand Up @@ -570,6 +759,13 @@ def update_binder_sequence(new_seq):
run_metrics[f"cycle_{cycle + 1}_iplddt"] = curr_iplddt
run_metrics[f"cycle_{cycle + 1}_alanine"] = alanine_count
run_metrics[f"cycle_{cycle + 1}_seq"] = seq
run_metrics[f"cycle_{cycle + 1}_ipae"] = curr_ipae
run_metrics[f"cycle_{cycle + 1}_ipsae_min"] = curr_ipsae_min

if curr_ipae > best_ipae:
best_ipae = curr_ipae
if curr_ipsae_min > best_ipsae_min:
best_ipsae_min = curr_ipsae_min

pdb_filename = (
f"{run_save_dir}/{a.name}_run_{run_id}_predicted_cycle_{cycle + 1}.pdb"
Expand All @@ -579,7 +775,7 @@ def update_binder_sequence(new_seq):
clean_memory()

print(
f"ipTM: {current_iptm:.2f} pLDDT: {curr_plddt:.2f} iPLDDT: {curr_iplddt:.2f} Alanine count: {alanine_count}"
f"ipTM: {current_iptm:.2f} ipAE: {curr_ipae:.2f} ipSAE_min: {curr_ipsae_min:.2f} pLDDT: {curr_plddt:.2f} iPLDDT: {curr_iplddt:.2f} Alanine count: {alanine_count}"
)

# 4. Save YAML for High ipTM
Expand Down Expand Up @@ -648,7 +844,9 @@ def update_binder_sequence(new_seq):
"alanine_count": alanine_count,
"sequence": seq,
"pdb_filename": pdb_base_name, # Log the *new* filename
"yaml_filename": yaml_base_name
"yaml_filename": yaml_base_name,
"ipae": curr_ipae,
"ipsae_min": curr_ipsae_min,
}

# Check if file exists to write header only once
Expand Down Expand Up @@ -676,11 +874,15 @@ def update_binder_sequence(new_seq):
.numpy()[0]
)
run_metrics["best_seq"] = best_seq
run_metrics["best_ipae"] = float(best_ipae)
run_metrics["best_ipsae_min"] = float(best_ipsae_min)
else:
run_metrics["best_iptm"] = float("nan")
run_metrics["best_cycle"] = None
run_metrics["best_plddt"] = float("nan")
run_metrics["best_seq"] = None
run_metrics["best_ipae"] = float("nan")
run_metrics["best_ipsae_min"] = float("nan")

if a.plot:
plot_run_metrics(run_save_dir, a.name, run_id, a.num_cycles, run_metrics)
Expand All @@ -700,10 +902,12 @@ def _save_summary_metrics(self, all_run_metrics):
f"cycle_{i}_iplddt",
f"cycle_{i}_alanine",
f"cycle_{i}_seq",
f"cycle_{i}_ipae",
f"cycle_{i}_ipsae_min",
]
)
# Best metric columns
columns.extend(["best_iptm", "best_cycle", "best_plddt", "best_seq"])
columns.extend(["best_iptm", "best_cycle", "best_plddt", "best_seq", "best_ipae", "best_ipsae_min",])

summary_csv = os.path.join(self.save_dir, "summary_all_runs.csv")
df = pd.DataFrame(all_run_metrics)
Expand Down