Skip to content

Commit 996f63e

Browse files
committed
Add utf8 to chat example
1 parent 3ceb47b commit 996f63e

File tree

3 files changed

+130
-40
lines changed

3 files changed

+130
-40
lines changed

examples/low_level_api/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def gpt_params_parse(argv = None):
102102
parser.add_argument("--frequency_penalty", type=float, default=0.0, help="repeat alpha frequency penalty (0.0 = disabled)",dest="tfs_z")
103103
parser.add_argument("--presence_penalty", type=float, default=0.0, help="repeat alpha presence penalty (0.0 = disabled)",dest="presence_penalty")
104104
parser.add_argument("--mirostat", type=float, default=1.0, help="use Mirostat sampling.",dest="mirostat")
105-
parser.add_argument("--mirostat_ent", type=float, default=5.0, help="Mirostat target entropy, parameter tau",dest="mirostat_tau")
105+
parser.add_argument("--mirostat_ent", type=float, default=5.0, help="Mirostat target entropy, parameter tau represents the average surprise value",dest="mirostat_tau")
106106
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
107107

108108
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,7 @@
1717

1818
import llama_cpp
1919
from common import GptParams, gpt_params_parse, gpt_random_prompt
20-
21-
ANSI_COLOR_RESET = "\x1b[0m"
22-
ANSI_COLOR_YELLOW = "\x1b[33m"
23-
ANSI_BOLD = "\x1b[1m"
24-
ANSI_COLOR_GREEN = "\x1b[32m"
25-
26-
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
27-
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
28-
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
29-
30-
# Iterative search
31-
# Actively searches and prevents a pattern from being returned
32-
class IterSearch:
33-
def __init__(self, pattern):
34-
self.pattern = list(pattern)
35-
self.buffer = []
36-
37-
def __call__(self, char):
38-
self.buffer += [char]
39-
40-
if (self.pattern[:len(self.buffer)] == self.buffer):
41-
if (len(self.buffer) >= len(self.pattern)):
42-
self.buffer.clear()
43-
return []
44-
45-
_tmp = self.buffer[:]
46-
self.buffer.clear()
47-
return _tmp
20+
import util
4821

4922
# A LLaMA interactive session
5023
class LLaMAInteract:
@@ -82,6 +55,7 @@ def __init__(self, params: GptParams) -> None:
8255
self.first_antiprompt = []
8356
self.remaining_tokens = self.params.n_predict
8457
self.output_echo = self.params.input_echo
58+
self.multibyte_fix = []
8559

8660
# model load
8761
self.lparams = llama_cpp.llama_context_default_params()
@@ -188,7 +162,7 @@ def __init__(self, params: GptParams) -> None:
188162
self.params.interactive_start = True
189163
_ptn = self._tokenize(self.params.instruct_inp_prefix.strip(), False)
190164
self.first_antiprompt.append(_ptn)
191-
self.antiecho = IterSearch(_ptn)
165+
self.antiecho = util.IterSearch(_ptn)
192166

