1
+ import sys
2
+ import unittest
3
+ import torch
4
+ import numpy as np
5
+
6
+ from torch .distributed .tensor import DeviceMesh
7
+ from torch .distributed ._tensor import DTensor
8
+ from torch .distributed .tensor .placement_types import Replicate , Shard
9
+ import torch_xla
10
+ import torch_xla .runtime as xr
11
+ import torch_xla .core .xla_model as xm
12
+ from torch_xla .distributed .spmd .xla_sharded_tensor import XLAShardedTensor
13
+ import test_xla_sharding_base
14
+
15
+
16
+ class DTensorXLAFromLocalConversionTest (test_xla_sharding_base .XlaShardingTest ):
17
+ """
18
+ Test suite for the automatic conversion of regular tensors to XLAShardedTensor
19
+ in DTensor.from_local() when using XLA device mesh.
20
+ """
21
+
22
+ @classmethod
23
+ def setUpClass (cls ):
24
+ super ().setUpClass ()
25
+
26
+ def test_basic_conversion (self ):
27
+ """Test basic conversion of regular tensor to XLAShardedTensor."""
28
+ world_size = xr .global_runtime_device_count ()
29
+
30
+ # Create a regular tensor (not on XLA device)
31
+ tensor = torch .randn (100_000 , 88 )
32
+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
33
+
34
+ # Create a DeviceMesh
35
+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
36
+
37
+ # Use DTensor.from_local with the regular tensor
38
+ dt = DTensor .from_local (tensor , device_mesh = device_mesh )
39
+
40
+ # Verify the tensor was converted correctly
41
+ self .assertEqual (dt .shape , tensor .shape )
42
+
43
+ # Check the value of the tensor
44
+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
45
+
46
+ # Verify operations work
47
+ result = dt + 1.0
48
+ self .assertEqual (result .shape , tensor .shape )
49
+
50
+ print ("Basic conversion successful" )
51
+
52
+
53
+ def test_conversion_with_placements (self ):
54
+ """Test conversion with explicit placements."""
55
+ world_size = xr .global_runtime_device_count ()
56
+
57
+ # Create a regular tensor (not on XLA device)
58
+ tensor = torch .randn (100_000 , 88 )
59
+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
60
+
61
+ # Create a DeviceMesh
62
+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
63
+
64
+ # Use DTensor.from_local with explicit placements
65
+ dt = DTensor .from_local (
66
+ tensor ,
67
+ device_mesh = device_mesh ,
68
+ placements = [Replicate ()]
69
+ )
70
+
71
+ # Verify the tensor was converted correctly
72
+ self .assertEqual (dt .shape , tensor .shape )
73
+
74
+ # Check the value of the tensor
75
+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
76
+
77
+ # Verify operations work
78
+ result = dt + 1.0
79
+ self .assertEqual (result .shape , tensor .shape )
80
+
81
+ print ("Conversion with placements successful" )
82
+
83
+ def test_conversion_with_sharding (self ):
84
+ """Test conversion with sharding placement."""
85
+ world_size = xr .global_runtime_device_count ()
86
+ if world_size < 2 :
87
+ self .skipTest ("Need at least 2 devices for sharding test" )
88
+
89
+ # Create a tensor divisible by world_size
90
+ tensor = torch .randn (100_000 , 88 )
91
+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
92
+
93
+ # Create a DeviceMesh
94
+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
95
+
96
+ # Use DTensor.from_local with sharding placement
97
+ dt = DTensor .from_local (
98
+ tensor ,
99
+ device_mesh = device_mesh ,
100
+ placements = [Shard (0 )]
101
+ )
102
+
103
+ # Verify the tensor was converted correctly
104
+ self .assertEqual (dt .shape , tensor .shape )
105
+
106
+ # Check the value of the tensor
107
+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
108
+
109
+ # Verify operations work
110
+ result = dt + 1.0
111
+ self .assertEqual (result .shape , tensor .shape )
112
+
113
+ print ("Conversion with sharding successful" )
114
+
115
+ def test_conversion_with_different_dtypes (self ):
116
+ """Test conversion with different dtypes."""
117
+ world_size = xr .global_runtime_device_count ()
118
+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
119
+
120
+ # Test with different dtypes
121
+ for dtype in [torch .float16 , torch .float32 , torch .int32 , torch .int64 ]:
122
+ # Create a tensor with specific dtype
123
+ tensor = torch .ones (100_000 , 88 , dtype = dtype )
124
+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
125
+
126
+ # Use DTensor.from_local with the tensor
127
+ dt = DTensor .from_local (tensor , device_mesh = device_mesh )
128
+
129
+ # Verify dtype is preserved
130
+ self .assertEqual (dt .dtype , dtype )
131
+
132
+ # Check the value of the tensor
133
+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
134
+
135
+ # Verify operations work
136
+ if dtype .is_floating_point :
137
+ result = dt + 1.0
138
+ else :
139
+ result = dt + 1
140
+
141
+ self .assertEqual (result .shape , tensor .shape )
142
+ self .assertEqual (result .dtype , dtype )
143
+
144
+ print (f"Conversion with { dtype } successful" )
145
+
146
+
147
+ if __name__ == "__main__" :
148
+ result = unittest .main (exit = False )
149
+ sys .exit (0 if result .result .wasSuccessful () else 1 )
0 commit comments