forked from arneschneuing/DiffSBDD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
geometry_utils.py
141 lines (113 loc) · 4.14 KB
/
geometry_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
from constants import CA_C_DIST, N_CA_DIST, N_CA_C_ANGLE
def rotation_matrix(angle, axis):
"""
Args:
angle: (n,)
axis: 0=x, 1=y, 2=z
Returns:
(n, 3, 3)
"""
n = len(angle)
R = np.eye(3)[None, :, :].repeat(n, axis=0)
axis = 2 - axis
start = axis // 2
step = axis % 2 + 1
s = slice(start, start + step + 1, step)
R[:, s, s] = np.array(
[[np.cos(angle), (-1) ** (axis + 1) * np.sin(angle)],
[(-1) ** axis * np.sin(angle), np.cos(angle)]]
).transpose(2, 0, 1)
return R
def get_bb_transform(n_xyz, ca_xyz, c_xyz):
"""
Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with
Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame
Args:
n_xyz: (n, 3)
ca_xyz: (n, 3)
c_xyz: (n, 3)
Returns:
quaternion represented as array of shape (n, 4)
translation vector which is an array of shape (n, 3)
"""
translation = ca_xyz
n_xyz = n_xyz - translation
c_xyz = c_xyz - translation
# Find rotation matrix that aligns the coordinate systems
# rotate around y-axis to move N into the xy-plane
theta_y = np.arctan2(n_xyz[:, 2], -n_xyz[:, 0])
Ry = rotation_matrix(theta_y, 1)
n_xyz = np.einsum('noi,ni->no', Ry.transpose(0, 2, 1), n_xyz)
# rotate around z-axis to move N onto the x-axis
theta_z = np.arctan2(n_xyz[:, 1], n_xyz[:, 0])
Rz = rotation_matrix(theta_z, 2)
# n_xyz = np.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz)
# rotate around x-axis to move C into the xy-plane
c_xyz = np.einsum('noj,nji,ni->no', Rz.transpose(0, 2, 1),
Ry.transpose(0, 2, 1), c_xyz)
theta_x = np.arctan2(c_xyz[:, 2], c_xyz[:, 1])
Rx = rotation_matrix(theta_x, 0)
# Final rotation matrix
R = np.einsum('nok,nkj,nji->noi', Ry, Rz, Rx)
# Convert to quaternion
# q = w + i*u_x + j*u_y + k * u_z
quaternion = rotation_matrix_to_quaternion(R)
return quaternion, translation
def get_bb_coords_from_transform(ca_coords, quaternion):
"""
Args:
ca_coords: (n, 3)
quaternion: (n, 4)
Returns:
backbone coords (n*3, 3), order is [N, CA, C]
backbone atom types as a list of length n*3
"""
R = quaternion_to_rotation_matrix(quaternion)
bb_coords = np.tile(np.array(
[[N_CA_DIST, 0, 0],
[0, 0, 0],
[CA_C_DIST * np.cos(N_CA_C_ANGLE), CA_C_DIST * np.sin(N_CA_C_ANGLE), 0]]),
[len(ca_coords), 1])
bb_coords = np.einsum('noi,ni->no', R.repeat(3, axis=0), bb_coords) + ca_coords.repeat(3, axis=0)
bb_atom_types = [t for _ in range(len(ca_coords)) for t in ['N', 'C', 'C']]
return bb_coords, bb_atom_types
def quaternion_to_rotation_matrix(q):
"""
x_rot = R x
Args:
q: (n, 4)
Returns:
R: (n, 3, 3)
"""
# Normalize
q = q / (q ** 2).sum(1, keepdims=True) ** 0.5
# https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
R = np.stack([
np.stack([1 - 2 * y ** 2 - 2 * z ** 2, 2 * x * y - 2 * z * w,
2 * x * z + 2 * y * w], axis=1),
np.stack([2 * x * y + 2 * z * w, 1 - 2 * x ** 2 - 2 * z ** 2,
2 * y * z - 2 * x * w], axis=1),
np.stack([2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w,
1 - 2 * x ** 2 - 2 * y ** 2], axis=1)
], axis=1)
return R
def rotation_matrix_to_quaternion(R):
"""
https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
Args:
R: (n, 3, 3)
Returns:
q: (n, 4)
"""
t = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
r = np.sqrt(1 + t)
w = 0.5 * r
x = np.sign(R[:, 2, 1] - R[:, 1, 2]) * np.abs(
0.5 * np.sqrt(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2]))
y = np.sign(R[:, 0, 2] - R[:, 2, 0]) * np.abs(
0.5 * np.sqrt(1 - R[:, 0, 0] + R[:, 1, 1] - R[:, 2, 2]))
z = np.sign(R[:, 1, 0] - R[:, 0, 1]) * np.abs(
0.5 * np.sqrt(1 - R[:, 0, 0] - R[:, 1, 1] + R[:, 2, 2]))
return np.stack((w, x, y, z), axis=1)