Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 14d1c8c

Browse files
Merge pull request #571 from lvdmaaten/patch-1
Add ReLU + masked convolution
2 parents bd7c15b + 7e56a94 commit 14d1c8c

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

python/examples/masked_conv.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) 2017-present, Facebook, Inc.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
##############################################################################
17+
import tensor_comprehensions as tc
18+
19+
import argparse
20+
import torch
21+
import torch.nn as nn
22+
import torch.nn.functional as functional
23+
24+
torch.backends.cudnn.benchmark = True
25+
26+
27+
def GetArgumentParser():
28+
parser = argparse.ArgumentParser(
29+
description='Lengths Cosine Coherence benchmark.'
30+
)
31+
parser.add_argument(
32+
'--tuner_threads', type=int, default=16,
33+
help='Number of CPU tuning threads.',
34+
)
35+
parser.add_argument(
36+
'--tuner_generations', type=int, default=25,
37+
help='Number of tuning generations.',
38+
)
39+
parser.add_argument(
40+
'--tuner_pop_size', type=int, default=100,
41+
help='Number candidates per tuning generations.',
42+
)
43+
parser.add_argument(
44+
'--tuner_number_elites', type=int, default=5,
45+
help='Number of best tuning candidates that survive each generation.',
46+
)
47+
parser.add_argument(
48+
'--tuner_devices', type=str, default='0',
49+
help='Comma separated list of tuning devices.',
50+
)
51+
parser.add_argument(
52+
'--tuner_cache_file',
53+
type=str,
54+
default='/tmp/cache_condensenet',
55+
help='File to store tuned mapping options',
56+
)
57+
return parser
58+
59+
60+
parser = GetArgumentParser()
61+
args, extra_args = parser.parse_known_args()
62+
63+
64+
###############################################################################
65+
# TC equivalent converting control-flow to data dependencies
66+
###############################################################################
67+
MASKED_CONVOLVE = '''
68+
def masked_convolve(float(B, C, H, W) Input,
69+
float(F, C, K, K) Weights,
70+
uint8(F, C) Mask) -> (Output) {
71+
Output(b, f, h, w) +=! (Mask(f, r_c) == 1) ?
72+
fmax(0.0, Input(b, r_c, h + r_k1, w + r_k2)) *
73+
Weights(f, r_c, r_k1, r_k2) :
74+
0.0
75+
}
76+
'''
77+
78+
###############################################################################
79+
# Implicit compilation and tuning behavior
80+
###############################################################################
81+
tuner_config = (
82+
tc.TunerConfig()
83+
.threads(args.tuner_threads)
84+
.generations(args.tuner_generations)
85+
.pop_size(args.tuner_pop_size)
86+
.number_elites(args.tuner_number_elites)
87+
.devices(args.tuner_devices))
88+
reinforce_list = ['']
89+
90+
91+
def generate_options(tc_str: str,
92+
entry_point: str,
93+
*inputs: torch.Tensor) -> tc.MappingOptions:
94+
global reinforce
95+
96+
# TODO: comment the line below which serves the purpose of not blowing up
97+
# CI time
98+
return tc.make_naive_options_factory()(tc_str, entry_point, *inputs)
99+
100+
if entry_point == 'make_idx':
101+
return tc.make_naive_options_factory()(tc_str, entry_point, *inputs)
102+
103+
loaded = tc.make_load_from_cache_options_factory(args.tuner_cache_file)(
104+
tc_str, entry_point, *inputs)
105+
106+
if loaded is None or entry_point in reinforce_list or '*' in reinforce_list:
107+
start = loaded if loaded is not None else 'naive'
108+
return tc.make_autotuned_options_factory(
109+
starting_options=start,
110+
tuner_config=tuner_config,
111+
cache_filename=args.tuner_cache_file,
112+
store_to_cache=True,)(tc_str, entry_point, *inputs)
113+
114+
assert loaded is not None, 'None found'
115+
116+
return loaded
117+
118+
119+
###############################################################################
120+
# Define the TC for MASKED_CONVOLVE
121+
###############################################################################
122+
TC = tc.define(MASKED_CONVOLVE, generate_options)
123+
124+
###############################################################################
125+
# Run with implicit compilation and tuning
126+
###############################################################################
127+
128+
# sizes:
129+
H, W, C, B, F, K = 56, 56, 128, 32, 32, 1
130+
131+
# Pytorch:
132+
conv = nn.Conv2d(C, F, K, 1, 0, 1, groups=1, bias=False).cuda()
133+
relu = nn.ReLU(inplace=True).cuda()
134+
input_data = torch.zeros(B, C, H, W).cuda(non_blocking=True)
135+
mask = torch.randn(F, C, K, K).gt_(0.).cuda(non_blocking=True)
136+
torch.cuda.synchronize()
137+
138+
weight = conv.weight * mask
139+
rectified_input = relu(input_data)
140+
output = functional.conv2d(rectified_input, weight, None, conv.stride,
141+
conv.padding, conv.dilation, 1)
142+
143+
# TC:
144+
InputData = input_data
145+
Weights = conv.weight
146+
Mask = mask.view(F, C).byte()
147+
torch.cuda.synchronize()
148+
Output = TC.masked_convolve(InputData, Weights, Mask)
149+
150+
151+
###############################################################################
152+
# Check
153+
###############################################################################
154+
tc.assert_almost_equal(
155+
output.cpu(),
156+
Output.cpu(),
157+
input_data.cpu(), conv.weight.cpu(), mask.cpu(),
158+
operations=C * K * K,
159+
precision=1e-7)
160+
161+
print('SUCCESS')

0 commit comments

Comments
 (0)