Skip to content

Commit ae44627

Browse files
goldsboroughfacebook-github-bot
authored andcommitted
Rm test_jit.cpp (pytorch#12988)
Summary: Removes test_jit.cpp, which was supposed to have been deleted in pytorch#12030 I had to move zou3519's dynamic DAG tests into `test/cpp/jit/tests.h` too. No other changes to `test_jit.cpp` seem to have happened in the meantime. zdevito Pull Request resolved: pytorch#12988 Differential Revision: D10854320 Pulled By: goldsborough fbshipit-source-id: 7ab533e6e494e34a16ce39bbe62b1150e48fcb58
1 parent 314d95a commit ae44627

File tree

5 files changed

+196
-1346
lines changed

5 files changed

+196
-1346
lines changed

test/cpp/jit/gtest.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ JIT_TEST(CodeTemplate)
1717
JIT_TEST(ControlFlow)
1818
JIT_TEST(CreateAutodiffSubgraphs)
1919
JIT_TEST(CustomOperators)
20-
JIT_TEST(SchemaParser)
2120
JIT_TEST(Differentiate)
2221
JIT_TEST(DifferentiateWithRequiresGrad)
22+
JIT_TEST(DynamicDAG)
2323
JIT_TEST(FromQualString)
2424
JIT_TEST(InternedStrings)
2525
JIT_TEST(IValue)
26+
JIT_TEST(SchemaParser)
2627
JIT_TEST(TopologicalIndex)
2728

2829
#define JIT_TEST_CUDA(name) \

test/cpp/jit/no-gtest.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@ std::string runJITCPPTests() {
1515
testControlFlow();
1616
testCreateAutodiffSubgraphs(out);
1717
testCustomOperators();
18-
testSchemaParser();
1918
testDifferentiate(out);
2019
testDifferentiateWithRequiresGrad(out);
20+
testDynamicDAG();
2121
testFromQualString();
2222
testFusion();
2323
testGraphExecutor();
2424
testInternedStrings();
2525
testInterp();
2626
testIValue();
2727
testProto();
28+
testSchemaParser();
2829
return out.str();
2930
}
3031
} // namespace jit

test/cpp/jit/tests.h

+192
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "torch/csrc/jit/autodiff.h"
3535
#include "torch/csrc/jit/code_template.h"
3636
#include "torch/csrc/jit/custom_operator.h"
37+
#include "torch/csrc/jit/dynamic_dag.h"
3738
#include "torch/csrc/jit/fusers/interface.h"
3839
#include "torch/csrc/jit/interned_strings.h"
3940
#include "torch/csrc/jit/interpreter.h"
@@ -62,6 +63,8 @@
6263

6364
#include <ATen/ATen.h>
6465

66+
#include <c10/util/Exception.h>
67+
6568
#include <algorithm>
6669
#include <cstddef>
6770
#include <functional>
@@ -1220,6 +1223,195 @@ void testTopologicalIndex() {
12201223
}
12211224
}
12221225

