Skip to content

Commit ca191f3

Browse files
feat(charts): plot length distribution
1 parent 8e39654 commit ca191f3

File tree

3 files changed

+172
-80
lines changed

3 files changed

+172
-80
lines changed

.pylintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ disable=raw-checker-failed,
433433
missing-module-docstring,
434434
missing-class-docstring,
435435
missing-function-docstring,
436+
no-member,
436437
W0122, # Use of exec (exec-used)
437438
R0914, # Too many local variables (19/15) (too-many-locals)
438439
R0903, # Too few public methods (1/2)
@@ -450,7 +451,7 @@ disable=raw-checker-failed,
450451
E1120, # TODO: unbound-method-call-no-value-for-parameter
451452
R0917, # Too many positional arguments (6/5) (too-many-positional-arguments)
452453
C0103,
453-
E0401
454+
E0401,
454455

455456
# Enable the message, report, category or checker with the given id(s). You can
456457
# either give multiple identifier separated by comma (,) or put this option

charts/plot_rephrase_process.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import re
2-
2+
import plotly.express as px
3+
from collections import defaultdict
4+
import plotly.graph_objects as go
35
import pandas as pd
46
from tqdm import tqdm
7+
58
from models import Tokenizer
69
from utils.log import parse_log
7-
import plotly.express as px
8-
import plotly.graph_objects as go
9-
from collections import defaultdict
10+
1011

1112
def analyse_log(log_info: dict) -> list:
1213
"""
@@ -80,7 +81,6 @@ def plot_pre_length_distribution(stats: list[dict]):
8081
:return fig
8182
"""
8283

83-
# 使用传入的stats参数而不是全局的data
8484
if not stats:
8585
return go.Figure()
8686

@@ -134,6 +134,66 @@ def plot_pre_length_distribution(stats: list[dict]):
134134

135135
return fig
136136

