Skip to content

Commit a221b2a

Browse files
authored
Make scalar bias term optional in _batchnorm.py (e3nn#416)
1 parent bb71b3c commit a221b2a

File tree

6 files changed

+17
-16
lines changed

6 files changed

+17
-16
lines changed

ChangeLog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased]
8+
### Added
9+
- Optional scalar bias term in `_batchnorm.py`
810

911
## [0.5.1] - 2022-12-12
1012
### Added

e3nn/nn/_batchnorm.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class BatchNorm(nn.Module):
3232
3333
instance : bool
3434
apply instance norm instead of batch norm
35+
36+
include_bias : bool
37+
include a bias term for batch norm of scalars
38+
39+
normalization : str
40+
which normalization method to apply (i.e., `norm` or `component`)
3541
"""
3642

3743
__constants__ = ["instance", "normalization", "irs", "affine"]
@@ -44,6 +50,7 @@ def __init__(
4450
affine: bool = True,
4551
reduce: str = "mean",
4652
instance: bool = False,
53+
include_bias: bool = True,
4754
normalization: str = "component",
4855
) -> None:
4956
super().__init__()
@@ -53,6 +60,7 @@ def __init__(
5360
self.momentum = momentum
5461
self.affine = affine
5562
self.instance = instance
63+
self.include_bias = include_bias
5664

5765
num_scalar = sum(mul for mul, ir in self.irreps if ir.is_scalar())
5866
num_features = self.irreps.num_irreps
@@ -67,10 +75,12 @@ def __init__(
6775

6876
if affine:
6977
self.weight = nn.Parameter(torch.ones(num_features))
70-
self.bias = nn.Parameter(torch.zeros(num_scalar))
78+
if self.include_bias:
79+
self.bias = nn.Parameter(torch.zeros(num_scalar))
7180
else:
7281
self.register_parameter("weight", None)
73-
self.register_parameter("bias", None)
82+
if self.include_bias:
83+
self.register_parameter("bias", None)
7484

7585
assert isinstance(reduce, str), "reduce should be passed as a string value"
7686
assert reduce in ["mean", "max"], "reduce needs to be 'mean' or 'max'"
@@ -171,7 +181,7 @@ def forward(self, input) -> torch.Tensor:
171181

172182
field = field * field_norm.reshape(-1, 1, mul, 1) # [batch, sample, mul, repr]
173183

174-
if self.affine and is_scalar:
184+
if self.affine and self.include_bias and is_scalar:
175185
bias = self.bias[ib : ib + mul] # [mul]
176186
ib += mul
177187
field += bias.reshape(mul, 1) # [batch, sample, mul, repr]
@@ -185,7 +195,8 @@ def forward(self, input) -> torch.Tensor:
185195
torch._assert(irv == self.running_var.size(0), "irv == self.running_var.size(0)")
186196
if self.affine:
187197
torch._assert(iw == self.weight.size(0), "iw == self.weight.size(0)")
188-
torch._assert(ib == self.bias.numel(), "ib == self.bias.numel()")
198+
if self.include_bias:
199+
torch._assert(ib == self.bias.numel(), "ib == self.bias.numel()")
189200

190201
if self.training and not self.instance:
191202
if len(new_means) > 0:

e3nn/o3/_tensor_product/_sub.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ class ElementwiseTensorProduct(TensorProduct):
102102
"""
103103

104104
def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, irrep_normalization: str = None, **kwargs) -> None:
105-
106105
irreps_in1 = o3.Irreps(irreps_in1).simplify()
107106
irreps_in2 = o3.Irreps(irreps_in2).simplify()
108107
if filter_ir_out is not None:
@@ -135,7 +134,6 @@ def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, irrep_normalizati
135134
for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)):
136135
assert mul == mul_2
137136
for ir in ir_1 * ir_2:
138-
139137
if filter_ir_out is not None and ir not in filter_ir_out:
140138
continue
141139

@@ -179,7 +177,6 @@ def __init__(
179177
irrep_normalization: str = None,
180178
**kwargs,
181179
) -> None:
182-
183180
irreps_in1 = o3.Irreps(irreps_in1).simplify()
184181
irreps_in2 = o3.Irreps(irreps_in2).simplify()
185182
if filter_ir_out is not None:
@@ -193,7 +190,6 @@ def __init__(
193190
for i_1, (mul_1, ir_1) in enumerate(irreps_in1):
194191
for i_2, (mul_2, ir_2) in enumerate(irreps_in2):
195192
for ir_out in ir_1 * ir_2:
196-
197193
if filter_ir_out is not None and ir_out not in filter_ir_out:
198194
continue
199195

@@ -238,7 +234,6 @@ def _square_instructions_full(irreps_in, filter_ir_out=None, irrep_normalization
238234
for i_1, (mul_1, ir_1) in enumerate(irreps_in):
239235
for i_2, (mul_2, ir_2) in enumerate(irreps_in):
240236
for ir_out in ir_1 * ir_2:
241-
242237
if filter_ir_out is not None and ir_out not in filter_ir_out:
243238
continue
244239

@@ -311,7 +306,6 @@ def _square_instructions_fully_connected(irreps_in, irreps_out, irrep_normalizat
311306
for i_2, (_mul_2, ir_2) in enumerate(irreps_in):
312307
for i_out, (_mul_out, ir_out) in enumerate(irreps_out):
313308
if ir_out in ir_1 * ir_2:
314-
315309
if irrep_normalization == "component":
316310
alpha = ir_out.dim
317311
if irrep_normalization == "norm":
@@ -374,7 +368,6 @@ def __init__(
374368
irrep_normalization: str = None,
375369
**kwargs,
376370
) -> None:
377-
378371
if irrep_normalization is None:
379372
irrep_normalization = "component"
380373

examples/s2cnn/mnist/gendata.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,13 @@ def main() -> None:
199199
no_rotate = {"train": args.no_rotate_train, "test": args.no_rotate_test}
200200

201201
for label, data in zip(["train", "test"], [mnist_train, mnist_test]):
202-
203202
print(f"projecting {label} data set")
204203
current = 0
205204
signals = data["images"].reshape(-1, 28, 28).astype(np.float64)
206205
n_signals = signals.shape[0]
207206
projections = np.ndarray((signals.shape[0], 2 * args.bandwidth, 2 * args.bandwidth), dtype=np.uint8)
208207

209208
while current < n_signals:
210-
211209
if not no_rotate[label]:
212210
rot = rand_rotation_matrix(deflection=args.noise)
213211
rotated_grid = rotate_grid(rot, grid)

examples/s2cnn/mnist/train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def forward(self, x):
135135

136136

137137
def load_data(path, batch_size):
138-
139138
with gzip.open(path, "rb") as f:
140139
dataset = pickle.load(f)
141140

@@ -192,7 +191,6 @@ def main() -> None:
192191
correct = 0
193192
total = 0
194193
for images, labels in test_loader:
195-
196194
classifier.eval()
197195

198196
with torch.no_grad():

tests/o3/cartesian_spherical_harmonics_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def func(pos):
7777

7878
@pytest.mark.parametrize("l", range(10 + 1))
7979
def test_normalization(float_tolerance, l) -> None:
80-
8180
n = o3.spherical_harmonics(l, torch.randn(3), normalize=True, normalization="integral").pow(2).mean()
8281
assert abs(n - 1 / (4 * math.pi)) < float_tolerance
8382

0 commit comments

Comments
 (0)