Skip to content

Commit eae0f47

Browse files
committed
Added MLP; FIX: Respect bias argument
1 parent 81dedfb commit eae0f47

File tree

4 files changed

+92
-8
lines changed

4 files changed

+92
-8
lines changed

deepxml/libs/parameters.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,12 @@ def _construct(self):
329329
'--validate',
330330
action='store_true',
331331
help='Validate or just train')
332+
self.parser.add_argument(
333+
'--bias',
334+
action='store',
335+
default=True,
336+
type=bool,
337+
help='Use bias term or not!')
332338
self.parser.add_argument(
333339
'--shuffle',
334340
action='store',

deepxml/models/mlp.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
__author__ = 'KD'
6+
7+
8+
class MLP(nn.Module):
9+
"""
10+
A multi-layer perceptron with flexibility for non-liearity
11+
* no non-linearity after last layer
12+
* support for 2D or 3D inputs
13+
14+
Parameters:
15+
-----------
16+
input_size: int
17+
input size of embeddings
18+
hidden_size: int or list of ints or str (comma separated)
19+
e.g., 512: a single hidden layer with 512 neurons
20+
"512": a single hidden layer with 512 neurons
21+
"512,300": 512 -> nnl -> 300
22+
[512, 300]: 512 -> nnl -> 300
23+
dimensionality of layers in MLP
24+
nnl: str, optional, default='relu'
25+
which non-linearity to use
26+
device: str, default="cuda:0"
27+
keep on this device
28+
"""
29+
def __init__(self, input_size, hidden_size, nnl='relu', device="cuda:0"):
30+
super(MLP, self).__init__()
31+
hidden_size = self.parse_hidden_size(hidden_size)
32+
assert len(hidden_size) >= 1, "Should contain atleast 1 hidden layer"
33+
hidden_size = [input_size] + hidden_size
34+
self.device = torch.device(device)
35+
layers = []
36+
for i, (i_s, o_s) in enumerate(zip(hidden_size[:-1], hidden_size[1:])):
37+
layers.append(nn.Linear(i_s, o_s, bias=True))
38+
if i < len(hidden_size) - 2:
39+
layers.append(self._get_nnl(nnl))
40+
self.transform = torch.nn.Sequential(*layers)
41+
42+
def parse_hidden_size(self, hidden_size):
43+
if isinstance(hidden_size, int):
44+
return [hidden_size]
45+
elif isinstance(hidden_size, str):
46+
_hidden_size = []
47+
for item in hidden_size.split(","):
48+
_hidden_size.append(int(item))
49+
return _hidden_size
50+
elif isinstance(hidden_size, list):
51+
return hidden_size
52+
else:
53+
raise NotImplementedError("hidden_size must be a int, str or list")
54+
55+
def _get_nnl(self, nnl):
56+
if nnl == 'sigmoid':
57+
return torch.nn.Sigmoid()
58+
elif nnl == 'relu':
59+
return torch.nn.ReLU()
60+
elif nnl == 'gelu':
61+
return torch.nn.GELU()
62+
elif nnl == 'tanh':
63+
return torch.nn.Tanh()
64+
else:
65+
raise NotImplementedError(f"{nnl} not implemented!")
66+
67+
def forward(self, x):
68+
return self.transform(x)
69+
70+
def to(self):
71+
"""Transfer to device
72+
"""
73+
super().to(self.device)
74+
75+
@property
76+
def sparse(self):
77+
return False

deepxml/models/network.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import torch.nn as nn
3-
import numpy as np
43
import math
54
import os
65
import models.transform_layer as transform_layer
@@ -154,6 +153,7 @@ def __init__(self, params):
154153
trans_config_coarse = transform_config_dict['transform_coarse']
155154
self.representation_dims = int(
156155
transform_config_dict['representation_dims'])
156+
self._bias = params.bias
157157
super(DeepXMLf, self).__init__(trans_config_coarse)
158158
if params.freeze_intermediate:
159159
print("Freezing intermediate model parameters!")
@@ -226,7 +226,7 @@ def forward(self, batch_data, bypass_coarse=False):
226226

227227
def _construct_classifier(self):
228228
if self.num_clf_partitions > 1: # Run the distributed version
229-
_bias = [True for _ in range(self.num_clf_partitions)]
229+
_bias = [self._bias for _ in range(self.num_clf_partitions)]
230230
_clf_devices = ["cuda:{}".format(
231231
idx) for idx in range(self.num_clf_partitions)]
232232
return linear_layer.ParallelLinear(
@@ -239,7 +239,7 @@ def _construct_classifier(self):
239239
return linear_layer.Linear(
240240
input_size=self.representation_dims,
241241
output_size=self.num_labels, # last one is padding index
242-
bias=True
242+
bias=self._bias
243243
)
244244

245245
def get_token_embeddings(self):
@@ -297,6 +297,7 @@ def __init__(self, params):
297297
trans_config_coarse = transform_config_dict['transform_coarse']
298298
self.representation_dims = int(
299299
transform_config_dict['representation_dims'])
300+
self._bias = params.bias
300301
super(DeepXMLs, self).__init__(trans_config_coarse)
301302
if params.freeze_intermediate:
302303
print("Freezing intermediate model parameters!")
@@ -383,7 +384,7 @@ def _construct_classifier(self):
383384
# last one is padding index for each partition
384385
_num_labels = self.num_labels + offset
385386
_padding_idx = [None for _ in range(self.num_clf_partitions)]
386-
_bias = [True for _ in range(self.num_clf_partitions)]
387+
_bias = [self._bias for _ in range(self.num_clf_partitions)]
387388
_clf_devices = ["cuda:{}".format(
388389
idx) for idx in range(self.num_clf_partitions)]
389390
return linear_layer.ParallelSparseLinear(
@@ -399,7 +400,7 @@ def _construct_classifier(self):
399400
input_size=self.representation_dims,
400401
output_size=self.num_labels + offset,
401402
padding_idx=self.label_padding_index,
402-
bias=True)
403+
bias=self._bias)
403404

404405
def to(self):
405406
"""Send layers to respective devices

deepxml/models/transform_layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import sys
21
import re
32
import torch.nn as nn
43
import models.residual_layer as residual_layer
54
import models.astec as astec
65
import json
7-
from collections import OrderedDict
6+
import models.mlp as mlp
87

98

109
class _Identity(nn.Module):
@@ -38,7 +37,8 @@ def initialize(self, *args, **kwargs):
3837
'residual': residual_layer.Residual,
3938
'identity': Identity,
4039
'_identity': _Identity,
41-
'astec': astec.Astec
40+
'astec': astec.Astec,
41+
'mlp': mlp.MLP
4242
}
4343

4444

0 commit comments

Comments
 (0)