9
9
'port' : 5000 ,
10
10
}
11
11
12
+
12
13
class Handler (BaseHTTPRequestHandler ):
13
14
def do_GET (self ):
14
15
if self .path == '/api/v1/model' :
@@ -32,34 +33,34 @@ def do_POST(self):
32
33
self .end_headers ()
33
34
34
35
prompt = body ['prompt' ]
35
- prompt_lines = [l .strip () for l in prompt .split ('\n ' )]
36
+ prompt_lines = [k .strip () for k in prompt .split ('\n ' )]
36
37
37
38
max_context = body .get ('max_context_length' , 2048 )
38
39
39
40
while len (prompt_lines ) >= 0 and len (encode ('\n ' .join (prompt_lines ))) > max_context :
40
41
prompt_lines .pop (0 )
41
42
42
43
prompt = '\n ' .join (prompt_lines )
43
- generate_params = {
44
- 'max_new_tokens' : int (body .get ('max_length' , 200 )),
44
+ generate_params = {
45
+ 'max_new_tokens' : int (body .get ('max_length' , 200 )),
45
46
'do_sample' : bool (body .get ('do_sample' , True )),
46
- 'temperature' : float (body .get ('temperature' , 0.5 )),
47
- 'top_p' : float (body .get ('top_p' , 1 )),
48
- 'typical_p' : float (body .get ('typical' , 1 )),
49
- 'repetition_penalty' : float (body .get ('rep_pen' , 1.1 )),
47
+ 'temperature' : float (body .get ('temperature' , 0.5 )),
48
+ 'top_p' : float (body .get ('top_p' , 1 )),
49
+ 'typical_p' : float (body .get ('typical' , 1 )),
50
+ 'repetition_penalty' : float (body .get ('rep_pen' , 1.1 )),
50
51
'encoder_repetition_penalty' : 1 ,
51
- 'top_k' : int (body .get ('top_k' , 0 )),
52
+ 'top_k' : int (body .get ('top_k' , 0 )),
52
53
'min_length' : int (body .get ('min_length' , 0 )),
53
- 'no_repeat_ngram_size' : int (body .get ('no_repeat_ngram_size' ,0 )),
54
- 'num_beams' : int (body .get ('num_beams' ,1 )),
54
+ 'no_repeat_ngram_size' : int (body .get ('no_repeat_ngram_size' , 0 )),
55
+ 'num_beams' : int (body .get ('num_beams' , 1 )),
55
56
'penalty_alpha' : float (body .get ('penalty_alpha' , 0 )),
56
57
'length_penalty' : float (body .get ('length_penalty' , 1 )),
57
58
'early_stopping' : bool (body .get ('early_stopping' , False )),
58
59
'seed' : int (body .get ('seed' , - 1 )),
59
60
}
60
61
61
62
generator = generate_reply (
62
- prompt ,
63
+ prompt ,
63
64
generate_params ,
64
65
stopping_strings = body .get ('stopping_strings' , []),
65
66
)
@@ -84,9 +85,9 @@ def do_POST(self):
84
85
def run_server ():
85
86
server_addr = ('0.0.0.0' if shared .args .listen else '127.0.0.1' , params ['port' ])
86
87
server = ThreadingHTTPServer (server_addr , Handler )
87
- if shared .args .share :
88
+ if shared .args .share :
88
89
try :
89
- from flask_cloudflared import _run_cloudflared
90
+ from flask_cloudflared import _run_cloudflared
90
91
public_url = _run_cloudflared (params ['port' ], params ['port' ] + 1 )
91
92
print (f'Starting KoboldAI compatible api at { public_url } /api' )
92
93
except ImportError :
@@ -95,5 +96,6 @@ def run_server():
95
96
print (f'Starting KoboldAI compatible api at http://{ server_addr [0 ]} :{ server_addr [1 ]} /api' )
96
97
server .serve_forever ()
97
98
99
+
98
100
def setup ():
99
101
Thread (target = run_server , daemon = True ).start ()
0 commit comments