25
25
api = BigML()
26
26
27
27
model = api.get_model('model/5026965515526876630001b2')
28
+ model = api.get_model('model/5026a3c315526876630001b5')
28
29
29
30
tree = Tree(model['object']['model']['root'], model['object']['model']['fields'], model['object']['objective_fields'])
30
31
tree.predict({"000002": 2.46, "000003": 1})
31
32
tree.rules()
33
+ tree.python()
32
34
33
35
"""
34
36
import logging
35
37
LOGGER = logging .getLogger ('BigML' )
36
38
37
39
import operator
40
+ import unidecode
41
+ import re
38
42
39
43
OPERATOR = {
40
44
"<" : operator .lt ,
@@ -51,11 +55,16 @@ def __init__(self, operator, field, value):
51
55
self .field = field
52
56
self .value = value
53
57
58
+ def slugify (str ):
59
+ str = unidecode .unidecode (str ).lower ()
60
+ return re .sub (r'\W+' , '_' , str )
61
+
54
62
class Tree (object ):
55
63
56
64
def __init__ (self , tree , fields , objective_field = None ):
57
65
58
66
self .fields = fields
67
+
59
68
if objective_field and isinstance (objective_field , list ):
60
69
self .objective_field = objective_field [0 ]
61
70
else :
@@ -72,7 +81,7 @@ def __init__(self, tree, fields, objective_field=None):
72
81
children = []
73
82
if 'children' in tree :
74
83
for child in tree ['children' ]:
75
- children .append (Tree (child , fields , objective_field ))
84
+ children .append (Tree (child , self . fields , objective_field ))
76
85
self .children = children
77
86
self .count = tree ['count' ]
78
87
self .distribution = tree ['distribution' ]
@@ -112,3 +121,33 @@ def rules(self, depth=0):
112
121
' ' * depth ,
113
122
self .fields [self .objective_field ]['name' ] if self .objective_field else "Prediction" ,
114
123
self .output ))
124
+
125
+
126
+ def python_body (self , depth = 1 ):
127
+ if self .children :
128
+ for child in self .children :
129
+ print ("%sif (%s %s %s)%s" %
130
+ (' ' * depth ,
131
+ self .fields [child .predicate .field ]['slug' ],
132
+ child .predicate .operator ,
133
+ child .predicate .value ,
134
+ ":" if child .children else ":" ))
135
+ child .python_body (depth + 1 )
136
+ else :
137
+ if self .fields [self .objective_field ]['optype' ] == 'numeric' :
138
+ print ("%s return %s" % (' ' * depth , self .output ))
139
+ else :
140
+ print ("%s return '%s'" % (' ' * depth , self .output ))
141
+
142
+ def python (self ):
143
+ args = []
144
+ for key in self .fields .iterkeys ():
145
+ slug = slugify (self .fields [key ]['name' ])
146
+ self .fields [key ].update (slug = slug )
147
+ if key != self .objective_field :
148
+ args .append (slug )
149
+ print ("def predict_%s(%s):" % (self .fields [self .objective_field ]['slug' ], ", " .join (args )))
150
+ self .python_body ()
151
+
152
+
153
+
0 commit comments