|
13 | 13 | import unittest
|
14 | 14 | from contextlib import redirect_stdout
|
15 | 15 |
|
16 |
| -from typing import Callable, List |
| 16 | +from typing import Callable, List, Union |
17 | 17 |
|
18 | 18 | from unittest.mock import patch
|
19 | 19 |
|
|
56 | 56 |
|
57 | 57 | OP_TYPE = "aten::add"
|
58 | 58 | EVENT_BLOCK_NAME = "block_0"
|
59 |
| -EVENTS_SIZE = 5 |
| 59 | +EVENTS_SIZE = 10 |
60 | 60 | RAW_DATA_SIZE = 10
|
61 | 61 | ETDUMP_PATH = "unittest_etdump_path"
|
62 | 62 | ETRECORD_PATH = "unittest_etrecord_path"
|
@@ -535,17 +535,116 @@ def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self):
|
535 | 535 | )
|
536 | 536 | )
|
537 | 537 |
|
| 538 | + def test_get_runtime_intermediate_outputs(self): |
| 539 | + # Create a context manager to patch functions called by Inspector.__init__ |
| 540 | + with patch.object( |
| 541 | + _inspector, "parse_etrecord", return_value=None |
| 542 | + ), patch.object( |
| 543 | + _inspector, "gen_etdump_object", return_value=None |
| 544 | + ), patch.object( |
| 545 | + EventBlock, "_gen_from_etdump" |
| 546 | + ), patch.object( |
| 547 | + _inspector, "gen_graphs_from_etrecord" |
| 548 | + ): |
| 549 | + # Call the constructor of Inspector |
| 550 | + inspector_instance = Inspector( |
| 551 | + etdump_path=ETDUMP_PATH, |
| 552 | + etrecord=ETRECORD_PATH, |
| 553 | + ) |
| 554 | + |
| 555 | + # The mock inspector instance starts with having an empty event blocks list. |
| 556 | + # Add pre-defined event blocks to test _get_runtime_outputs(). |
| 557 | + inspector_instance.event_blocks = [ |
| 558 | + EventBlock(name=EVENT_BLOCK_NAME, events=self._gen_random_events()) |
| 559 | + ] |
| 560 | + |
| 561 | + runtime_outputs = inspector_instance._get_runtime_intermediate_outputs() |
| 562 | + # This output should be a dictionary with 5 keys |
| 563 | + self.assertEqual( |
| 564 | + len(runtime_outputs), |
| 565 | + 5, |
| 566 | + ) |
| 567 | + # Check that keys (0,) and (1,) are not in the dictionary(skip OPERATOR_CALL and op_types are empty) |
| 568 | + self.assertNotIn((0,), runtime_outputs) |
| 569 | + self.assertNotIn((1,), runtime_outputs) |
| 570 | + |
| 571 | + # Same debug_handle but different instruction_id, should record the last one |
| 572 | + self.assertIn((4,), runtime_outputs) |
| 573 | + self.assertTrue( |
| 574 | + torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0])) |
| 575 | + ) |
| 576 | + # Check that keys (5,) to (8,) are in the dictionary and have values of the correct size |
| 577 | + for key in range(5, 9): |
| 578 | + self.assertIn((key,), runtime_outputs) |
| 579 | + self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) |
| 580 | + |
538 | 581 | def _gen_random_float_list(self) -> List[float]:
|
539 | 582 | return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
|
540 | 583 |
|
| 584 | + def _gen_random_runtime_output( |
| 585 | + self, |
| 586 | + ) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]: |
| 587 | + return list(torch.randn(RAW_DATA_SIZE)) |
| 588 | + |
541 | 589 | def _gen_random_events(self) -> List[Event]:
|
542 | 590 | events = []
|
543 |
| - for i in range(EVENTS_SIZE): |
| 591 | + for i in range(2): |
| 592 | + events.append( |
| 593 | + # OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2 |
| 594 | + Event( |
| 595 | + name="OPERATOR_CALL", |
| 596 | + op_types=[OP_TYPE], |
| 597 | + perf_data=PerfData(self._gen_random_float_list()), |
| 598 | + debug_handles=i * 2, |
| 599 | + _instruction_id=i * 2, |
| 600 | + debug_data=self._gen_random_runtime_output(), |
| 601 | + ) |
| 602 | + ) |
| 603 | + events.append( |
| 604 | + # op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3 |
| 605 | + Event( |
| 606 | + name=f"op_{i}", |
| 607 | + op_types=[], |
| 608 | + perf_data=PerfData(self._gen_random_float_list()), |
| 609 | + debug_handles=i * 2 + 1, |
| 610 | + _instruction_id=i * 2 + 1, |
| 611 | + debug_data=self._gen_random_runtime_output(), |
| 612 | + ) |
| 613 | + ) |
| 614 | + |
| 615 | + # op_2 with debug_hanldes/instruction_id 4 |
| 616 | + events.append( |
| 617 | + Event( |
| 618 | + name="op_2", |
| 619 | + op_types=[OP_TYPE], |
| 620 | + perf_data=PerfData(self._gen_random_float_list()), |
| 621 | + debug_handles=4, |
| 622 | + debug_data=[torch.tensor([1.0, 2.0, 3.0])], |
| 623 | + _instruction_id=4, |
| 624 | + ) |
| 625 | + ) |
| 626 | + # op_3 also with debug_hanldes 4 but with instruction_id 5 |
| 627 | + events.append( |
| 628 | + Event( |
| 629 | + name="op_3", |
| 630 | + op_types=[OP_TYPE], |
| 631 | + perf_data=PerfData(self._gen_random_float_list()), |
| 632 | + debug_handles=4, |
| 633 | + debug_data=[torch.tensor([4.0, 5.0, 6.0])], |
| 634 | + _instruction_id=5, |
| 635 | + ) |
| 636 | + ) |
| 637 | + |
| 638 | + # op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9 |
| 639 | + for i in range(4, EVENTS_SIZE - 2): |
544 | 640 | events.append(
|
545 | 641 | Event(
|
546 | 642 | name=f"op_{i}",
|
547 | 643 | op_types=[OP_TYPE],
|
548 | 644 | perf_data=PerfData(self._gen_random_float_list()),
|
| 645 | + debug_handles=i + 1, |
| 646 | + debug_data=self._gen_random_runtime_output(), |
| 647 | + _instruction_id=i + 2, |
549 | 648 | )
|
550 | 649 | )
|
551 | 650 | return events
|
0 commit comments