Skip to content

Commit 6709c29

Browse files
committed
fix:vectorize in batches.
1 parent 5525fc8 commit 6709c29

File tree

5 files changed

+199
-116
lines changed

5 files changed

+199
-116
lines changed

pypi/data-processing/src/data_store_process/minio_store_process.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ async def text_manipulate(
407407
return {"status": 400, "message": str(ex), "data": traceback.format_exc()}
408408

409409

410-
def text_manipulate_retry(req_json, pool):
410+
async def text_manipulate_retry(req_json, pool):
411411
task_id = req_json.get("id")
412412
creator = req_json.get("creator")
413413
log_id = ulid.ulid()
@@ -470,7 +470,7 @@ def text_manipulate_retry(req_json, pool):
470470
]
471471
)
472472
)
473-
result = _text_manipulate_retry_for_document(
473+
result = await _text_manipulate_retry_for_document(
474474
document=document,
475475
task_info=task_info_dict,
476476
log_id=log_id,
@@ -937,7 +937,7 @@ def _insert_log_info(id, task_id, execute_type, creator, pool):
937937
return {"status": 400, "message": str(ex), "data": traceback.format_exc()}
938938

939939

940-
def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creator):
940+
async def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creator):
941941
file_name = document.get("file_name")
942942
task_id = task_info.get("id")
943943
document_id = document.get("id")
@@ -1025,6 +1025,16 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat
10251025
task_id=task_id,
10261026
create_user=creator,
10271027
)
1028+
elif file_extension == "web":
1029+
# 处理.web文件
1030+
result = await web_handle.web_manipulate(
1031+
file_name=file_name,
1032+
document_id=item.get("document_id"),
1033+
support_type=support_type,
1034+
conn_pool=pool,
1035+
task_id=id,
1036+
create_user=req_json["creator"],
1037+
)
10281038

10291039
# 将下载的本地文件删除
10301040
_remove_local_file(file_name)
@@ -1042,6 +1052,7 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat
10421052
file_name=file_name,
10431053
all_document_for_process=document_chunk_dict.get("data"),
10441054
support_type=support_type,
1055+
progress=int(document.get("progress")),
10451056
conn_pool=pool,
10461057
create_user=creator,
10471058
)

pypi/data-processing/src/database_operate/data_process_document_db_operate.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,35 @@ def update_document_progress(req_json, pool):
160160
return res
161161

162162

163+
def update_document_status_and_progress(req_json, pool):
164+
"""Update the status and progress with id"""
165+
now = date_time_utils.now_str()
166+
program = "文件处理完成-修改"
167+
168+
params = {
169+
"id": req_json["id"],
170+
"status": req_json["status"],
171+
"end_time": now,
172+
"progress": req_json["progress"],
173+
"update_datetime": now,
174+
"update_program": program,
175+
}
176+
177+
sql = """
178+
update public.data_process_task_document set
179+
status = %(status)s,
180+
end_time = %(end_time)s,
181+
progress = %(progress)s,
182+
update_datetime = %(update_datetime)s,
183+
update_program = %(update_program)s
184+
where
185+
id = %(id)s
186+
""".strip()
187+
188+
res = postgresql_pool_client.execute_update(pool, sql, params)
189+
return res
190+
191+
163192
def list_file_by_task_id(req_json, pool):
164193
"""info with id"""
165194
params = {"task_id": req_json["task_id"]}

pypi/data-processing/src/file_handle/common_handle.py

Lines changed: 153 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737

3838

