@@ -338,6 +338,129 @@ def _export_quantized_weight(
338338        sub_module .register_buffer (quantizer_attrs .weight_scale , weight_scale )
339339
340340
341+ def  _get_sparse_attention_config (model : nn .Module ) ->  dict [str , Any ]:
342+     """Extract sparse attention configuration from model for export. 
343+ 
344+     Args: 
345+         model: Model with sparse attention modules 
346+ 
347+     Returns: 
348+         Dictionary with sparse attention config in format: 
349+         { 
350+             "config_groups": { 
351+                 "group_0": { 
352+                     "sparse_algo": "softmax_skip", 
353+                     "threshold": 1e-4,  # only if not calibrated 
354+                     "targets": ["LlamaAttention"] 
355+                 } 
356+             }, 
357+             "threshold_scale_factor": 0.001234,  # global, if calibrated 
358+             "target_sparsity": 0.5,  # global, if calibrated 
359+             "producer": {"name": "modelopt", "version": "..."} 
360+         } 
361+     """ 
362+     from  modelopt  import  __version__ 
363+     from  modelopt .torch .sparsity .attention_sparsity .nn .sparse_attention  import  SparseAttentionModule 
364+ 
365+     # Collect all enabled sparse attention modules 
366+     sparse_modules  =  []
367+     for  name , module  in  model .named_modules ():
368+         if  isinstance (module , SparseAttentionModule ) and  module .is_enabled :
369+             sparse_modules .append ((name , module ))
370+ 
371+     if  not  sparse_modules :
372+         return  {}
373+ 
374+     sparse_config  =  {
375+         "config_groups" : {},
376+         "producer" : {
377+             "name" : "modelopt" ,
378+             "version" : __version__ ,
379+         },
380+     }
381+ 
382+     # Check first module for global calibration parameters 
383+     # (all modules share the same calibration parameters) 
384+     first_module  =  sparse_modules [0 ][1 ]
385+     method_instance  =  first_module ._sparse_method_instance 
386+     threshold_scale_factor  =  getattr (method_instance , "threshold_scale_factor" , None )
387+ 
388+     if  threshold_scale_factor  is  not   None :
389+         # Model was calibrated: add global calibration parameters 
390+         sparse_config ["threshold_scale_factor" ] =  float (threshold_scale_factor )
391+ 
392+         target_sparsity  =  getattr (method_instance , "target_sparsity" , None )
393+         if  target_sparsity  is  not   None :
394+             sparse_config ["target_sparsity" ] =  float (target_sparsity )
395+ 
396+     # Group modules by configuration 
397+     # Key: (sparse_algo, threshold_repr), Value: list of module class names 
398+     config_to_targets  =  {}
399+ 
400+     for  name , module  in  sparse_modules :
401+         method_instance  =  module ._sparse_method_instance 
402+ 
403+         # Extract sparse algorithm name from method name 
404+         # e.g., "flash_softmax_skip" -> "softmax_skip" 
405+         method_name  =  method_instance .name 
406+         if  method_name .startswith ("flash_" ):
407+             sparse_algo  =  method_name [6 :]  # Remove "flash_" prefix 
408+         else :
409+             sparse_algo  =  method_name 
410+ 
411+         # Get module's original class name for targets 
412+         # Get the class name before SparseAttentionModule wrapping 
413+         original_cls  =  module .get_original_cls_by_level (level = 0 )
414+         target_class_name  =  original_cls .__name__ 
415+ 
416+         # Build config key for grouping 
417+         if  threshold_scale_factor  is  None :
418+             # Not calibrated: include threshold in grouping 
419+             threshold_config  =  getattr (method_instance , "threshold_config" , None )
420+             if  isinstance (threshold_config , dict ):
421+                 # Convert dict to tuple for hashable key 
422+                 threshold_repr  =  tuple (sorted (threshold_config .items ()))
423+             else :
424+                 threshold_repr  =  threshold_config 
425+         else :
426+             # Calibrated: no threshold in per-layer config 
427+             threshold_repr  =  None 
428+ 
429+         config_key  =  (sparse_algo , threshold_repr )
430+ 
431+         if  config_key  not  in   config_to_targets :
432+             config_to_targets [config_key ] =  {
433+                 "sparse_algo" : sparse_algo ,
434+                 "threshold_config" : threshold_config  if  threshold_scale_factor  is  None  else  None ,
435+                 "targets" : set (),
436+             }
437+ 
438+         config_to_targets [config_key ]["targets" ].add (target_class_name )
439+ 
440+     # Convert grouped configs to config_groups format 
441+     for  group_idx , ((sparse_algo , threshold_repr ), group_data ) in  enumerate (
442+         config_to_targets .items ()
443+     ):
444+         group_name  =  f"group_{ group_idx }  " 
445+         group_config  =  {
446+             "sparse_algo" : group_data ["sparse_algo" ],
447+             "targets" : sorted (group_data ["targets" ]),
448+         }
449+ 
450+         # Add threshold only if not calibrated 
451+         if  group_data ["threshold_config" ] is  not   None :
452+             threshold_config  =  group_data ["threshold_config" ]
453+             if  isinstance (threshold_config , dict ):
454+                 # Convert to JSON-serializable format 
455+                 group_config ["threshold" ] =  {k : float (v ) for  k , v  in  threshold_config .items ()}
456+             else :
457+                 group_config ["threshold" ] =  float (threshold_config )
458+ 
459+         sparse_config ["config_groups" ][group_name ] =  group_config 
460+ 
461+     return  sparse_config 
462+ 
463+ 
341464def  _export_hf_checkpoint (
342465    model : nn .Module , dtype : torch .dtype  |  None  =  None 
343466) ->  tuple [dict [str , Any ], dict [str , Any ]]:
@@ -543,6 +666,11 @@ def export_hf_checkpoint(
543666
544667        config_data ["quantization_config" ] =  hf_quant_config 
545668
669+         # Add sparse attention config if model has sparse attention 
670+         sparse_attention_config  =  _get_sparse_attention_config (model )
671+         if  sparse_attention_config :
672+             config_data ["sparse_attention_config" ] =  sparse_attention_config 
673+ 
546674        with  open (original_config , "w" ) as  file :
547675            json .dump (config_data , file , indent = 4 )
548676
0 commit comments