Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ sample(
The [`physical_steering.yaml`](./src/bioemu/config/steering/physical_steering.yaml) configuration provides potentials for physical realism:
- **ChainBreak**: Prevents backbone discontinuities
- **ChainClash**: Avoids steric clashes between non-neighboring residues
- **DisulfideBridge**: Encourages disulfide bond formation between specified cysteine pairs

You can create a custom `steering_config.yaml` YAML file instantiating your own potential to steer the system with your own potentials.

Expand Down
111 changes: 0 additions & 111 deletions notebooks/disulfide_bridge_steering_example.py

This file was deleted.

14 changes: 0 additions & 14 deletions src/bioemu/config/steering/disulfide_bridge_steering.yaml

This file was deleted.

161 changes: 0 additions & 161 deletions src/bioemu/convert_chemgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,75 +215,6 @@ def compute_backbone(
return atom37_bb_pos, atom37_mask


def batch_frames_to_atom37(
pos: torch.Tensor, rot: torch.Tensor, seq: str
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fully batched transformation of backbone frame parameterization (pos, rot, seq) into atom37 coordinates.
All samples in the batch must have the same sequence.

Args:
pos: Tensor of shape (batch, L, 3) - backbone frame positions in nm
rot: Tensor of shape (batch, L, 3, 3) - backbone frame orientations
seq: String of length L - amino acid sequence (same for all samples in batch)

Returns:
atom37: Tensor of shape (batch, L, 37, 3) - atom coordinates in Angstroms
atom37_mask: Tensor of shape (batch, L, 37) - atom masks
aatype: Tensor of shape (batch, L) - residue types (same across batch)

Example to denoise all backbone atoms from a batch of structures:
x0_t, R0_t = get_pos0_rot0(
sdes=sdes, batch=batch, t=t, score=score
) # batch -> x0_t:(batch_size, seq_length, 3), R0_t:(batch_size, seq_length, 3, 3)

# Reconstruct heavy backbone atom positions, nm to Angstrom conversion
atom37, _, _ = batch_frames_to_atom37(pos=10 * x0_t, rot=R0_t, seq=batch.sequence[0])
N_pos, Ca_pos, C_pos, O_pos = (
atom37[..., 0, :],
atom37[..., 1, :],
atom37[..., 2, :],
atom37[..., 4, :],
) # [BS, L, 4, 3] -> [BS, L, 3] for N,Ca,C,O
"""
batch_size, L, _ = pos.shape
assert rot.shape == (
batch_size,
L,
3,
3,
), f"Expected rot shape {(batch_size, L, 3, 3)}, got {rot.shape}"
assert (
isinstance(seq, str) and len(seq) == L
), f"Sequence must be a string of length {L}, got {type(seq)} of length {len(seq) if isinstance(seq, str) else 'N/A'}"
device = pos.device

# Convert sequence to aatype tensor (L,) then broadcast to (batch, L)
aatype_single = torch.tensor(
[residue_constants.restype_order.get(x, 0) for x in seq], device=device
)
aatype = aatype_single.unsqueeze(0).expand(batch_size, -1) # (batch, L)

# Create Rigid objects - these support arbitrary batch dimensions
rots = Rotation(rot_mats=rot) # (batch, L, 3, 3)
rigids = Rigid(rots=rots, trans=pos) # (batch, L, 3), (batch, L, 3, 3)

# Compute backbone atoms - this already supports batching
psi_torsions = torch.zeros(batch_size, L, 2, device=device)
atom_37, atom_37_mask = compute_backbone(
bb_rigids=rigids,
psi_torsions=psi_torsions,
aatype=aatype,
)

# atom_37 is now (batch, L, 37, 3), atom_37_mask is (batch, L, 37)

# Adjust oxygen positions using batched version
atom_37 = _batch_adjust_oxygen_pos(atom_37, pos_is_known=None)

return atom_37, atom_37_mask, aatype


def _adjust_oxygen_pos(
atom_37: torch.Tensor, pos_is_known: torch.Tensor | None = None
) -> torch.Tensor:
Expand Down Expand Up @@ -366,98 +297,6 @@ def _adjust_oxygen_pos(
return atom_37


def _batch_adjust_oxygen_pos(
atom_37: torch.Tensor, pos_is_known: torch.Tensor | None = None
) -> torch.Tensor:
"""
Batched version of _adjust_oxygen_pos that handles multiple structures simultaneously.

Imputes the position of the oxygen atom on the backbone by using adjacent frame information.
Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the
current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom
away from the C in the current frame in the direction away from the Ca-C-N triangle.

For cases where the next frame is not available, for example we are at the C-terminus or the
next frame is not available in the data then we place the oxygen in the same plane as the
N-Ca-C of the current frame and pointing in the same direction as the average of the
Ca->C and Ca->N vectors.

Args:
atom_37 (torch.Tensor): (B, N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering
which is ['N', 'CA', 'C', 'CB', 'O', ...]. In Angstroms.
pos_is_known (torch.Tensor): (B, N) mask for known residues, or None.

Returns:
atom_37 (torch.Tensor): (B, N, 37, 3) with adjusted oxygen positions.
"""
B, N = atom_37.shape[0], atom_37.shape[1]
assert atom_37.shape == (B, N, 37, 3)

# Get vectors to Carbonyl from Carbon alpha and N of next residue. (B, N-1, 3)
# Note that the (N,) ordering is from N-terminal to C-terminal.

# Calpha to carbonyl both in the current frame. (B, N-1, 3)
calpha_to_carbonyl = (atom_37[:, :-1, 2, :] - atom_37[:, :-1, 1, :]) / (
torch.norm(atom_37[:, :-1, 2, :] - atom_37[:, :-1, 1, :], keepdim=True, dim=2) + 1e-7
)
# For masked positions, they are all 0 and so we add 1e-7 to avoid division by 0.
# The positions are in Angstroms and so are on the order ~1 so 1e-7 is an insignificant change.

# Nitrogen of the next frame to carbonyl of the current frame. (B, N-1, 3)
nitrogen_to_carbonyl = (atom_37[:, :-1, 2, :] - atom_37[:, 1:, 0, :]) / (
torch.norm(atom_37[:, :-1, 2, :] - atom_37[:, 1:, 0, :], keepdim=True, dim=2) + 1e-7
)

carbonyl_to_oxygen = calpha_to_carbonyl + nitrogen_to_carbonyl # (B, N-1, 3)
carbonyl_to_oxygen = carbonyl_to_oxygen / (
torch.norm(carbonyl_to_oxygen, dim=2, keepdim=True) + 1e-7
)

atom_37[:, :-1, 4, :] = atom_37[:, :-1, 2, :] + carbonyl_to_oxygen * C_O_BOND_LENGTH

# Now we deal with frames for which there is no next frame available.

# Calpha to carbonyl both in the current frame. (B, N, 3)
calpha_to_carbonyl_term = (atom_37[:, :, 2, :] - atom_37[:, :, 1, :]) / (
torch.norm(atom_37[:, :, 2, :] - atom_37[:, :, 1, :], keepdim=True, dim=2) + 1e-7
)
# Calpha to nitrogen both in the current frame. (B, N, 3)
calpha_to_nitrogen_term = (atom_37[:, :, 0, :] - atom_37[:, :, 1, :]) / (
torch.norm(atom_37[:, :, 0, :] - atom_37[:, :, 1, :], keepdim=True, dim=2) + 1e-7
)
carbonyl_to_oxygen_term = calpha_to_carbonyl_term + calpha_to_nitrogen_term # (B, N, 3)
carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / (
torch.norm(carbonyl_to_oxygen_term, dim=2, keepdim=True) + 1e-7
)

# Create a mask that is 1 when the next residue is not available either
# due to this frame being the C-terminus or the next residue is not
# known due to pos_is_known being false.

if pos_is_known is None:
pos_is_known = torch.ones((B, N), dtype=torch.int64, device=atom_37.device)

next_res_gone = ~pos_is_known.bool() # (B, N)
next_res_gone = torch.cat(
[next_res_gone, torch.ones((B, 1), device=pos_is_known.device).bool()], dim=1
) # (B, N+1)
next_res_gone = next_res_gone[:, 1:] # (B, N)

# Use masking to apply the terminal oxygen calculation
# next_res_gone shape: (B, N), we need to expand for broadcasting
next_res_gone_expanded = next_res_gone.unsqueeze(-1) # (B, N, 1)

# Apply the terminal calculation where needed
terminal_oxygen_pos = (
atom_37[:, :, 2, :] + carbonyl_to_oxygen_term * C_O_BOND_LENGTH
) # (B, N, 3)
atom_37[:, :, 4, :] = torch.where(
next_res_gone_expanded, terminal_oxygen_pos, atom_37[:, :, 4, :]
)

return atom_37


def _get_frames_non_clash_kdtree(
traj: mdtraj.Trajectory, clash_distance_angstrom: float
) -> np.ndarray:
Expand Down
Loading
Loading