forked from mc2-project/muse
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathend_to_end.py
More file actions
135 lines (107 loc) · 4.85 KB
/
Copy pathend_to_end.py
File metadata and controls
135 lines (107 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import numpy as np
import sys
def evaluate_linear_layer(layer, state):
return np.dot(layer, state)
def evaluate_relu(state):
return (state > 0) * state
def evaluate_malleated_relu(state, shift=20):
state = state + shift
state = (state > 0) * state
state = state - shift
return state
def evaluate_masked_relu(state, mask_start, mask_stop, shift=20):
mask = np.full(state.shape, -shift)
mask[mask_start:mask_stop] = shift
state = state + mask
state = (state > 0) * state
mask = (mask > 0) * mask
state = state - mask
return state
def evaluate_network_upto(linear_layers, state, stop_at):
for (i, layer) in enumerate(linear_layers[:stop_at]):
state = evaluate_linear_layer(layer, state)
# only perform ReLU if we're not the last layer
if i != len(layer) - 1:
state = evaluate_relu(state)
return state
def evaluate_network_after_malleation(linear_layers, state, start_at,
shift=20):
for (i, layer) in enumerate(linear_layers[start_at:]):
state = evaluate_linear_layer(layer, state)
# only perform ReLU if we're not the last layer
if i != len(layer) - 1:
state = evaluate_malleated_relu(state, shift)
return state
def unit_vector(dim, i):
vec = np.zeros(dim)
vec[i] = 1.0
return vec
def extract_network(linear_layers):
starting_dim = linear_layers[0].shape[1]
num_classes = linear_layers[-1].shape[0]
initial_state = np.zeros((starting_dim, 1))
extracted_layers = []
num_queries = 0
# We iterate in reverse.
for (i, layer) in list(enumerate(linear_layers))[::-1]:
(num_rows, num_cols) = layer.shape
extracted_layer = np.zeros(layer.shape)
# If we haven't extracted the last layer yet:
if len(extracted_layers) == 0:
# this is the simple case
for col in range(0, num_cols):
state = initial_state
num_queries += 1
last_state = evaluate_network_upto(linear_layers, state, i)
# At this point, the `last_state` should be all-zero vector.
# To extract the column number `col`, we set the col-th column
# of `last_state` to be 1.
last_state = last_state + unit_vector(last_state.shape, col)
result = evaluate_linear_layer(layer, last_state)
# update extracted_layer with results
for row in range(0, num_rows):
extracted_layer[row, col] = result[row, 0]
else:
# we are now recovering intermediate layers
next_matrix = np.identity(linear_layers[i+1].shape[1])
for _layer in linear_layers[i + 1:]:
next_matrix = np.dot(_layer, next_matrix)
assert(next_matrix.shape[0] == num_classes)
for col in range(0, num_cols):
for row in range(0, num_rows, num_classes):
state = initial_state
num_queries += 1
state = evaluate_network_upto(linear_layers, state, i)
# At this point, the `last_state` should be all-zero
# vector.
#
# To extract elements of the column `col`, we set the col-th
# column of `last_state` to be 1.
state = state + unit_vector(state.shape, col)
state = evaluate_linear_layer(layer, state)
# At this point, we have all the rows in column i.
# However, because eventually we'll only obtain information
# about `num_classes` rows at a time, we mask out the rest.
start = (num_rows - num_classes
if row + num_classes > num_rows
else row)
end = min(row + num_classes, num_rows)
state = evaluate_masked_relu(state, start, end)
# evaluate the rest of the network
result = evaluate_network_after_malleation(linear_layers, state, i + 1)
sub_matrix = next_matrix[:, start:end]
result = np.linalg.solve(sub_matrix, result)
extracted_layer[start:end, col] = result.reshape((num_classes,))
extracted_layers.append(extracted_layer)
extracted_layers.reverse()
print(num_queries)
return extracted_layers
if __name__ == '__main__':
sizes = list(map(int, sys.argv[1].split("-")))
dimensions = [tuple([x]) for x in sizes]
layers = []
for (row, col) in zip(sizes[1:], sizes):
layers.append(np.random.rand(row, col))
extracted_layers = extract_network(layers)
for (layer, extracted_layer) in zip(layers, extracted_layers):
assert(np.allclose(layer, extracted_layer))