3939
def text_manipulate(
40-
all_document_for_process, file_name, support_type, conn_pool, create_user
40+
all_document_for_process,
41+
file_name,
42+
support_type,
43+
conn_pool,
44+
create_user,
45+
progress=0
4146
):
4247
"""Manipulate the text content.
4348
@@ -63,7 +68,7 @@ def text_manipulate(
6368
conn_pool=conn_pool,
6469
)
6570

66-
text_process_success_num = 0
71+
text_process_success_num = progress
6772
for document in all_document_for_process:
6873
document_chunk_id = document.get("id")
6974
# Clean the data such as removing invisible characters.
@@ -116,11 +121,6 @@ def text_manipulate(
116121
if qa_response.get("status") != 200:
117122
return qa_response
118123

119-
# 文件处理成功,更新data_process_task_document中的文件状态
120-
_updata_document_status_and_end_time(
121-
id=document_id, status="success", conn_pool=conn_pool
122-
)
123-
124124
if support_type_map.get("qa_split"):
125125
# 是否选择了QA拆分
126126
qa_list_dict = support_type_map.get("qa_split")
@@ -196,6 +196,13 @@ def text_manipulate(
196196
file_name=file_name_csv, phase_value="final", data=qa_data_dict
197197
)
198198

199+
_update_document_status_and_progress(
200+
id=document_id,
201+
status="success",
202+
progress=100,
203+
conn_pool=conn_pool
204+
)
205+
199206
logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text")
200207
return {
201208
"status": 200,
@@ -225,13 +232,25 @@ def text_manipulate(
225232
file_name=file_name_csv, phase_value="final", data=chunk_data_dict
226233
)
227234

235+
_update_document_status_and_progress(
236+
id=document_id,
237+
status="success",
238+
progress=100,
239+
conn_pool=conn_pool
240+
)
241+
228242
logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text")
229243
return {
230244
"status": 200,
231245
"message": "",
232246
"data": "",
233247
}
234248

249+
# 文件处理成功,更新data_process_task_document中的文件状态
250+
_updata_document_status_and_end_time(
251+
id=document_id, status="success", conn_pool=conn_pool
252+
)
253+
235254
return {"status": 200, "message": "", "data": ""}
236255
except Exception as ex:
237256
logger.error(
@@ -914,6 +933,7 @@ def _qa_split(
914933
):
915934
qa_list_dict = support_type_map.get("qa_split")
916935
llm_config = qa_list_dict.get("llm_config")
936+
remove_duplicate_config = qa_list_dict.get("remove_duplicate_config")
917937

918938
# 更新chunk状态为开始
919939
_update_document_chunk_status_and_start_time(
@@ -937,6 +957,7 @@ def _qa_split(
937957
id=document_id, status="fail", conn_pool=conn_pool
938958
)
939959
else:
960+
qa_list = []
940961
# 将QA数据存入表中
941962
qa_data = qa_response.get("data")
942963
for _, item in enumerate(qa_data):
@@ -955,6 +976,34 @@ def _qa_split(
955976
qa_insert_item, pool=conn_pool
956977
)
957978

979+
qa_list.append(qa_insert_item)
980+
981+
# 是否需要进行去重
982+
if remove_duplicate_config:
983+
for qa in qa_list:
984+
embedding_response = _embedding_qa(
985+
qa_list=[qa],
986+
remove_duplicate_config=remove_duplicate_config,
987+
conn_pool=conn_pool
988+
)
989+
990+
if embedding_response.get("status") != 200:
991+
# 处理失败
992+
# 更新data_process_task_document_chunk中的状态
993+
_updata_document_chunk_status_and_end_time(
994+
id=document_chunk_id,
995+
update_user=create_user,
996+
status="fail",
997+
conn_pool=conn_pool,
998+
)
999+
1000+
# 更新data_process_task_document中的文件状态
1001+
_updata_document_status_and_end_time(
1002+
id=document_id, status="fail", conn_pool=conn_pool
1003+
)
1004+
1005+
return embedding_response
1006+
9581007
# 更新data_process_task_document_chunk中的状态
9591008
_updata_document_chunk_status_and_end_time(
9601009
id=document_chunk_id,
@@ -965,6 +1014,9 @@ def _qa_split(
9651014

9661015
# 更新文件处理进度
9671016
progress = int(text_process_success_num / document_chunk_size * 100)
1017+
if text_process_success_num == document_chunk_size:
1018+
progress = 99
1019+
9681020
_updata_document_progress(
9691021
id=document_id,
9701022
progress=progress,
@@ -994,7 +1046,7 @@ def _generate_qa_list(content, llm_config):
9941046

9951047
# Generate the QA list.
9961048
qa_list = []
997-
if llm_spec_info.get("data").get("provider").get("worker"):
1049+
if llm_config.get("provider") == "worker":
9981050
# get base url for configmap
9991051
base_url = model_cr.get_worker_base_url_k8s_configmap(
10001052
name=config.k8s_default_config, namespace=config.k8s_pod_namespace
@@ -1190,6 +1242,26 @@ def _updata_document_progress(id, progress, update_user, conn_pool):
11901242
return {"status": 1000, "message": str(ex), "data": traceback.format_exc()}
11911243

11921244

1245+
def _update_document_status_and_progress(id, status, progress, conn_pool):
1246+
try:
1247+
document_update_item = {"id": id, "status": status, "progress": progress}
1248+
data_process_document_db_operate.update_document_status_and_progress(
1249+
document_update_item, pool=conn_pool
1250+
)
1251+
1252+
return {"status": 200, "message": "", "data": ""}
1253+
except Exception as ex:
1254+
logger.error(
1255+
"".join(
1256+
[
1257+
f"{log_tag_const.COMMON_HANDLE} update document status ",
1258+
f"\n{traceback.format_exc()}",
1259+
]
1260+
)
1261+
)
1262+
return {"status": 1000, "message": str(ex), "data": traceback.format_exc()}
1263+
1264+
11931265
def _update_document_chunk_status_and_start_time(id, update_user, conn_pool):
11941266
try:
11951267
now = date_time_utils.now_str()
@@ -1292,8 +1364,8 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
12921364
provider = remove_duplicate_config.get("embedding_provider")
12931365
similarity = float(remove_duplicate_config.get("similarity"))
12941366

1295-
# llms cr 中模型相关信息
1296-
llm_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)
1367+
# embedding cr 中模型相关信息
1368+
embedding_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)
12971369

12981370
if provider == "worker":
12991371
# get base url for configmap
@@ -1319,11 +1391,11 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
13191391
)
13201392

13211393
remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
1322-
return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity)
1394+
return remove_duplicate_loader.remove_duplicate_qa_data(qa_list, similarity)
13231395
else:
1324-
endpoint = llm_spec_info.get("data").get("provider").get("endpoint")
1396+
endpoint = embedding_spec_info.get("data").get("provider").get("endpoint")
13251397
base_url = endpoint.get("url")
1326-
llm_type = llm_spec_info.get("data").get("type")
1398+
embedding_type = embedding_spec_info.get("data").get("type")
13271399

13281400
logger.debug(
13291401
"".join(
@@ -1332,19 +1404,83 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
13321404
f"name: {name}\n",
13331405
f"namespace: {namespace}\n",
13341406
f"model: {model}\n",
1335-
f"llm_type: {llm_type}\n",
1407+
f"embedding_type: {embedding_type}\n",
1408+
]
1409+
)
1410+
)
1411+
1412+
if embedding_type == "openai":
1413+
qa_embeddings = OpenAIEmbeddings(
1414+
api_key="fake",
1415+
base_url=base_url,
1416+
model=model,
1417+
)
1418+
1419+
remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
1420+
return remove_duplicate_loader.remove_duplicate_qa_data(qa_list, similarity)
1421+
else:
1422+
return {"status": 1000, "message": f"暂时不支持{embedding_type}类型的向量化模型模型", "data": ""}
1423+
1424+
1425+
def _embedding_qa(qa_list, remove_duplicate_config, conn_pool):
1426+
name = remove_duplicate_config.get("embedding_name")
1427+
namespace = remove_duplicate_config.get("embedding_namespace")
1428+
model = remove_duplicate_config.get("embedding_model")
1429+
provider = remove_duplicate_config.get("embedding_provider")
1430+
1431+
# embeddings cr 中模型相关信息
1432+
embedding_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)
1433+
1434+
if provider == "worker":
1435+
# get base url for configmap
1436+
base_url = model_cr.get_worker_base_url_k8s_configmap(
1437+
name=config.k8s_default_config, namespace=config.k8s_pod_namespace
1438+
)
1439+
logger.debug(
1440+
"".join(
1441+
[
1442+
f"worker embedding \n",
1443+
f"name: {name}\n",
1444+
f"namespace: {namespace}\n",
1445+
f"model: {model}\n",
1446+
f"base_url: {base_url}\n",
1447+
]
1448+
)
1449+
)
1450+
1451+
qa_embeddings = OpenAIEmbeddings(
1452+
api_key="fake",
1453+
base_url=base_url,
1454+
model=model,
1455+
)
1456+
1457+
remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
1458+
return remove_duplicate_loader.embedding_qa_data(qa_list)
1459+
else:
1460+
endpoint = embedding_spec_info.get("data").get("provider").get("endpoint")
1461+
base_url = endpoint.get("url")
1462+
embedding_type = embedding_spec_info.get("data").get("type")
1463+
1464+
logger.debug(
1465+
"".join(
1466+
[
1467+
f"3rd_party embedding \n",
1468+
f"name: {name}\n",
1469+
f"namespace: {namespace}\n",
1470+
f"model: {model}\n",
1471+
f"embedding_type: {embedding_type}\n",
13361472
]
13371473
)
13381474
)
13391475

1340-
if llm_type == "openai":
1476+
if embedding_type == "openai":
13411477
qa_embeddings = OpenAIEmbeddings(
13421478
api_key="fake",
13431479
base_url=base_url,
13441480
model=model,
13451481
)
13461482

13471483
remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
1348-
return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity)
1484+
return remove_duplicate_loader.embedding_qa_data(qa_list)
13491485
else:
1350-
return {"status": 1000, "message": f"暂时不支持{llm_type}类型的向量化模型模型", "data": ""}
1486+
return {"status": 1000, "message": f"暂时不支持{embedding_type}类型的向量化模型模型", "data": ""}

0 commit comments

Comments
 (0)