|
34 | 34 | #include "torch/csrc/jit/autodiff.h"
|
35 | 35 | #include "torch/csrc/jit/code_template.h"
|
36 | 36 | #include "torch/csrc/jit/custom_operator.h"
|
| 37 | +#include "torch/csrc/jit/dynamic_dag.h" |
37 | 38 | #include "torch/csrc/jit/fusers/interface.h"
|
38 | 39 | #include "torch/csrc/jit/interned_strings.h"
|
39 | 40 | #include "torch/csrc/jit/interpreter.h"
|
|
62 | 63 |
|
63 | 64 | #include <ATen/ATen.h>
|
64 | 65 |
|
| 66 | +#include <c10/util/Exception.h> |
| 67 | + |
65 | 68 | #include <algorithm>
|
66 | 69 | #include <cstddef>
|
67 | 70 | #include <functional>
|
@@ -1220,6 +1223,195 @@ void testTopologicalIndex() {
|
1220 | 1223 | }
|
1221 | 1224 | }
|
1222 | 1225 |
|
| 1226 | + |
| 1227 | +std::unique_ptr<detail::DynamicDAG<std::string>> newDynamicDAG() { |
| 1228 | + return std::unique_ptr<detail::DynamicDAG<std::string>>(new detail::DynamicDAG<std::string>()); |
| 1229 | +} |
| 1230 | + |
| 1231 | +void testNewVertex() { |
| 1232 | + auto graph = newDynamicDAG(); |
| 1233 | + JIT_ASSERT(graph->debugNumVertices() == 0); |
| 1234 | + auto a = graph->newVertex("a"); |
| 1235 | + JIT_ASSERT(graph->debugNumVertices() == 1); |
| 1236 | + JIT_ASSERT(a->ord == 0); |
| 1237 | + JIT_ASSERT(a->data.size() == 1); |
| 1238 | + JIT_ASSERT(a->data[0] == "a"); |
| 1239 | + JIT_ASSERT(a->in_edges().size() == 0); |
| 1240 | + JIT_ASSERT(a->out_edges().size() == 0); |
| 1241 | + auto b = graph->newVertex("b"); |
| 1242 | + auto c = graph->newVertex("c"); |
| 1243 | + JIT_ASSERT(graph->debugNumVertices() == 3); |
| 1244 | + JIT_ASSERT(b->ord == 1); |
| 1245 | + JIT_ASSERT(c->ord == 2); |
| 1246 | +} |
| 1247 | + |
| 1248 | +void testAddEdgeBasic() { |
| 1249 | + // a -> b -> c |
| 1250 | + // \---------^ |
| 1251 | + auto graph = newDynamicDAG(); |
| 1252 | + auto a = graph->newVertex("a"); |
| 1253 | + auto b = graph->newVertex("b"); |
| 1254 | + auto c = graph->newVertex("c"); |
| 1255 | + graph->addEdge(a, b); |
| 1256 | + graph->addEdge(b, c); |
| 1257 | + graph->addEdge(a, c); |
| 1258 | + JIT_ASSERT(a->in_edges().size() == 0); |
| 1259 | + JIT_ASSERT(a->out_edges().size() == 2); |
| 1260 | + JIT_ASSERT(a->out_edges().contains(b)); |
| 1261 | + JIT_ASSERT(a->out_edges().contains(c)); |
| 1262 | + JIT_ASSERT(b->in_edges().size() == 1); |
| 1263 | + JIT_ASSERT(b->out_edges().size() == 1); |
| 1264 | + JIT_ASSERT(b->in_edges().contains(a)); |
| 1265 | + JIT_ASSERT(b->out_edges().contains(c)); |
| 1266 | + JIT_ASSERT(c->in_edges().size() == 2); |
| 1267 | + JIT_ASSERT(c->out_edges().size() == 0); |
| 1268 | + JIT_ASSERT(c->in_edges().contains(a)); |
| 1269 | + JIT_ASSERT(c->in_edges().contains(b)); |
| 1270 | +} |
| 1271 | + |
| 1272 | +void testAddEdgeCycleDetection() { |
| 1273 | + // a -> b -> c |
| 1274 | + // ^---------/ |
| 1275 | + auto graph = newDynamicDAG(); |
| 1276 | + auto a = graph->newVertex("a"); |
| 1277 | + auto b = graph->newVertex("b"); |
| 1278 | + auto c = graph->newVertex("c"); |
| 1279 | + graph->addEdge(a, b); |
| 1280 | + graph->addEdge(b, c); |
| 1281 | + bool erred = false; |
| 1282 | + try { |
| 1283 | + graph->addEdge(c, a); |
| 1284 | + } catch (c10::Error& err) { |
| 1285 | + erred = true; |
| 1286 | + } |
| 1287 | + JIT_ASSERT(erred); |
| 1288 | +} |
| 1289 | + |
| 1290 | +void testAddEdgeReordersBasic() { |
| 1291 | + // a, b => b -> a |
| 1292 | + auto graph = newDynamicDAG(); |
| 1293 | + auto a = graph->newVertex("a"); |
| 1294 | + auto b = graph->newVertex("b"); |
| 1295 | + JIT_ASSERT(a->ord == 0); |
| 1296 | + JIT_ASSERT(b->ord == 1); |
| 1297 | + graph->addEdge(b, a); |
| 1298 | + JIT_ASSERT(a->ord == 1); |
| 1299 | + JIT_ASSERT(b->ord == 0); |
| 1300 | +} |
| 1301 | + |
| 1302 | +void testAddEdgeReordersComplicated() { |
| 1303 | + // a -> b c -> d with addEdge(d, b) ==> |
| 1304 | + // c -> d -> a -> b |
| 1305 | + auto graph = newDynamicDAG(); |
| 1306 | + auto a = graph->newVertex("a"); |
| 1307 | + auto b = graph->newVertex("b"); |
| 1308 | + auto c = graph->newVertex("c"); |
| 1309 | + auto d = graph->newVertex("d"); |
| 1310 | + graph->addEdge(a, b); |
| 1311 | + graph->addEdge(c, d); |
| 1312 | + JIT_ASSERT(a->ord == 0); |
| 1313 | + JIT_ASSERT(b->ord == 1); |
| 1314 | + JIT_ASSERT(c->ord == 2); |
| 1315 | + JIT_ASSERT(d->ord == 3); |
| 1316 | + graph->addEdge(d, a); |
| 1317 | + JIT_ASSERT(c->ord == 0); |
| 1318 | + JIT_ASSERT(d->ord == 1); |
| 1319 | + JIT_ASSERT(a->ord == 2); |
| 1320 | + JIT_ASSERT(b->ord == 3); |
| 1321 | + JIT_ASSERT(c->in_edges().size() == 0); |
| 1322 | + JIT_ASSERT(c->out_edges().size() == 1); |
| 1323 | + JIT_ASSERT(c->out_edges().contains(d)); |
| 1324 | + JIT_ASSERT(d->in_edges().size() == 1); |
| 1325 | + JIT_ASSERT(d->out_edges().size() == 1); |
| 1326 | + JIT_ASSERT(d->in_edges().contains(c)); |
| 1327 | + JIT_ASSERT(d->out_edges().contains(a)); |
| 1328 | + JIT_ASSERT(a->in_edges().size() == 1); |
| 1329 | + JIT_ASSERT(a->out_edges().size() == 1); |
| 1330 | + JIT_ASSERT(a->in_edges().contains(d)); |
| 1331 | + JIT_ASSERT(a->out_edges().contains(b)); |
| 1332 | + JIT_ASSERT(b->in_edges().size() == 1); |
| 1333 | + JIT_ASSERT(b->out_edges().size() == 0); |
| 1334 | + JIT_ASSERT(b->in_edges().contains(a)); |
| 1335 | +} |
| 1336 | + |
| 1337 | +void testRemoveEdgeBasic() { |
| 1338 | + // a -> b |
| 1339 | + auto graph = newDynamicDAG(); |
| 1340 | + auto a = graph->newVertex("a"); |
| 1341 | + auto b = graph->newVertex("b"); |
| 1342 | + graph->addEdge(a, b); |
| 1343 | + JIT_ASSERT(graph->debugNumVertices() == 2); |
| 1344 | + graph->removeEdge(a, b); |
| 1345 | + JIT_ASSERT(graph->debugNumVertices() == 2); |
| 1346 | + JIT_ASSERT(a->out_edges().size() == 0); |
| 1347 | + JIT_ASSERT(b->in_edges().size() == 0); |
| 1348 | +} |
| 1349 | + |
| 1350 | +void testRemoveVertexBasic() { |
| 1351 | + // a -> b |
| 1352 | + auto graph = newDynamicDAG(); |
| 1353 | + auto a = graph->newVertex("a"); |
| 1354 | + auto b = graph->newVertex("b"); |
| 1355 | + auto c = graph->newVertex("c"); |
| 1356 | + graph->addEdge(a, b); |
| 1357 | + graph->addEdge(b, c); |
| 1358 | + JIT_ASSERT(graph->debugNumVertices() == 3); |
| 1359 | + graph->removeVertex(b); |
| 1360 | + JIT_ASSERT(graph->debugNumVertices() == 2); |
| 1361 | + JIT_ASSERT(a->out_edges().size() == 0); |
| 1362 | + JIT_ASSERT(c->in_edges().size() == 0); |
| 1363 | +} |
| 1364 | + |
| 1365 | +void testContractEdgeBasic() { |
| 1366 | + // a -> b -> c -> d |
| 1367 | + auto graph = newDynamicDAG(); |
| 1368 | + auto a = graph->newVertex("a"); |
| 1369 | + auto b = graph->newVertex("b"); |
| 1370 | + auto c = graph->newVertex("c"); |
| 1371 | + auto d = graph->newVertex("d"); |
| 1372 | + graph->addEdge(a, b); |
| 1373 | + graph->addEdge(b, c); |
| 1374 | + graph->addEdge(c, d); |
| 1375 | + graph->contractEdge(b, c); |
| 1376 | + JIT_ASSERT(graph->debugNumVertices() == 3); |
| 1377 | + JIT_ASSERT(a->out_edges().size() == 1); |
| 1378 | + JIT_ASSERT(d->in_edges().size() == 1); |
| 1379 | + JIT_ASSERT(*a->out_edges().begin() == *d->in_edges().begin()); |
| 1380 | + auto* contracted = *a->out_edges().begin(); |
| 1381 | + JIT_ASSERT(contracted->data.size() == 2); |
| 1382 | + JIT_ASSERT(contracted->data[0] == "b"); |
| 1383 | + JIT_ASSERT(contracted->data[1] == "c"); |
| 1384 | + JIT_ASSERT(contracted->out_edges().size() == 1); |
| 1385 | + JIT_ASSERT(contracted->in_edges().size() == 1); |
| 1386 | + JIT_ASSERT(contracted->in_edges().contains(a)); |
| 1387 | + JIT_ASSERT(contracted->out_edges().contains(d)); |
| 1388 | +} |
| 1389 | + |
| 1390 | +void testContractEdgeCycleDetection() { |
| 1391 | + // a -> b -> c |
| 1392 | + // `---------^ |
| 1393 | + // contractEdge(a, c) will cause a cycle |
| 1394 | + auto graph = newDynamicDAG(); |
| 1395 | + auto a = graph->newVertex("a"); |
| 1396 | + auto b = graph->newVertex("b"); |
| 1397 | + auto c = graph->newVertex("c"); |
| 1398 | + graph->addEdge(a, b); |
| 1399 | + graph->addEdge(b, c); |
| 1400 | + graph->addEdge(a, c); |
| 1401 | + JIT_ASSERT(!graph->contractEdge(a, c)); |
| 1402 | +} |
| 1403 | + |
| 1404 | +void testDynamicDAG() { |
| 1405 | + testNewVertex(); |
| 1406 | + testAddEdgeBasic(); |
| 1407 | + testAddEdgeCycleDetection(); |
| 1408 | + testAddEdgeReordersBasic(); |
| 1409 | + testAddEdgeReordersComplicated(); |
| 1410 | + testRemoveEdgeBasic(); |
| 1411 | + testRemoveVertexBasic(); |
| 1412 | + testContractEdgeBasic(); |
| 1413 | + testContractEdgeCycleDetection(); |
| 1414 | +} |
1223 | 1415 | } // namespace
|
1224 | 1416 | } // namespace jit
|
1225 | 1417 | } // namespace torch
|
0 commit comments