File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed
Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -229,6 +229,14 @@ class DatasetArguments(CustomDatasetArguments):
229229 "definition"
230230 },
231231 )
232+ offload_sequential_activations : bool = field (
233+ default = True ,
234+ metadata = {
235+ "help" : "Whether to offload intermediate activations between sequential "
236+ "layers to the CPU. Disabling offloading is much faster, but uses "
237+ "signficiantly more memory. Default is True."
238+ },
239+ )
232240 quantization_aware_calibration : bool = field (
233241 default = True ,
234242 metadata = {
Original file line number Diff line number Diff line change @@ -66,7 +66,6 @@ def __call__(
6666 # prepare to trace subgraphs
6767 modifiers = session .lifecycle .recipe .modifiers
6868 sequential_targets = get_sequential_targets (modifiers , model , dataset_args )
69-
7069 ignore = dataset_args .tracing_ignore
7170
7271 # trace subgraphs
@@ -90,7 +89,11 @@ def __call__(
9089 stack .enter_context (DisableQuantization (model ))
9190
9291 # prepare intermediates cache
93- activations = IntermediatesCache .from_dataloader (dataloader , model_device )
92+ cache_offload = dataset_args .offload_sequential_activations
93+ offload_device = torch .device ("cpu" ) if cache_offload else None
94+ activations = IntermediatesCache .from_dataloader (
95+ dataloader , model_device , offload_device = offload_device
96+ )
9497
9598 for subgraph_index , subgraph in enumerate (subgraphs ):
9699 # prepare tqdm description texts
You can’t perform that action at this time.
0 commit comments