diff --git a/distributed/FSDP/T5_training.py b/distributed/FSDP/T5_training.py
index 4ab136eace..762e70c436 100644
--- a/distributed/FSDP/T5_training.py
+++ b/distributed/FSDP/T5_training.py
@@ -14,6 +14,7 @@
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.utils.data.distributed import DistributedSampler
 from transformers.models.t5.modeling_t5 import T5Block
+from nlp import load_dataset
 
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
@@ -86,11 +87,11 @@ def fsdp_main(args):
     print("Size of train dataset: ", dataset['train'].shape)
     print("Size of Validation dataset: ", dataset['validation'].shape)
 
-   
+
     #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
-    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False) 
+    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
     val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
- 
+
     sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
     sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
 
@@ -107,12 +108,12 @@ def fsdp_main(args):
 
     train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
     val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
- 
+
     torch.cuda.set_device(local_rank)
-    
+
     # Set up FSDP parameters
     mixed_precision_policy, t5_auto_wrap_policy = get_policies(train_config, rank)
-    
+
     # Apply FSDP wrapping to the model
     model = FSDP(model,
         auto_wrap_policy=t5_auto_wrap_policy,
@@ -120,7 +121,7 @@ def fsdp_main(args):
         sharding_strategy=fsdp_config.sharding_strategy,
         device_id=torch.cuda.current_device(),
         limit_all_gathers=fsdp_config.limit_all_gathers)
-    
+
     # Enabling this causes https://github.com/pytorch/examples/issues/1210
     if fsdp_config.fsdp_activation_checkpointing:
         policies.apply_fsdp_checkpointing(model)
@@ -150,7 +151,7 @@ def fsdp_main(args):
         if args.run_validation:
             curr_val_loss = validation(model, rank, world_size, val_loader)
         scheduler.step()
-        
+
         if rank == 0:
 
             print(f"--> epoch {epoch} completed...entering save and stats zone")
@@ -170,7 +171,7 @@ def fsdp_main(args):
                 )
 
         if train_config.save_model and curr_val_loss < best_val_loss:
-            
+
             if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                 model_checkpointing.save_model_checkpoint(
                     model, optimizer, rank, fsdp_config, epoch=1
@@ -183,7 +184,7 @@ def fsdp_main(args):
             if fsdp_config.save_optimizer:
                 model_checkpointing.save_optimizer_checkpoint(
                     model, optimizer, rank, fsdp_config, epoch=1
-                )           
+                )
         if curr_val_loss < best_val_loss:
 
             best_val_loss = curr_val_loss
@@ -212,5 +213,5 @@ def fsdp_main(args):
     args = parser.parse_args()
 
     torch.manual_seed(args.seed)
-    
+
     fsdp_main(args)
diff --git a/distributed/FSDP/model_checkpointing/checkpoint_handler.py b/distributed/FSDP/model_checkpointing/checkpoint_handler.py
index 5f6858476f..5d2ea84695 100644
--- a/distributed/FSDP/model_checkpointing/checkpoint_handler.py
+++ b/distributed/FSDP/model_checkpointing/checkpoint_handler.py
@@ -11,7 +11,7 @@
     # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
 )
 