1226+
1227+
std::unique_ptr<detail::DynamicDAG<std::string>> newDynamicDAG() {
1228+
return std::unique_ptr<detail::DynamicDAG<std::string>>(new detail::DynamicDAG<std::string>());
1229+
}
1230+
1231+
void testNewVertex() {
1232+
auto graph = newDynamicDAG();
1233+
JIT_ASSERT(graph->debugNumVertices() == 0);
1234+
auto a = graph->newVertex("a");
1235+
JIT_ASSERT(graph->debugNumVertices() == 1);
1236+
JIT_ASSERT(a->ord == 0);
1237+
JIT_ASSERT(a->data.size() == 1);
1238+
JIT_ASSERT(a->data[0] == "a");
1239+
JIT_ASSERT(a->in_edges().size() == 0);
1240+
JIT_ASSERT(a->out_edges().size() == 0);
1241+
auto b = graph->newVertex("b");
1242+
auto c = graph->newVertex("c");
1243+
JIT_ASSERT(graph->debugNumVertices() == 3);
1244+
JIT_ASSERT(b->ord == 1);
1245+
JIT_ASSERT(c->ord == 2);
1246+
}
1247+
1248+
void testAddEdgeBasic() {
1249+
// a -> b -> c
1250+
// \---------^
1251+
auto graph = newDynamicDAG();
1252+
auto a = graph->newVertex("a");
1253+
auto b = graph->newVertex("b");
1254+
auto c = graph->newVertex("c");
1255+
graph->addEdge(a, b);
1256+
graph->addEdge(b, c);
1257+
graph->addEdge(a, c);
1258+
JIT_ASSERT(a->in_edges().size() == 0);
1259+
JIT_ASSERT(a->out_edges().size() == 2);
1260+
JIT_ASSERT(a->out_edges().contains(b));
1261+
JIT_ASSERT(a->out_edges().contains(c));
1262+
JIT_ASSERT(b->in_edges().size() == 1);
1263+
JIT_ASSERT(b->out_edges().size() == 1);
1264+
JIT_ASSERT(b->in_edges().contains(a));
1265+
JIT_ASSERT(b->out_edges().contains(c));
1266+
JIT_ASSERT(c->in_edges().size() == 2);
1267+
JIT_ASSERT(c->out_edges().size() == 0);
1268+
JIT_ASSERT(c->in_edges().contains(a));
1269+
JIT_ASSERT(c->in_edges().contains(b));
1270+
}
1271+
1272+
void testAddEdgeCycleDetection() {
1273+
// a -> b -> c
1274+
// ^---------/
1275+
auto graph = newDynamicDAG();
1276+
auto a = graph->newVertex("a");
1277+
auto b = graph->newVertex("b");
1278+
auto c = graph->newVertex("c");
1279+
graph->addEdge(a, b);
1280+
graph->addEdge(b, c);
1281+
bool erred = false;
1282+
try {
1283+
graph->addEdge(c, a);
1284+
} catch (c10::Error& err) {
1285+
erred = true;
1286+
}
1287+
JIT_ASSERT(erred);
1288+
}
1289+
1290+
void testAddEdgeReordersBasic() {
1291+
// a, b => b -> a
1292+
auto graph = newDynamicDAG();
1293+
auto a = graph->newVertex("a");
1294+
auto b = graph->newVertex("b");
1295+
JIT_ASSERT(a->ord == 0);
1296+
JIT_ASSERT(b->ord == 1);
1297+
graph->addEdge(b, a);
1298+
JIT_ASSERT(a->ord == 1);
1299+
JIT_ASSERT(b->ord == 0);
1300+
}
1301+
1302+
void testAddEdgeReordersComplicated() {
1303+
// a -> b c -> d with addEdge(d, b) ==>
1304+
// c -> d -> a -> b
1305+
auto graph = newDynamicDAG();
1306+
auto a = graph->newVertex("a");
1307+
auto b = graph->newVertex("b");
1308+
auto c = graph->newVertex("c");
1309+
auto d = graph->newVertex("d");
1310+
graph->addEdge(a, b);
1311+
graph->addEdge(c, d);
1312+
JIT_ASSERT(a->ord == 0);
1313+
JIT_ASSERT(b->ord == 1);
1314+
JIT_ASSERT(c->ord == 2);
1315+
JIT_ASSERT(d->ord == 3);
1316+
graph->addEdge(d, a);
1317+
JIT_ASSERT(c->ord == 0);
1318+
JIT_ASSERT(d->ord == 1);
1319+
JIT_ASSERT(a->ord == 2);
1320+
JIT_ASSERT(b->ord == 3);
1321+
JIT_ASSERT(c->in_edges().size() == 0);
1322+
JIT_ASSERT(c->out_edges().size() == 1);
1323+
JIT_ASSERT(c->out_edges().contains(d));
1324+
JIT_ASSERT(d->in_edges().size() == 1);
1325+
JIT_ASSERT(d->out_edges().size() == 1);
1326+
JIT_ASSERT(d->in_edges().contains(c));
1327+
JIT_ASSERT(d->out_edges().contains(a));
1328+
JIT_ASSERT(a->in_edges().size() == 1);
1329+
JIT_ASSERT(a->out_edges().size() == 1);
1330+
JIT_ASSERT(a->in_edges().contains(d));
1331+
JIT_ASSERT(a->out_edges().contains(b));
1332+
JIT_ASSERT(b->in_edges().size() == 1);
1333+
JIT_ASSERT(b->out_edges().size() == 0);
1334+
JIT_ASSERT(b->in_edges().contains(a));
1335+
}
1336+
1337+
void testRemoveEdgeBasic() {
1338+
// a -> b
1339+
auto graph = newDynamicDAG();
1340+
auto a = graph->newVertex("a");
1341+
auto b = graph->newVertex("b");
1342+
graph->addEdge(a, b);
1343+
JIT_ASSERT(graph->debugNumVertices() == 2);
1344+
graph->removeEdge(a, b);
1345+
JIT_ASSERT(graph->debugNumVertices() == 2);
1346+
JIT_ASSERT(a->out_edges().size() == 0);
1347+
JIT_ASSERT(b->in_edges().size() == 0);
1348+
}
1349+
1350+
void testRemoveVertexBasic() {
1351+
// a -> b
1352+
auto graph = newDynamicDAG();
1353+
auto a = graph->newVertex("a");
1354+
auto b = graph->newVertex("b");
1355+
auto c = graph->newVertex("c");
1356+
graph->addEdge(a, b);
1357+
graph->addEdge(b, c);
1358+
JIT_ASSERT(graph->debugNumVertices() == 3);
1359+
graph->removeVertex(b);
1360+
JIT_ASSERT(graph->debugNumVertices() == 2);
1361+
JIT_ASSERT(a->out_edges().size() == 0);
1362+
JIT_ASSERT(c->in_edges().size() == 0);
1363+
}
1364+
1365+
void testContractEdgeBasic() {
1366+
// a -> b -> c -> d
1367+
auto graph = newDynamicDAG();
1368+
auto a = graph->newVertex("a");
1369+
auto b = graph->newVertex("b");
1370+
auto c = graph->newVertex("c");
1371+
auto d = graph->newVertex("d");
1372+
graph->addEdge(a, b);
1373+
graph->addEdge(b, c);
1374+
graph->addEdge(c, d);
1375+
graph->contractEdge(b, c);
1376+
JIT_ASSERT(graph->debugNumVertices() == 3);
1377+
JIT_ASSERT(a->out_edges().size() == 1);
1378+
JIT_ASSERT(d->in_edges().size() == 1);
1379+
JIT_ASSERT(*a->out_edges().begin() == *d->in_edges().begin());
1380+
auto* contracted = *a->out_edges().begin();
1381+
JIT_ASSERT(contracted->data.size() == 2);
1382+
JIT_ASSERT(contracted->data[0] == "b");
1383+
JIT_ASSERT(contracted->data[1] == "c");
1384+
JIT_ASSERT(contracted->out_edges().size() == 1);
1385+
JIT_ASSERT(contracted->in_edges().size() == 1);
1386+
JIT_ASSERT(contracted->in_edges().contains(a));
1387+
JIT_ASSERT(contracted->out_edges().contains(d));
1388+
}
1389+
1390+
void testContractEdgeCycleDetection() {
1391+
// a -> b -> c
1392+
// `---------^
1393+
// contractEdge(a, c) will cause a cycle
1394+
auto graph = newDynamicDAG();
1395+
auto a = graph->newVertex("a");
1396+
auto b = graph->newVertex("b");
1397+
auto c = graph->newVertex("c");
1398+
graph->addEdge(a, b);
1399+
graph->addEdge(b, c);
1400+
graph->addEdge(a, c);
1401+
JIT_ASSERT(!graph->contractEdge(a, c));
1402+
}
1403+
1404+
void testDynamicDAG() {
1405+
testNewVertex();
1406+
testAddEdgeBasic();
1407+
testAddEdgeCycleDetection();
1408+
testAddEdgeReordersBasic();
1409+
testAddEdgeReordersComplicated();
1410+
testRemoveEdgeBasic();
1411+
testRemoveVertexBasic();
1412+
testContractEdgeBasic();
1413+
testContractEdgeCycleDetection();
1414+
}
12231415
} // namespace
12241416
} // namespace jit
12251417
} // namespace torch

tools/run-clang-tidy-in-ci.sh

-1
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,4 @@ time python tools/clang_tidy.py \
4545
-g"-torch/csrc/jit/init.cpp" \
4646
-g"-torch/csrc/jit/export.cpp" \
4747
-g"-torch/csrc/jit/import.cpp" \
48-
-g"-torch/csrc/jit/test_jit.cpp" \
4948
"$@"

0 commit comments

Comments
 (0)