Skip to content

Commit 5efeb10

Browse files
committed
Test for Routing XLA device handling through distribute_tensor to ensure proper XLA support and maintain consistency with PyTorch/XLA SPMD integration.
1 parent fa7f432 commit 5efeb10

File tree

4 files changed

+152
-0
lines changed

4 files changed

+152
-0
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ function run_xla_op_tests3 {
256256
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
257257
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258258
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
259+
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py"
259260
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
260261
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
261262
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ function run_xla_op_tests3 {
255255
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
256256
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
257257
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
258+
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py"
258259
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
259260
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
260261
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import sys
2+
import unittest
3+
import torch
4+
import numpy as np
5+
6+
from torch.distributed.tensor import DeviceMesh
7+
from torch.distributed._tensor import DTensor
8+
from torch.distributed.tensor.placement_types import Replicate, Shard
9+
import torch_xla
10+
import torch_xla.runtime as xr
11+
import torch_xla.core.xla_model as xm
12+
from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor
13+
import test_xla_sharding_base
14+
15+
16+
class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest):
17+
"""
18+
Test suite for the automatic conversion of regular tensors to XLAShardedTensor
19+
in DTensor.from_local() when using XLA device mesh.
20+
"""
21+
22+
@classmethod
23+
def setUpClass(cls):
24+
super().setUpClass()
25+
26+
def test_basic_conversion(self):
27+
"""Test basic conversion of regular tensor to XLAShardedTensor."""
28+
world_size = xr.global_runtime_device_count()
29+
30+
# Create a regular tensor (not on XLA device)
31+
tensor = torch.randn(100_000, 88)
32+
tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison
33+
34+
# Create a DeviceMesh
35+
device_mesh = DeviceMesh("xla", list(range(world_size)))
36+
37+
# Use DTensor.from_local with the regular tensor
38+
dt = DTensor.from_local(tensor, device_mesh=device_mesh)
39+
40+
# Verify the tensor was converted correctly
41+
self.assertEqual(dt.shape, tensor.shape)
42+
43+
# Check the value of the tensor
44+
torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False)
45+
46+
# Verify operations work
47+
result = dt + 1.0
48+
self.assertEqual(result.shape, tensor.shape)
49+
50+
print("Basic conversion successful")
51+
52+
53+
def test_conversion_with_placements(self):
54+
"""Test conversion with explicit placements."""
55+
world_size = xr.global_runtime_device_count()
56+
57+
# Create a regular tensor (not on XLA device)
58+
tensor = torch.randn(100_000, 88)
59+
tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison
60+
61+
# Create a DeviceMesh
62+
device_mesh = DeviceMesh("xla", list(range(world_size)))
63+
64+
# Use DTensor.from_local with explicit placements
65+
dt = DTensor.from_local(
66+
tensor,
67+
device_mesh=device_mesh,
68+
placements=[Replicate()]
69+
)
70+
71+
# Verify the tensor was converted correctly
72+
self.assertEqual(dt.shape, tensor.shape)
73+
74+
# Check the value of the tensor
75+
torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False)
76+
77+
# Verify operations work
78+
result = dt + 1.0
79+
self.assertEqual(result.shape, tensor.shape)
80+
81+
print("Conversion with placements successful")
82+
83+
def test_conversion_with_sharding(self):
84+
"""Test conversion with sharding placement."""
85+
world_size = xr.global_runtime_device_count()
86+
if world_size < 2:
87+
self.skipTest("Need at least 2 devices for sharding test")
88+
89+
# Create a tensor divisible by world_size
90+
tensor = torch.randn(100_000, 88)
91+
tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison
92+
93+
# Create a DeviceMesh
94+
device_mesh = DeviceMesh("xla", list(range(world_size)))
95+
96+
# Use DTensor.from_local with sharding placement
97+
dt = DTensor.from_local(
98+
tensor,
99+
device_mesh=device_mesh,
100+
placements=[Shard(0)]
101+
)
102+
103+
# Verify the tensor was converted correctly
104+
self.assertEqual(dt.shape, tensor.shape)
105+
106+
# Check the value of the tensor
107+
torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False)
108+
109+
# Verify operations work
110+
result = dt + 1.0
111+
self.assertEqual(result.shape, tensor.shape)
112+
113+
print("Conversion with sharding successful")
114+
115+
def test_conversion_with_different_dtypes(self):
116+
"""Test conversion with different dtypes."""
117+
world_size = xr.global_runtime_device_count()
118+
device_mesh = DeviceMesh("xla", list(range(world_size)))
119+
120+
# Test with different dtypes
121+
for dtype in [torch.float16, torch.float32, torch.int32, torch.int64]:
122+
# Create a tensor with specific dtype
123+
tensor = torch.ones(100_000, 88, dtype=dtype)
124+
tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison
125+
126+
# Use DTensor.from_local with the tensor
127+
dt = DTensor.from_local(tensor, device_mesh=device_mesh)
128+
129+
# Verify dtype is preserved
130+
self.assertEqual(dt.dtype, dtype)
131+
132+
# Check the value of the tensor
133+
torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False)
134+
135+
# Verify operations work
136+
if dtype.is_floating_point:
137+
result = dt + 1.0
138+
else:
139+
result = dt + 1
140+
141+
self.assertEqual(result.shape, tensor.shape)
142+
self.assertEqual(result.dtype, dtype)
143+
144+
print(f"Conversion with {dtype} successful")
145+
146+
147+
if __name__ == "__main__":
148+
result = unittest.main(exit=False)
149+
sys.exit(0 if result.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
6262
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
6464
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
65+
run_test "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py"
6566
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6667
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6768
run_test "$_TEST_DIR/test_autocast.py"

0 commit comments

Comments
 (0)