|
20 | 20 | import struct
|
21 | 21 | import subprocess
|
22 | 22 | import time
|
23 |
| -from typing import TYPE_CHECKING, List, Tuple, Union |
| 23 | +from typing import TYPE_CHECKING, Dict, List, Tuple, Union |
24 | 24 |
|
25 | 25 | import torch
|
26 | 26 | import torch_npu
|
|
51 | 51 |
|
52 | 52 | # Get all device ips using hccn_tool
|
53 | 53 | HCCN_TOOL_PATH = envs.HCCN_PATH
|
| 54 | +HCCN_CONF_PATH = envs.HCCN_CONF_PATH |
54 | 55 |
|
55 | 56 |
|
56 | 57 | class KVTransferEngine:
|
@@ -457,26 +458,53 @@ def _extract_kv_from_layer(
|
457 | 458 |
|
458 | 459 |
|
459 | 460 | def get_device_ips():
|
| 461 | + def get_ips_from_hccn_conf(): |
| 462 | + device_ips: Dict[str, str] = {} |
| 463 | + with open(HCCN_CONF_PATH, 'r') as fin: |
| 464 | + for hccn_item in fin.readlines(): |
| 465 | + if hccn_item.strip().startswith('address_'): |
| 466 | + device_id, device_ip = hccn_item.split('=') |
| 467 | + device_id = device_id.split('_')[1] |
| 468 | + device_ips[device_id] = device_ip.strip() |
| 469 | + return device_ips |
| 470 | + |
460 | 471 | world_size = 8
|
461 | 472 | npu_info = subprocess.run(['npu-smi', 'info', '-m'],
|
462 | 473 | stdout=subprocess.PIPE,
|
463 | 474 | stderr=subprocess.PIPE,
|
464 | 475 | universal_newlines=True)
|
465 |
| - if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH): |
466 |
| - raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.") |
| 476 | + if npu_info.returncode != 0: |
| 477 | + raise RuntimeError("No npu-smi tools provided for NPU.") |
| 478 | + |
467 | 479 | npu_start_idx = int(
|
468 | 480 | re.match(r'.*\n\t([0-9]+).*', npu_info.stdout).group(1))
|
| 481 | + |
469 | 482 | device_ip_list = []
|
470 |
| - for ip_offset in range(world_size): |
471 |
| - cmd = [ |
472 |
| - HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g' |
473 |
| - ] |
474 |
| - device_ip_info = subprocess.run(cmd, |
475 |
| - stdout=subprocess.PIPE, |
476 |
| - stderr=subprocess.PIPE, |
477 |
| - universal_newlines=True) |
478 |
| - device_ip = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout).group(1) |
479 |
| - device_ip_list.append(device_ip) |
| 483 | + if os.path.exists(HCCN_TOOL_PATH): |
| 484 | + for ip_offset in range(world_size): |
| 485 | + cmd = [ |
| 486 | + HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', |
| 487 | + '-g' |
| 488 | + ] |
| 489 | + device_ip_info = subprocess.run(cmd, |
| 490 | + stdout=subprocess.PIPE, |
| 491 | + stderr=subprocess.PIPE, |
| 492 | + universal_newlines=True) |
| 493 | + device_ip = re.match(r'ipaddr:(.*)\n', |
| 494 | + device_ip_info.stdout).group(1) |
| 495 | + device_ip_list.append(device_ip) |
| 496 | + |
| 497 | + elif os.path.exists(HCCN_CONF_PATH): |
| 498 | + device_ips = get_ips_from_hccn_conf() |
| 499 | + for ip_offset in range(world_size): |
| 500 | + device_ip = device_ips[str(npu_start_idx + ip_offset)] |
| 501 | + device_ip_list.append(device_ip) |
| 502 | + |
| 503 | + else: |
| 504 | + raise RuntimeError( |
| 505 | + "Failed to find information for rank_table, please check the " |
| 506 | + "existence of hccn_tool and hccn.conf") |
| 507 | + |
480 | 508 | return device_ip_list
|
481 | 509 |
|
482 | 510 |
|
|
0 commit comments