@@ -61,41 +61,30 @@ class Objective(object):
6161
6262 def __init__ (self , objective_func , name = "" , description = "" ):
6363 self .objective_func = objective_func
64- self .name = name
6564 self .description = description
65+ self .value = None # This value is populated after a call
6666
6767 def __add__ (self , other ):
6868 if isinstance (other , (int , float )):
6969 objective_func = lambda T : other + self (T )
70- name = self .name
71- description = self .description
7270 else :
7371 objective_func = lambda T : self (T ) + other (T )
74- name = ", " .join ([self .name , other .name ])
75- description = "Sum(" + " +\n " .join ([self .description , other .description ]) + ")"
76- return Objective (objective_func , name = name , description = description )
72+ description = "(" + " + " .join ([str (self ), str (other )]) + ")"
73+ return Objective (objective_func , description = description )
7774
7875 def __neg__ (self ):
7976 return - 1 * self
8077
8178 def __sub__ (self , other ):
8279 return self + (- 1 * other )
8380
84- @staticmethod
85- def sum (objs ):
86- objective_func = lambda T : sum ([obj (T ) for obj in objs ])
87- descriptions = [obj .description for obj in objs ]
88- description = "Sum(" + " +\n " .join (descriptions ) + ")"
89- names = [obj .name for obj in objs ]
90- name = ", " .join (names )
91- return Objective (objective_func , name = name , description = description )
92-
9381 def __mul__ (self , other ):
9482 if isinstance (other , (int , float )):
9583 objective_func = lambda T : other * self (T )
9684 else :
9785 objective_func = lambda T : self (T ) * other (T )
98- return Objective (objective_func , name = self .name , description = self .description )
86+ description = str (self ) + "·" + str (other )
87+ return Objective (objective_func , description = description )
9988
10089 def __rmul__ (self , other ):
10190 return self .__mul__ (other )
@@ -104,7 +93,14 @@ def __radd__(self, other):
10493 return self .__add__ (other )
10594
10695 def __call__ (self , T ):
107- return self .objective_func (T )
96+ self .value = self .objective_func (T )
97+ return self .value
98+
99+ def __str__ (self ):
100+ return self .description
101+
102+ def __repr__ (self ):
103+ return self .description
108104
109105
110106def _make_arg_str (arg ):
@@ -124,7 +120,7 @@ def wrap_objective(f, *args, **kwds):
124120 """
125121 objective_func = f (* args , ** kwds )
126122 objective_name = f .__name__
127- args_str = " [ " + ", " .join ([_make_arg_str (arg ) for arg in args ]) + "] "
123+ args_str = "( " + ", " .join ([_make_arg_str (arg ) for arg in args ]) + ") "
128124 description = objective_name .title () + args_str
129125 return Objective (objective_func , objective_name , description )
130126
@@ -190,10 +186,10 @@ def direction(layer, vec, batch=None, cossim_pow=0):
190186 """Visualize a direction"""
191187 if batch is None :
192188 vec = vec [None , None , None ]
193- return lambda T : _dot_cossim (T (layer ), vec )
189+ return lambda T : _dot_cossim (T (layer ), vec , cossim_pow = cossim_pow )
194190 else :
195191 vec = vec [None , None ]
196- return lambda T : _dot_cossim (T (layer )[batch ], vec )
192+ return lambda T : _dot_cossim (T (layer )[batch ], vec , cossim_pow = cossim_pow )
197193
198194
199195@wrap_objective
0 commit comments