-from torch.distributed._shard.checkpoint import (
+from torch.distributed.checkpoint import (
     FileSystemReader,
     FileSystemWriter,
     save_state_dict,
@@ -24,7 +24,7 @@
 
 
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-import torch.distributed._shard.checkpoint as dist_cp
+import torch.distributed.checkpoint as dist_cp
 import torch.distributed as dist
 
 
@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
         if rank == 0:
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-      
+
         dist_cp.load_state_dict(
             state_dict=checkpoint,
             storage_reader=reader,
@@ -108,7 +108,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
             state_dict=state_dict,
             storage_writer=distributed_writer,
             planner=DefaultSavePlanner(),
-            
+
         )
     dist.barrier()
     t1 = time.perf_counter()
@@ -117,7 +117,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
         print(
             f"Checkpoint Time = {t1-t0:.4f}\n using {cfg.save_using_num_threads=} total threads"
         )
-        
+
 def save_model_checkpoint(
     model,
     optimizer,
@@ -138,7 +138,7 @@ def save_model_checkpoint(
 
     if cfg.verbose:
         print(f"saving process: rank {rank}  done w model state_dict\n")
-   
+
 
     if rank == 0:
         print(f"--> saving model ...")
@@ -153,7 +153,7 @@ def save_model_checkpoint(
 
         if cfg.verbose:
             print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
-      
+
 
 
 def load_model_checkpoint(model, rank, cfg, verbose=True):
@@ -299,7 +299,7 @@ def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
             StateDictType.LOCAL_STATE_DICT,
         ):
             state_dict = model.state_dict()
-       
+
 
         # write out distributed checkpoint
         save_state_dict(state_dict, writer)
diff --git a/distributed/FSDP/requirements.txt b/distributed/FSDP/requirements.txt
index a59c5bacb2..904bf752db 100644
--- a/distributed/FSDP/requirements.txt
+++ b/distributed/FSDP/requirements.txt
@@ -3,3 +3,4 @@ datasets
 tqdm
 protobuf
 SentencePiece
+nlp
diff --git a/distributed/FSDP/summarization_dataset.py b/distributed/FSDP/summarization_dataset.py
index 679ea48ec0..b9854e4e7f 100644
--- a/distributed/FSDP/summarization_dataset.py
+++ b/distributed/FSDP/summarization_dataset.py
@@ -14,8 +14,7 @@
 import torch
 from torch.utils.data import Dataset, DataLoader
 
-from datasets import load_dataset, load_metric
-
+from nlp import load_dataset
 
 from transformers import (
     AdamW,
@@ -25,7 +24,7 @@
 )
 
 class wikihow(Dataset):
-    def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):         
+    def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):
         self.dataset =  load_dataset('wikihow', 'all', data_dir='data/', split=type_path)
         if num_samples:
             self.dataset = self.dataset.select(list(range(0, num_samples)))
@@ -33,43 +32,43 @@ def __init__(self, tokenizer, type_path, num_samples, input_length, output_lengt
         self.tokenizer = tokenizer
         self.output_length = output_length
         self.print_text = print_text
-  
+
     def __len__(self):
         return self.dataset.shape[0]
-    
+
     def clean_text(self, text):
         text = text.replace('Example of text:', '')
         text = text.replace('Example of Summary:', '')
         text = text.replace('\n','')
         text = text.replace('``', '')
         text = text.replace('"', '')
-        
+
         return text
-    
-    
+
+
     def convert_to_features(self, example_batch):
         # Tokenize contexts and questions (as pairs of inputs)
-        
+
         if self.print_text:
             print("Input Text: ", self.clean_text(example_batch['text']))
 #         input_ = self.clean_text(example_batch['text']) + " </s>"
 #         target_ = self.clean_text(example_batch['headline']) + " </s>"
-        
+
         input_ = self.clean_text(example_batch['text'])
         target_ = self.clean_text(example_batch['headline'])
-        
-        source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length, 
+
+        source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length,
                                                      padding='max_length', truncation=True, return_tensors="pt")
-        
-        targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length, 
+
+        targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length,
                                                      padding='max_length', truncation=True, return_tensors="pt")
-    
-       
+
+
         return source, targets
-  
+
     def __getitem__(self, index):
         source, targets = self.convert_to_features(self.dataset[index])
-        
+
         source_ids = source["input_ids"].squeeze()
         target_ids = targets["input_ids"].squeeze()
 
@@ -77,7 +76,7 @@ def __getitem__(self, index):
         target_mask = targets["attention_mask"].squeeze()
 
         return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}
-        
+
 def get_dataset(tokenizer, type_path, num_samples, args):
-      return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples,  input_length=max_input_length, 
+      return wikihow(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples,  input_length=max_input_length,
                         output_length=max_output_length)
diff --git a/distributed/FSDP/utils/train_utils.py b/distributed/FSDP/utils/train_utils.py
index 24cf239e7c..60e5593ec7 100644
--- a/distributed/FSDP/utils/train_utils.py
+++ b/distributed/FSDP/utils/train_utils.py
@@ -36,7 +36,7 @@ def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler
     model.train()
     local_rank = int(os.environ['LOCAL_RANK'])
     fsdp_loss = torch.zeros(2).to(local_rank)
-  
+
     if sampler:
         sampler.set_epoch(epoch)
     if rank==0:
@@ -98,5 +98,5 @@ def validation(model, rank, world_size, val_loader):
 
 def setup_model(model_name):
         model = T5ForConditionalGeneration.from_pretrained(model_name)
-        tokenizer =  T5Tokenizer.from_pretrained(model_name)
+        tokenizer =  T5Tokenizer.from_pretrained(model_name, legacy=False)
         return model, tokenizer