193167
# enable interactive mode if reverse prompt or interactive start is specified
194168
if (len(self.params.antiprompt) != 0 or self.params.interactive_start):
@@ -256,14 +230,14 @@ def __init__(self, params: GptParams) -> None:
256230
- If you want to submit another line, end your input in '\\'.
257231
258232
""", file=sys.stderr)
259-
self.set_color(CONSOLE_COLOR_PROMPT)
233+
self.set_color(util.CONSOLE_COLOR_PROMPT)
260234

261235
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
262236

263237

264238
# tokenize a prompt
265239
def _tokenize(self, prompt, bos=True):
266-
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
240+
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
267241
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
268242
return _arr[:_n]
269243

@@ -295,7 +269,6 @@ def generate(self):
295269
self.params.path_session = ""
296270

297271
# try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
298-
# REVIEW
299272
if self.n_session_consumed < len(self.session_tokens):
300273
for i in range(len(self.embd)):
301274
if self.embd[i] != self.session_tokens[self.n_session_consumed]:
@@ -445,7 +418,7 @@ def generate(self):
445418

446419
# reset color to default if we there is no pending user input
447420
if (self.params.input_echo and len(self.embd_inp) == self.input_consumed):
448-
self.set_color(CONSOLE_COLOR_DEFAULT)
421+
self.set_color(util.CONSOLE_COLOR_DEFAULT)
449422

450423
if (self.params.interactive and len(self.embd_inp) <= self.input_consumed):
451424
# if antiprompt is present, stop
@@ -486,12 +459,12 @@ def __exit__(self, type, value, tb):
486459

487460
def exit(self):
488461
llama_cpp.llama_free(self.ctx)
489-
self.set_color(CONSOLE_COLOR_DEFAULT)
462+
self.set_color(util.CONSOLE_COLOR_DEFAULT)
490463

491464
# return past text
492465
def past(self):
493466
for id in self.last_n_tokens[-self.n_past:]:
494-
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")
467+
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore")
495468

496469
# write input
497470
def input(self, prompt: str):
@@ -505,7 +478,29 @@ def input(self, prompt: str):
505478
def output(self):
506479
self.remaining_tokens = self.params.n_predict
507480
for id in self.generate():
508-
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
481+
cur_char = llama_cpp.llama_token_to_str(self.ctx, id)
482+
483+
# Add remainder of missing bytes
484+
if None in self.multibyte_fix:
485+
self.multibyte_fix[self.multibyte_fix.index(None)] = cur_char
486+
487+
# Return completed utf char
488+
if len(self.multibyte_fix) > 0 and not None in self.multibyte_fix:
489+
yield (b"".join(self.multibyte_fix)).decode("utf8")
490+
self.multibyte_fix = []
491+
continue
492+
493+
# Contains multi-byte UTF8
494+
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
495+
# Bitwise AND check
496+
if pattern & int.from_bytes(cur_char) == pattern:
497+
self.multibyte_fix = [cur_char] + ([None] * (num-1))
498+
499+
# Stop incomplete bytes from passing
500+
if len(self.multibyte_fix) > 0:
501+
continue
502+
503+
yield cur_char.decode("utf8")
509504

510505
# read user input
511506
def read_input(self):
@@ -521,21 +516,21 @@ def interact(self):
521516
self.params.input_echo = False
522517

523518
while self.params.interactive:
524-
self.set_color(CONSOLE_COLOR_USER_INPUT)
519+
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
525520
if (self.params.instruct):
526521
print('\n> ', end="")
527522
self.input(self.read_input())
528523
else:
529524
print(self.params.input_prefix, end="")
530525
self.input(f"{self.params.input_prefix}{self.read_input()}{self.params.input_suffix}")
531526
print(self.params.input_suffix,end="")
532-
self.set_color(CONSOLE_COLOR_DEFAULT)
527+
self.set_color(util.CONSOLE_COLOR_DEFAULT)
533528

534529
try:
535530
for i in self.output():
536531
print(i,end="",flush=True)
537532
except KeyboardInterrupt:
538-
self.set_color(CONSOLE_COLOR_DEFAULT)
533+
self.set_color(util.CONSOLE_COLOR_DEFAULT)
539534
if not self.params.instruct:
540535
print(self.params.fix_prefix,end="")
541536
self.input(self.params.fix_prefix)

examples/low_level_api/util.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
2+
ANSI_COLOR_RESET = "\x1b[0m"
3+
ANSI_COLOR_YELLOW = "\x1b[33m"
4+
ANSI_BOLD = "\x1b[1m"
5+
ANSI_COLOR_GREEN = "\x1b[32m"
6+
7+
CONSOLE_COLOR_DEFAULT = ANSI_COLOR_RESET
8+
CONSOLE_COLOR_PROMPT = ANSI_COLOR_YELLOW
9+
CONSOLE_COLOR_USER_INPUT = ANSI_BOLD + ANSI_COLOR_GREEN
10+
11+
# Iterative search
12+
# Actively searches and prevents a pattern from being returned
13+
class IterSearch:
14+
def __init__(self, pattern):
15+
self.pattern = list(pattern)
16+
self.buffer = []
17+
18+
def __call__(self, char):
19+
self.buffer += [char]
20+
21+
if (self.pattern[:len(self.buffer)] == self.buffer):
22+
if (len(self.buffer) >= len(self.pattern)):
23+
self.buffer.clear()
24+
return []
25+
26+
_tmp = self.buffer[:]
27+
self.buffer.clear()
28+
return _tmp
29+
30+
class Circle:
31+
def __init__(self, size, default=0):
32+
self.list = [default] * size
33+
self.maxsize = size
34+
self.size = 0
35+
self.offset = 0
36+
37+
def append(self, elem):
38+
if self.size < self.maxsize:
39+
self.list[self.size] = elem
40+
self.size += 1
41+
else:
42+
self.list[self.offset] = elem
43+
self.offset = (self.offset + 1) % self.maxsize
44+
45+
def __getitem__(self, val):
46+
if isinstance(val, int):
47+
if 0 > val or val >= self.size:
48+
raise IndexError('Index out of range')
49+
return self.list[val] if self.size < self.maxsize else self.list[(self.offset + val) % self.maxsize]
50+
elif isinstance(val, slice):
51+
start, stop, step = val.start, val.stop, val.step
52+
if step is None:
53+
step = 1
54+
if start is None:
55+
start = 0
56+
if stop is None:
57+
stop = self.size
58+
if start < 0:
59+
start = self.size + start
60+
if stop < 0:
61+
stop = self.size + stop
62+
63+
indices = range(start, stop, step)
64+
return [self.list[(self.offset + i) % self.maxsize] for i in indices if i < self.size]
65+
else:
66+
raise TypeError('Invalid argument type')
67+
68+
69+
70+
71+
if __name__ == "__main__":
72+
c = Circle(5)
73+
74+
c.append(1)
75+
print(c.list)
76+
print(c[:])
77+
assert c[0] == 1
78+
assert c[:5] == [1]
79+
80+
for i in range(2,5+1):
81+
c.append(i)
82+
print(c.list)
83+
print(c[:])
84+
assert c[0] == 1
85+
assert c[:5] == [1,2,3,4,5]
86+
87+
for i in range(5+1,9+1):
88+
c.append(i)
89+
print(c.list)
90+
print(c[:])
91+
assert c[0] == 5
92+
assert c[:5] == [5,6,7,8,9]
93+
#assert c[:-5] == [5,6,7,8,9]
94+
assert c[:10] == [5,6,7,8,9]
95+

0 commit comments

Comments
 (0)