@@ -47,6 +47,41 @@ class OptimizerArgsSetItem:
4747OptimItem = OptimizerArgsSetItem
4848
4949
50+ ######################################################################
51+ ## Data Dict for the code generator script ##
52+ ######################################################################
53+ # a dict of tensor name and annotation to mark whether the tensor is mutable.
54+ # this is use to annotate the tensor in the defintion schema.
55+ annotation_dict : Dict [str , str ] = {
56+ "weights" : "(a!)" ,
57+ "weights_host" : "(a!)" ,
58+ "weights_dev" : "(b!)" ,
59+ "weights_uvm" : "(c!)" ,
60+ "weights_lxu_cache" : "(d!)" ,
61+ "aux_tensor" : "(e!)" ,
62+ "uvm_cache_stats" : "(f!)" ,
63+ "momentum1" : "(g!)" ,
64+ "momentum1_host" : "(g!)" ,
65+ "momentum1_dev" : "(h!)" ,
66+ "momentum1_uvm" : "(i!)" ,
67+ "momentum2" : "(j!)" ,
68+ "momentum2_host" : "(j!)" ,
69+ "momentum2_dev" : "(k!)" ,
70+ "momentum2_uvm" : "(l!)" ,
71+ "prev_iter" : "(m!)" ,
72+ "prev_iter_host" : "(m!)" ,
73+ "prev_iter_dev" : "(n!)" ,
74+ "prev_iter_uvm" : "(o!)" ,
75+ "row_counter" : "(p!)" ,
76+ "row_counter_host" : "(p!)" ,
77+ "row_counter_dev" : "(q!)" ,
78+ "row_counter_uvm" : "(r!)" ,
79+ "optim_tensor" : "(s!)" ,
80+ "delta_weights_host" : "(t!)" ,
81+ "delta_weights_dev" : "(u!)" ,
82+ "delta_weights_uvm" : "(v!)" ,
83+ }
84+
5085######################################################################
5186## Helper functions for the code generator script ##
5287######################################################################
@@ -146,6 +181,11 @@ def tensor_arg(name: str) -> str:
146181 return f"Tensor { name } "
147182
148183
184+ def tensor_arg_annotate (name : str ) -> str :
185+ annotate = annotation_dict [name ] if name in annotation_dict else ""
186+ return f"Tensor{ annotate } { name } "
187+
188+
149189def double_arg (name : str , default : float = 0.0 ) -> str :
150190 return f"double { name } = { default } "
151191
@@ -191,7 +231,8 @@ def schema_sym_int_arg_no_default(name: str) -> str:
191231
192232
193233def schema_tensor_list_arg_no_default (name : str ) -> str :
194- return f"Tensor[] { name } "
234+ annotate = annotation_dict [name ] if name in annotation_dict else ""
235+ return f"Tensor[]{ annotate } { name } "
195236
196237
197238def bool_arg (name : str , default : bool = False ) -> str :
@@ -409,7 +450,7 @@ def make_function_arg(
409450
410451def make_function_schema_arg (ty : ArgType , name : str , default : Union [int , float ]) -> str :
411452 return {
412- ArgType .TENSOR : tensor_arg ,
453+ ArgType .TENSOR : tensor_arg_annotate ,
413454 ArgType .INT_TENSOR : tensor_arg ,
414455 ArgType .LONG_TENSOR : tensor_arg ,
415456 ArgType .PLACEHOLDER_TENSOR : tensor_arg ,
0 commit comments