|
20 | 20 | import struct
|
21 | 21 | import subprocess
|
22 | 22 | import time
|
23 |
| -from typing import TYPE_CHECKING, List, Tuple, Union |
| 23 | +from typing import TYPE_CHECKING, Dict, List, Tuple, Union |
24 | 24 |
|
25 | 25 | import torch
|
26 | 26 | import torch_npu
|
|
51 | 51 |
|
52 | 52 | # Get all device ips using hccn_tool
|
53 | 53 | HCCN_TOOL_PATH = envs.HCCN_PATH
|
| 54 | +HCCN_CONF_PATH = envs.HCCN_CONF_PATH |
54 | 55 |
|
55 | 56 |
|
56 | 57 | class KVTransferEngine:
|
@@ -457,26 +458,60 @@ def _extract_kv_from_layer(
|
457 | 458 |
|
458 | 459 |
|
459 | 460 | def get_device_ips():
|
| 461 | + |
| 462 | + def get_ips_from_hccn_conf(): |
| 463 | + device_ips: Dict[str, str] = {} |
| 464 | + with open(HCCN_CONF_PATH, 'r') as fin: |
| 465 | + for hccn_item in fin.readlines(): |
| 466 | + if hccn_item.strip().startswith('address_'): |
| 467 | + device_id, device_ip = hccn_item.split('=') |
| 468 | + device_id = device_id.split('_')[1] |
| 469 | + device_ips[device_id] = device_ip.strip() |
| 470 | + return device_ips |
| 471 | + |
460 | 472 | world_size = 8
|
461 | 473 | npu_info = subprocess.run(['npu-smi', 'info', '-m'],
|
462 | 474 | stdout=subprocess.PIPE,
|
463 | 475 | stderr=subprocess.PIPE,
|
464 | 476 | universal_newlines=True)
|
465 |
| - if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH): |
466 |
| - raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.") |
467 |
| - npu_start_idx = int( |
468 |
| - re.match(r'.*\n\t([0-9]+).*', npu_info.stdout).group(1)) |
| 477 | + if npu_info.returncode != 0: |
| 478 | + raise RuntimeError("No npu-smi tools provided for NPU.") |
| 479 | + |
| 480 | + match = re.match(r'.*\n\t([0-9]+).*', npu_info.stdout) |
| 481 | + if match is None: |
| 482 | + raise ValueError( |
| 483 | + "Failed to extract NPU start index from npu-smi output.") |
| 484 | + npu_start_idx = int(match.group(1)) |
| 485 | + |
469 | 486 | device_ip_list = []
|
470 |
| - for ip_offset in range(world_size): |
471 |
| - cmd = [ |
472 |
| - HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g' |
473 |
| - ] |
474 |
| - device_ip_info = subprocess.run(cmd, |
475 |
| - stdout=subprocess.PIPE, |
476 |
| - stderr=subprocess.PIPE, |
477 |
| - universal_newlines=True) |
478 |
| - device_ip = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout).group(1) |
479 |
| - device_ip_list.append(device_ip) |
| 487 | + if os.path.exists(HCCN_TOOL_PATH): |
| 488 | + for ip_offset in range(world_size): |
| 489 | + cmd = [ |
| 490 | + HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', |
| 491 | + '-g' |
| 492 | + ] |
| 493 | + device_ip_info = subprocess.run(cmd, |
| 494 | + stdout=subprocess.PIPE, |
| 495 | + stderr=subprocess.PIPE, |
| 496 | + universal_newlines=True) |
| 497 | + match = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout) |
| 498 | + if match is None: |
| 499 | + raise ValueError( |
| 500 | + "Failed to extract device ip from hccn_tool output.") |
| 501 | + device_ip = match.group(1) |
| 502 | + device_ip_list.append(device_ip) |
| 503 | + |
| 504 | + elif os.path.exists(HCCN_CONF_PATH): |
| 505 | + device_ips = get_ips_from_hccn_conf() |
| 506 | + for ip_offset in range(world_size): |
| 507 | + device_ip = device_ips[str(npu_start_idx + ip_offset)] |
| 508 | + device_ip_list.append(device_ip) |
| 509 | + |
| 510 | + else: |
| 511 | + raise RuntimeError( |
| 512 | + "Failed to find information for rank_table, please check the " |
| 513 | + "existence of hccn_tool and hccn.conf") |
| 514 | + |
480 | 515 | return device_ip_list
|
481 | 516 |
|
482 | 517 |
|
|
0 commit comments