137+
def plot_post_synth_length_distribution(stats: list[dict]):
138+
"""
139+
Plot the distribution of post-synthesis length.
140+
141+
:return fig
142+
"""
143+
144+
if not stats:
145+
return go.Figure()
146+
147+
# 计算最大长度并确定区间
148+
max_length = max(item['post_length'] for item in stats)
149+
bin_size = 50
150+
max_length = ((max_length // bin_size) + 1) * bin_size
151+
152+
# 使用defaultdict避免键不存在的检查
153+
length_distribution = defaultdict(int)
154+
155+
# 一次遍历完成所有统计
156+
for item in stats:
157+
bin_start = (item['post_length'] // bin_size) * bin_size
158+
bin_key = f"{bin_start}-{bin_start + bin_size}"
159+
length_distribution[bin_key] += 1
160+
161+
# 转换为排序后的列表以保持区间顺序
162+
sorted_bins = sorted(length_distribution.keys(),
163+
key=lambda x: int(x.split('-')[0]))
164+
165+
# 创建图表
166+
fig = go.Figure(data=[
167+
go.Bar(
168+
x=sorted_bins,
169+
y=[length_distribution[bin_] for bin_ in sorted_bins],
170+
text=[length_distribution[bin_] for bin_ in sorted_bins],
171+
textposition='auto',
172+
)
173+
])
174+
175+
# 设置图表布局
176+
fig.update_layout(
177+
title='Distribution of Post-Synthesis Length',
178+
xaxis_title='Length Range',
179+
yaxis_title='Count',
180+
bargap=0.2,
181+
showlegend=False
182+
)
183+
184+
# 如果数据点过多,优化x轴标签显示
185+
if len(sorted_bins) > 10:
186+
fig.update_layout(
187+
xaxis={
188+
'tickangle': 45,
189+
'tickmode': 'array',
190+
'ticktext': sorted_bins[::2], # 每隔一个显示标签
191+
'tickvals': list(range(len(sorted_bins)))[::2]
192+
}
193+
)
194+
195+
return fig
196+
137197
if __name__ == "__main__":
138198
log = parse_log('/home/PJLAB/chenzihong/Project/graphgen/cache/logs/graphgen.log')
139199
data = analyse_log(log)

simulate.py

Lines changed: 105 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""Simulate text length distributions using input data distributions when rephrasing."""
22

3+
import copy
4+
import os
5+
import json
36
import gradio as gr
47

58
from models import TraverseStrategy, NetworkXStorage
6-
from charts.plot_rephrase_process import plot_pre_length_distribution
9+
from charts.plot_rephrase_process import plot_pre_length_distribution, plot_post_synth_length_distribution
710
from graphgen.operators.split_graph import get_batches_with_strategy
811
from utils import create_event_loop
9-
import copy
12+
from models import Tokenizer
1013

1114
if __name__ == "__main__":
1215
networkx_storage = NetworkXStorage(
@@ -32,22 +35,22 @@ async def get_batches(traverse_strategy: TraverseStrategy):
3235
return await get_batches_with_strategy(nodes, edges, networkx_storage, traverse_strategy)
3336

3437
def traverse_graph(
35-
bidirectional: bool,
36-
expand_method: str,
37-
max_extra_edges: int,
38-
max_tokens: int,
39-
max_depth: int,
40-
edge_sampling: str,
41-
isolated_node_strategy: str
38+
ts_bidirectional: bool,
39+
ts_expand_method: str,
40+
ts_max_extra_edges: int,
41+
ts_max_tokens: int,
42+
ts_max_depth: int,
43+
ts_edge_sampling: str,
44+
ts_isolated_node_strategy: str
4245
) -> str:
4346
traverse_strategy = TraverseStrategy(
44-
bidirectional=bidirectional,
45-
expand_method=expand_method,
46-
max_extra_edges=max_extra_edges,
47-
max_tokens=max_tokens,
48-
max_depth=max_depth,
49-
edge_sampling=edge_sampling,
50-
isolated_node_strategy=isolated_node_strategy
47+
bidirectional=ts_bidirectional,
48+
expand_method=ts_expand_method,
49+
max_extra_edges=ts_max_extra_edges,
50+
max_tokens=ts_max_tokens,
51+
max_depth=ts_max_depth,
52+
edge_sampling=ts_edge_sampling,
53+
isolated_node_strategy=ts_isolated_node_strategy
5154
)
5255

5356
loop = create_event_loop()
@@ -56,8 +59,8 @@ def traverse_graph(
5659

5760
data = []
5861
for _process_batch in batches:
59-
pre_length = sum([node['length'] for node in _process_batch[0]]) + sum(
60-
[edge[2]['length'] for edge in _process_batch[1]])
62+
pre_length = sum(node['length'] for node in _process_batch[0]) + sum(
63+
edge[2]['length'] for edge in _process_batch[1])
6164
data.append({
6265
'pre_length': pre_length
6366
})
@@ -66,60 +69,88 @@ def traverse_graph(
6669
return fig
6770

6871

69-
def update_sliders(expand_method):
70-
if expand_method == "max_tokens":
72+
def update_sliders(method_name):
73+
if method_name == "max_tokens":
7174
return gr.update(visible=True), gr.update(visible=False) # Show max_tokens, hide max_extra_edges
72-
else:
73-
return gr.update(visible=False), gr.update(visible=True) # Hide max_tokens, show max_extra_edges
74-
75-
76-
with gr.Blocks() as iface:
77-
gr.Markdown("# Graph Traversal Interface")
78-
79-
with gr.Row():
80-
with gr.Column():
81-
bidirectional = gr.Checkbox(label="Bidirectional", value=False)
82-
expand_method = gr.Dropdown(
83-
choices=["max_width", "max_tokens"],
84-
value="max_tokens",
85-
label="Expand Method",
86-
interactive=True
87-
)
88-
89-
# Initialize sliders
90-
max_extra_edges = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Max Extra Edges",
91-
visible=False)
92-
max_tokens = gr.Slider(minimum=128, maximum=8 * 1024, value=1024, step=128, label="Max Tokens")
93-
max_depth = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Max Depth")
94-
edge_sampling = gr.Dropdown(
95-
choices=["max_loss", "random", "min_loss"],
96-
value="max_loss",
97-
label="Edge Sampling Strategy"
98-
)
99-
isolated_node_strategy = gr.Dropdown(
100-
choices=["add", "ignore", "connect"],
101-
value="add",
102-
label="Isolated Node Strategy"
103-
)
104-
submit_btn = gr.Button("Traverse Graph")
105-
106-
with gr.Row():
107-
output_plot = gr.Plot(label="Graph Visualization")
108-
109-
# Set up event listener for expand_method dropdown
110-
expand_method.change(fn=update_sliders, inputs=expand_method, outputs=[max_tokens, max_extra_edges])
111-
112-
submit_btn.click(
113-
fn=traverse_graph,
114-
inputs=[
115-
bidirectional,
116-
expand_method,
117-
max_extra_edges,
118-
max_tokens,
119-
max_depth,
120-
edge_sampling,
121-
isolated_node_strategy
122-
],
123-
outputs=[output_plot]
124-
)
125-
iface.launch()
75+
return gr.update(visible=False), gr.update(visible=True) # Hide max_tokens, show max_extra_edges
76+
77+
78+
with gr.Blocks() as app:
79+
with gr.Tab("Before Traversal"):
80+
with gr.Row():
81+
with gr.Column():
82+
bidirectional = gr.Checkbox(label="Bidirectional", value=False)
83+
expand_method = gr.Dropdown(
84+
choices=["max_width", "max_tokens"],
85+
value="max_tokens",
86+
label="Expand Method",
87+
interactive=True
88+
)
89+
90+
# Initialize sliders
91+
max_extra_edges = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Max Extra Edges",
92+
visible=False)
93+
max_tokens = gr.Slider(minimum=128, maximum=8 * 1024, value=1024, step=128, label="Max Tokens")
94+
max_depth = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Max Depth")
95+
edge_sampling = gr.Dropdown(
96+
choices=["max_loss", "random", "min_loss"],
97+
value="max_loss",
98+
label="Edge Sampling Strategy"
99+
)
100+
isolated_node_strategy = gr.Dropdown(
101+
choices=["add", "ignore", "connect"],
102+
value="add",
103+
label="Isolated Node Strategy"
104+
)
105+
submit_btn = gr.Button("Traverse Graph")
106+
107+
with gr.Row():
108+
output_plot = gr.Plot(label="Graph Visualization")
109+
110+
# Set up event listener for expand_method dropdown
111+
expand_method.change(fn=update_sliders, inputs=expand_method, outputs=[max_tokens, max_extra_edges])
112+
113+
submit_btn.click(
114+
fn=traverse_graph,
115+
inputs=[
116+
bidirectional,
117+
expand_method,
118+
max_extra_edges,
119+
max_tokens,
120+
max_depth,
121+
edge_sampling,
122+
isolated_node_strategy
123+
],
124+
outputs=[output_plot]
125+
)
126+
127+
with gr.Tab("After Synthesis"):
128+
with gr.Row():
129+
with gr.Column():
130+
file_list = os.listdir("cache/data/graphgen")
131+
input_file = gr.Dropdown(choices=file_list, label="Input File")
132+
file_button = gr.Button("Submit File")
133+
134+
with gr.Row():
135+
output_plot = gr.Plot(label="Graph Visualization")
136+
137+
def synthesize_text(file):
138+
tokenizer = Tokenizer()
139+
with open(f"cache/data/graphgen/{file}", "r", encoding='utf-8') as f:
140+
data = json.load(f)
141+
stats = []
142+
for key in data:
143+
item = data[key]
144+
item['post_length'] = len(tokenizer.encode_string(item['answer']))
145+
stats.append({
146+
'post_length': item['post_length']
147+
})
148+
fig = plot_post_synth_length_distribution(stats)
149+
return fig
150+
file_button.click(
151+
fn=synthesize_text,
152+
inputs=[input_file],
153+
outputs=[output_plot]
154+
)
155+
156+
app.launch()

0 commit comments

Comments
 (0)