11"""Simulate text length distributions using input data distributions when rephrasing."""
22
3+ import copy
4+ import os
5+ import json
36import gradio as gr
47
58from models import TraverseStrategy , NetworkXStorage
6- from charts .plot_rephrase_process import plot_pre_length_distribution
9+ from charts .plot_rephrase_process import plot_pre_length_distribution , plot_post_synth_length_distribution
710from graphgen .operators .split_graph import get_batches_with_strategy
811from utils import create_event_loop
9- import copy
12+ from models import Tokenizer
1013
1114if __name__ == "__main__" :
1215 networkx_storage = NetworkXStorage (
@@ -32,22 +35,22 @@ async def get_batches(traverse_strategy: TraverseStrategy):
3235 return await get_batches_with_strategy (nodes , edges , networkx_storage , traverse_strategy )
3336
3437 def traverse_graph (
35- bidirectional : bool ,
36- expand_method : str ,
37- max_extra_edges : int ,
38- max_tokens : int ,
39- max_depth : int ,
40- edge_sampling : str ,
41- isolated_node_strategy : str
38+ ts_bidirectional : bool ,
39+ ts_expand_method : str ,
40+ ts_max_extra_edges : int ,
41+ ts_max_tokens : int ,
42+ ts_max_depth : int ,
43+ ts_edge_sampling : str ,
44+ ts_isolated_node_strategy : str
4245 ) -> str :
4346 traverse_strategy = TraverseStrategy (
44- bidirectional = bidirectional ,
45- expand_method = expand_method ,
46- max_extra_edges = max_extra_edges ,
47- max_tokens = max_tokens ,
48- max_depth = max_depth ,
49- edge_sampling = edge_sampling ,
50- isolated_node_strategy = isolated_node_strategy
47+ bidirectional = ts_bidirectional ,
48+ expand_method = ts_expand_method ,
49+ max_extra_edges = ts_max_extra_edges ,
50+ max_tokens = ts_max_tokens ,
51+ max_depth = ts_max_depth ,
52+ edge_sampling = ts_edge_sampling ,
53+ isolated_node_strategy = ts_isolated_node_strategy
5154 )
5255
5356 loop = create_event_loop ()
@@ -56,8 +59,8 @@ def traverse_graph(
5659
5760 data = []
5861 for _process_batch in batches :
59- pre_length = sum ([ node ['length' ] for node in _process_batch [0 ] ]) + sum (
60- [ edge [2 ]['length' ] for edge in _process_batch [1 ] ])
62+ pre_length = sum (node ['length' ] for node in _process_batch [0 ]) + sum (
63+ edge [2 ]['length' ] for edge in _process_batch [1 ])
6164 data .append ({
6265 'pre_length' : pre_length
6366 })
@@ -66,60 +69,88 @@ def traverse_graph(
6669 return fig
6770
6871
69- def update_sliders (expand_method ):
70- if expand_method == "max_tokens" :
72+ def update_sliders (method_name ):
73+ if method_name == "max_tokens" :
7174 return gr .update (visible = True ), gr .update (visible = False ) # Show max_tokens, hide max_extra_edges
72- else :
73- return gr .update (visible = False ), gr .update (visible = True ) # Hide max_tokens, show max_extra_edges
74-
75-
76- with gr .Blocks () as iface :
77- gr .Markdown ("# Graph Traversal Interface" )
78-
79- with gr .Row ():
80- with gr .Column ():
81- bidirectional = gr .Checkbox (label = "Bidirectional" , value = False )
82- expand_method = gr .Dropdown (
83- choices = ["max_width" , "max_tokens" ],
84- value = "max_tokens" ,
85- label = "Expand Method" ,
86- interactive = True
87- )
88-
89- # Initialize sliders
90- max_extra_edges = gr .Slider (minimum = 1 , maximum = 50 , value = 5 , step = 1 , label = "Max Extra Edges" ,
91- visible = False )
92- max_tokens = gr .Slider (minimum = 128 , maximum = 8 * 1024 , value = 1024 , step = 128 , label = "Max Tokens" )
93- max_depth = gr .Slider (minimum = 1 , maximum = 10 , value = 3 , step = 1 , label = "Max Depth" )
94- edge_sampling = gr .Dropdown (
95- choices = ["max_loss" , "random" , "min_loss" ],
96- value = "max_loss" ,
97- label = "Edge Sampling Strategy"
98- )
99- isolated_node_strategy = gr .Dropdown (
100- choices = ["add" , "ignore" , "connect" ],
101- value = "add" ,
102- label = "Isolated Node Strategy"
103- )
104- submit_btn = gr .Button ("Traverse Graph" )
105-
106- with gr .Row ():
107- output_plot = gr .Plot (label = "Graph Visualization" )
108-
109- # Set up event listener for expand_method dropdown
110- expand_method .change (fn = update_sliders , inputs = expand_method , outputs = [max_tokens , max_extra_edges ])
111-
112- submit_btn .click (
113- fn = traverse_graph ,
114- inputs = [
115- bidirectional ,
116- expand_method ,
117- max_extra_edges ,
118- max_tokens ,
119- max_depth ,
120- edge_sampling ,
121- isolated_node_strategy
122- ],
123- outputs = [output_plot ]
124- )
125- iface .launch ()
75+ return gr .update (visible = False ), gr .update (visible = True ) # Hide max_tokens, show max_extra_edges
76+
77+
78+ with gr .Blocks () as app :
79+ with gr .Tab ("Before Traversal" ):
80+ with gr .Row ():
81+ with gr .Column ():
82+ bidirectional = gr .Checkbox (label = "Bidirectional" , value = False )
83+ expand_method = gr .Dropdown (
84+ choices = ["max_width" , "max_tokens" ],
85+ value = "max_tokens" ,
86+ label = "Expand Method" ,
87+ interactive = True
88+ )
89+
90+ # Initialize sliders
91+ max_extra_edges = gr .Slider (minimum = 1 , maximum = 50 , value = 5 , step = 1 , label = "Max Extra Edges" ,
92+ visible = False )
93+ max_tokens = gr .Slider (minimum = 128 , maximum = 8 * 1024 , value = 1024 , step = 128 , label = "Max Tokens" )
94+ max_depth = gr .Slider (minimum = 1 , maximum = 10 , value = 3 , step = 1 , label = "Max Depth" )
95+ edge_sampling = gr .Dropdown (
96+ choices = ["max_loss" , "random" , "min_loss" ],
97+ value = "max_loss" ,
98+ label = "Edge Sampling Strategy"
99+ )
100+ isolated_node_strategy = gr .Dropdown (
101+ choices = ["add" , "ignore" , "connect" ],
102+ value = "add" ,
103+ label = "Isolated Node Strategy"
104+ )
105+ submit_btn = gr .Button ("Traverse Graph" )
106+
107+ with gr .Row ():
108+ output_plot = gr .Plot (label = "Graph Visualization" )
109+
110+ # Set up event listener for expand_method dropdown
111+ expand_method .change (fn = update_sliders , inputs = expand_method , outputs = [max_tokens , max_extra_edges ])
112+
113+ submit_btn .click (
114+ fn = traverse_graph ,
115+ inputs = [
116+ bidirectional ,
117+ expand_method ,
118+ max_extra_edges ,
119+ max_tokens ,
120+ max_depth ,
121+ edge_sampling ,
122+ isolated_node_strategy
123+ ],
124+ outputs = [output_plot ]
125+ )
126+
127+ with gr .Tab ("After Synthesis" ):
128+ with gr .Row ():
129+ with gr .Column ():
130+ file_list = os .listdir ("cache/data/graphgen" )
131+ input_file = gr .Dropdown (choices = file_list , label = "Input File" )
132+ file_button = gr .Button ("Submit File" )
133+
134+ with gr .Row ():
135+ output_plot = gr .Plot (label = "Graph Visualization" )
136+
137+ def synthesize_text (file ):
138+ tokenizer = Tokenizer ()
139+ with open (f"cache/data/graphgen/{ file } " , "r" , encoding = 'utf-8' ) as f :
140+ data = json .load (f )
141+ stats = []
142+ for key in data :
143+ item = data [key ]
144+ item ['post_length' ] = len (tokenizer .encode_string (item ['answer' ]))
145+ stats .append ({
146+ 'post_length' : item ['post_length' ]
147+ })
148+ fig = plot_post_synth_length_distribution (stats )
149+ return fig
150+ file_button .click (
151+ fn = synthesize_text ,
152+ inputs = [input_file ],
153+ outputs = [output_plot ]
154+ )
155+
156+ app .launch ()
0 commit comments