From eeb17bdbccd0154c0feca5224d124257bdcf5ba5 Mon Sep 17 00:00:00 2001 From: kyoto7250 <50972773+kyoto7250@users.noreply.github.com> Date: Sun, 16 Jun 2024 15:57:13 +0900 Subject: [PATCH] merge my PRs --- Cargo.lock | 13 ++ Cargo.toml | 1 + README.md | 4 +- examples/config.toml | 26 +++ examples/key_bind.ron | 44 +++++ src/app.rs | 87 +++++++- src/components/command.rs | 27 +++ src/components/completion.rs | 7 +- src/components/properties.rs | 4 + src/components/record_table.rs | 6 +- src/components/sql_editor.rs | 2 +- src/components/table.rs | 286 ++++++++++++++++++++++++++- src/config.rs | 351 +++++++++++++++++++++++++++++---- src/database/mod.rs | 1 + src/database/mysql.rs | 30 ++- src/database/postgres.rs | 58 +++++- src/database/sqlite.rs | 24 ++- src/event/events.rs | 1 + src/key_bind.rs | 179 +++++++++++++++++ src/main.rs | 1 + 20 files changed, 1088 insertions(+), 64 deletions(-) create mode 100644 examples/config.toml create mode 100644 examples/key_bind.ron create mode 100644 src/key_bind.rs diff --git a/Cargo.lock b/Cargo.lock index b275a3f..3ab4c74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1731,6 +1731,18 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ron" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" +dependencies = [ + "base64", + "bitflags 2.5.0", + "serde", + "serde_derive", +] + [[package]] name = "rsa" version = "0.9.6" @@ -3193,6 +3205,7 @@ dependencies = [ "itertools 0.13.0", "pretty_assertions", "ratatui", + "ron", "rust_decimal", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index e002639..6ebae82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ clap = "4.5.7" structopt = "0.3.26" syntect = { version = "5.0", default-features = false, features = ["metadata", "default-fancy"]} unicode-segmentation = "1.11.0" +ron = "0.8.1" [target.'cfg(all(target_family="unix",not(target_os="macos")))'.dependencies] which = "6.0.1" diff --git a/README.md b/README.md index 938573a..e3d5276 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# dolce +# zhobo -`dolce` is the rebaked [gobang project](https://github.com/TaKO8Ki/gobang). +`zhobo` is the rebaked [gobang project](https://github.com/TaKO8Ki/gobang). ## Features - Cross-platform support (macOS, Windows, Linux) diff --git a/examples/config.toml b/examples/config.toml new file mode 100644 index 0000000..f2a4505 --- /dev/null +++ b/examples/config.toml @@ -0,0 +1,26 @@ +[[conn]] +type = "mysql" +user = "root" +host = "localhost" +port = 3306 + +[[conn]] +type = "mysql" +user = "root" +host = "localhost" +port = 3306 +password = "password" +database = "foo" +name = "mysql Foo DB" + +[[conn]] +type = "postgres" +user = "root" +host = "localhost" +port = 5432 +database = "bar" +name = "postgres Bar DB" + +[[conn]] +type = "sqlite" +path = "/path/to/baz.db" \ No newline at end of file diff --git a/examples/key_bind.ron b/examples/key_bind.ron new file mode 100644 index 0000000..407acfc --- /dev/null +++ b/examples/key_bind.ron @@ -0,0 +1,44 @@ +/* + * This file is a custom key configuration file. + * Place this file in `$HOME/.config/zhobo/key_bind.ron`. +*/ +( + scroll_up: Some(Char('k')), + scroll_down: Some(Char('j')), + scroll_right: Some(Char('l')), + scroll_left: Some(Char('h')), + sort_by_column: Some(Char('s')), + move_up: Some(Up), + move_down: Some(Down), + copy: Some(Char('y')), + enter: Some(Enter), + exit: Some(Ctrl('c')), + quit: Some(Char('q')), + exit_popup: Some(Esc), + focus_right: Some(Right), + focus_left: Some(Left), + focus_above: Some(Up), + focus_connections: Some(Char('c')), + open_help: Some(Char('?')), + filter: Some(Char('/')), + scroll_down_multiple_lines: Some(Ctrl('d')), + scroll_up_multiple_lines: Some(Ctrl('u')), + scroll_to_top: Some(Char('g')), + scroll_to_bottom: Some(Char('G')), + move_to_head_of_line: Some(Char('^')), + move_to_tail_of_line: Some(Char('$')), + extend_selection_by_one_cell_left: Some(Char('H')), + extend_selection_by_one_cell_right: Some(Char('L')), + extend_selection_by_one_cell_down: Some(Char('J')), + extend_selection_by_horizontal_line: Some(Char('V')), + extend_selection_by_one_cell_up: Some(Char('K')), + tab_records: Some(Char('1')), + tab_properties: Some(Char('2')), + tab_sql_editor: Some(Char('3')), + tab_columns: Some(Char('4')), + tab_constraints: Some(Char('5')), + tab_foreign_keys: Some(Char('6')), + tab_indexes: Some(Char('7')), + extend_or_shorten_widget_width_to_right: Some(Char('>')), + extend_or_shorten_widget_width_to_left: Some(Char('<')), +) diff --git a/src/app.rs b/src/app.rs index 6e5684f..e077ba4 100644 --- a/src/app.rs +++ b/src/app.rs @@ -12,6 +12,7 @@ use crate::components::{ use crate::config::Config; use crate::database::{MySqlPool, Pool, PostgresPool, SqlitePool, RECORDS_LIMIT_PER_PAGE}; use crate::event::Key; +use ratatui::layout::Flex; use ratatui::{ layout::{Constraint, Direction, Layout, Rect}, Frame, @@ -82,6 +83,7 @@ impl App { let right_chunks = Layout::default() .direction(Direction::Vertical) + .flex(Flex::Legacy) .constraints([Constraint::Length(3), Constraint::Length(5)].as_ref()) .split(main_chunks[1]); @@ -139,6 +141,7 @@ impl App { if let Some(pool) = self.pool.as_ref() { pool.close().await; } + self.pool = if conn.is_mysql() { Some(Box::new( MySqlPool::new(conn.database_url()?.as_str()).await?, @@ -162,7 +165,12 @@ impl App { Ok(()) } - async fn update_record_table(&mut self) -> anyhow::Result<()> { + async fn update_record_table( + &mut self, + orders: Option, + header_icons: Option>, + hold_cursor_position: bool, + ) -> anyhow::Result<()> { if let Some((database, table)) = self.databases.tree().selected_table() { let (headers, records) = self .pool @@ -177,10 +185,16 @@ impl App { } else { Some(self.record_table.filter.input_str()) }, + orders, ) .await?; - self.record_table - .update(records, headers, database.clone(), table.clone()); + self.record_table.update( + records, + self.concat_headers(headers, header_icons), + database.clone(), + table.clone(), + hold_cursor_position, + ); } Ok(()) } @@ -230,10 +244,15 @@ impl App { .pool .as_ref() .unwrap() - .get_records(&database, &table, 0, None) + .get_records(&database, &table, 0, None, None) .await?; - self.record_table - .update(records, headers, database.clone(), table.clone()); + self.record_table.update( + records, + headers, + database.clone(), + table.clone(), + false, + ); self.properties .update(database.clone(), table.clone(), self.pool.as_ref().unwrap()) .await?; @@ -249,6 +268,17 @@ impl App { return Ok(EventState::Consumed); }; + if key == self.config.key_config.sort_by_column + && !self.record_table.table.headers.is_empty() + { + self.record_table.table.add_order(); + let order_query = self.record_table.table.generate_order_query(); + let header_icons = self.record_table.table.generate_header_icons(); + self.update_record_table(order_query, Some(header_icons), true) + .await?; + return Ok(EventState::Consumed); + }; + if key == self.config.key_config.copy { if let Some(text) = self.record_table.table.selected_cells() { copy_to_clipboard(text.as_str())? @@ -258,7 +288,10 @@ impl App { if key == self.config.key_config.enter && self.record_table.filter_focused() { self.record_table.focus = crate::components::record_table::Focus::Table; - self.update_record_table().await?; + let order_query = self.record_table.table.generate_order_query(); + let header_icons = self.record_table.table.generate_header_icons(); + self.update_record_table(order_query, Some(header_icons), false) + .await?; } if self.record_table.table.eod { @@ -283,6 +316,7 @@ impl App { } else { Some(self.record_table.filter.input_str()) }, + None, ) .await?; if !records.is_empty() { @@ -321,6 +355,24 @@ impl App { Ok(EventState::NotConsumed) } + fn concat_headers( + &self, + headers: Vec, + header_icons: Option>, + ) -> Vec { + if let Some(header_icons) = &header_icons { + let mut new_headers = vec![String::new(); headers.len()]; + for (index, header) in headers.iter().enumerate() { + new_headers[index] = format!("{} {}", header, header_icons[index]) + .trim() + .to_string(); + } + return new_headers; + } + + headers + } + fn extend_or_shorten_widget_width(&mut self, key: Key) -> anyhow::Result { if key == self @@ -408,4 +460,25 @@ mod test { ); assert_eq!(app.left_main_chunk_percentage, 15); } + + #[test] + fn test_concat_headers() { + let app = App::new(Config::default()); + let headers = vec![ + "ID".to_string(), + "NAME".to_string(), + "TIMESTAMP".to_string(), + ]; + let header_icons = vec!["".to_string(), "↑1".to_string(), "↓2".to_string()]; + let concat_headers: Vec = app.concat_headers(headers, Some(header_icons)); + + assert_eq!( + concat_headers, + vec![ + "ID".to_string(), + "NAME ↑1".to_string(), + "TIMESTAMP ↓2".to_string() + ] + ) + } } diff --git a/src/components/command.rs b/src/components/command.rs index e569309..0fcd2d5 100644 --- a/src/components/command.rs +++ b/src/components/command.rs @@ -62,6 +62,16 @@ pub fn scroll_to_top_bottom(key: &KeyConfig) -> CommandText { ) } +pub fn move_to_head_tail_of_line(key: &KeyConfig) -> CommandText { + CommandText::new( + format!( + "Move to head/tail of line [{},{}]", + key.move_to_head_of_line, key.move_to_tail_of_line, + ), + CMD_GROUP_TABLE, + ) +} + pub fn expand_collapse(key: &KeyConfig) -> CommandText { CommandText::new( format!("Expand/Collapse [{},{}]", key.scroll_right, key.scroll_left,), @@ -83,6 +93,13 @@ pub fn move_focus(key: &KeyConfig) -> CommandText { ) } +pub fn sort_by_column(key: &KeyConfig) -> CommandText { + CommandText::new( + format!("Sort by column [{}]", key.sort_by_column), + CMD_GROUP_TABLE, + ) +} + pub fn extend_selection_by_one_cell(key: &KeyConfig) -> CommandText { CommandText::new( format!( @@ -96,6 +113,16 @@ pub fn extend_selection_by_one_cell(key: &KeyConfig) -> CommandText { ) } +pub fn extend_selection_by_line(key: &KeyConfig) -> CommandText { + CommandText::new( + format!( + "Extend selection by horizontal line [{}]", + key.extend_selection_by_horizontal_line, + ), + CMD_GROUP_TABLE, + ) +} + pub fn extend_or_shorten_widget_width(key: &KeyConfig) -> CommandText { CommandText::new( format!( diff --git a/src/components/completion.rs b/src/components/completion.rs index c65ba1a..eaf2b6f 100644 --- a/src/components/completion.rs +++ b/src/components/completion.rs @@ -48,7 +48,7 @@ impl CompletionComponent { fn next(&mut self) { let i = match self.state.selected() { Some(i) => { - if i >= self.filterd_candidates().count() - 1 { + if i + 1 >= self.filterd_candidates().count() { 0 } else { i + 1 @@ -63,7 +63,10 @@ impl CompletionComponent { let i = match self.state.selected() { Some(i) => { if i == 0 { - self.filterd_candidates().count() - 1 + self.filterd_candidates() + .count() + .checked_sub(1) + .unwrap_or(0) } else { i - 1 } diff --git a/src/components/properties.rs b/src/components/properties.rs index cdf7e8a..b67907a 100644 --- a/src/components/properties.rs +++ b/src/components/properties.rs @@ -76,6 +76,7 @@ impl PropertiesComponent { columns.first().unwrap().fields(), database.clone(), table.clone(), + false, ); } self.constraint_table.reset(); @@ -89,6 +90,7 @@ impl PropertiesComponent { constraints.first().unwrap().fields(), database.clone(), table.clone(), + false, ); } self.foreign_key_table.reset(); @@ -102,6 +104,7 @@ impl PropertiesComponent { foreign_keys.first().unwrap().fields(), database.clone(), table.clone(), + false, ); } self.index_table.reset(); @@ -115,6 +118,7 @@ impl PropertiesComponent { indexes.first().unwrap().fields(), database.clone(), table.clone(), + false, ); } Ok(()) diff --git a/src/components/record_table.rs b/src/components/record_table.rs index 7e645db..ae8108a 100644 --- a/src/components/record_table.rs +++ b/src/components/record_table.rs @@ -5,6 +5,7 @@ use crate::config::KeyConfig; use crate::event::Key; use crate::tree::{Database, Table as DTable}; use anyhow::Result; +use ratatui::layout::Flex; use ratatui::{ layout::{Constraint, Direction, Layout, Rect}, Frame, @@ -38,8 +39,10 @@ impl RecordTableComponent { headers: Vec, database: Database, table: DTable, + hold_cursor_position: bool, ) { - self.table.update(rows, headers, database, table.clone()); + self.table + .update(rows, headers, database, table.clone(), hold_cursor_position); self.filter.table = Some(table); } @@ -58,6 +61,7 @@ impl StatefulDrawableComponent for RecordTableComponent { let layout = Layout::default() .direction(Direction::Vertical) .constraints(vec![Constraint::Length(3), Constraint::Length(5)]) + .flex(Flex::Legacy) .split(area); self.table diff --git a/src/components/sql_editor.rs b/src/components/sql_editor.rs index 92a56db..f2205ea 100644 --- a/src/components/sql_editor.rs +++ b/src/components/sql_editor.rs @@ -268,7 +268,7 @@ impl Component for SqlEditorComponent { database, table, } => { - self.table.update(rows, headers, database, table); + self.table.update(rows, headers, database, table, false); self.focus = Focus::Table; self.query_result = None; } diff --git a/src/components/table.rs b/src/components/table.rs index fbefb2e..2b3775c 100644 --- a/src/components/table.rs +++ b/src/components/table.rs @@ -7,6 +7,7 @@ use crate::config::KeyConfig; use crate::event::Key; use crate::tree::{Database, Table as DTable}; use anyhow::Result; +use ratatui::layout::Flex; use ratatui::{ layout::{Constraint, Direction, Layout, Rect}, style::{Color, Modifier, Style}, @@ -16,11 +17,92 @@ use ratatui::{ use std::convert::From; use unicode_width::UnicodeWidthStr; +#[derive(Debug, PartialEq)] +struct Order { + pub column_number: usize, + pub is_asc: bool, +} + +impl Order { + pub fn new(column_number: usize, is_asc: bool) -> Self { + Self { + column_number, + is_asc, + } + } + + fn query(&self) -> String { + let order = if self.is_asc { "ASC" } else { "DESC" }; + + return format!( + "{column} {order}", + column = self.column_number, + order = order + ); + } +} + +#[derive(PartialEq)] +struct OrderManager { + orders: Vec, +} + +impl OrderManager { + fn new() -> Self { + Self { orders: vec![] } + } + + fn generate_order_query(&mut self) -> Option { + let order_query = self + .orders + .iter() + .map(|order| order.query()) + .collect::>(); + + if !order_query.is_empty() { + return Some("ORDER BY ".to_string() + &order_query.join(", ")); + } + + None + } + + fn generate_header_icons(&mut self, header_length: usize) -> Vec { + let mut header_icons = vec![String::new(); header_length]; + + for (index, order) in self.orders.iter().enumerate() { + let arrow = if order.is_asc { "↑" } else { "↓" }; + header_icons[order.column_number - 1] = + format!("{arrow}{number}", arrow = arrow, number = index + 1); + } + + header_icons + } + + fn add_order(&mut self, selected_column: usize) { + let selected_column_number = selected_column + 1; + if let Some(position) = self + .orders + .iter() + .position(|order| order.column_number == selected_column_number) + { + if self.orders[position].is_asc { + self.orders[position].is_asc = false; + } else { + self.orders.remove(position); + } + } else { + let order = Order::new(selected_column_number, true); + self.orders.push(order); + } + } +} + pub struct TableComponent { pub headers: Vec, pub rows: Vec>, pub eod: bool, pub selected_row: TableState, + orders: OrderManager, table: Option<(Database, DTable)>, selected_column: usize, selection_area_corner: Option<(usize, usize)>, @@ -35,6 +117,7 @@ impl TableComponent { selected_row: TableState::default(), headers: vec![], rows: vec![], + orders: OrderManager::new(), table: None, selected_column: 0, selection_area_corner: None, @@ -57,6 +140,7 @@ impl TableComponent { headers: Vec, database: Database, table: DTable, + hold_cusor_position: bool, ) { self.selected_row.select(None); if !rows.is_empty() { @@ -64,7 +148,11 @@ impl TableComponent { } self.headers = headers; self.rows = rows; - self.selected_column = 0; + self.selected_column = if hold_cusor_position { + self.selected_column + } else { + 0 + }; self.selection_area_corner = None; self.column_page_start = std::cell::Cell::new(0); self.scroll = VerticalScroll::new(false, false); @@ -76,6 +164,7 @@ impl TableComponent { self.selected_row.select(None); self.headers = Vec::new(); self.rows = Vec::new(); + self.orders = OrderManager::new(); self.selected_column = 0; self.selection_area_corner = None; self.column_page_start = std::cell::Cell::new(0); @@ -88,10 +177,31 @@ impl TableComponent { self.selection_area_corner = None; } + pub fn add_order(&mut self) { + self.orders.add_order(self.selected_column) + } + + pub fn generate_order_query(&mut self) -> Option { + self.orders.generate_order_query() + } + + pub fn generate_header_icons(&mut self) -> Vec { + self.orders.generate_header_icons(self.headers.len()) + } + pub fn end(&mut self) { self.eod = true; } + fn move_to_head_of_line(&mut self) { + self.selected_column = 0; + } + + fn move_to_tail_of_line(&mut self) { + let vertical_length = self.headers.len().saturating_sub(1); + self.selected_column = vertical_length; + } + fn next_row(&mut self, lines: usize) { let i = match self.selected_row.selected() { Some(i) => { @@ -199,6 +309,21 @@ impl TableComponent { } } + fn expand_selected_by_horizontal_line(&mut self) { + let horizontal_length = self.headers.len().saturating_sub(1); + let vertical_length = self.selected_row.selected().unwrap_or(0); + + if let Some((x, y)) = self.selection_area_corner { + if x == horizontal_length { + self.selection_area_corner = None; + } else { + self.selection_area_corner = Some((horizontal_length, y)); + } + } else { + self.selection_area_corner = Some((horizontal_length, vertical_length)); + } + } + pub fn selected_cells(&self) -> Option { if let Some((x, y)) = self.selection_area_corner { let selected_row_index = self.selected_row.selected()?; @@ -405,6 +530,7 @@ impl StatefulDrawableComponent for TableComponent { .vertical_margin(1) .horizontal_margin(1) .direction(Direction::Vertical) + .flex(Flex::Legacy) .constraints( [ Constraint::Length(2), @@ -471,15 +597,15 @@ impl StatefulDrawableComponent for TableComponent { }); Row::new(cells).height(height as u16).bottom_margin(1) }); - - let table = Table::new(rows, &constraints) + let table = Table::default() + .rows(rows) .header(header) - .block(block) .style(if focused { Style::default() } else { Style::default().fg(Color::DarkGray) - }); + }) + .widths(&constraints); let mut state = self.selected_row.clone(); f.render_stateful_widget( table, @@ -520,6 +646,13 @@ impl Component for TableComponent { out.push(CommandInfo::new(command::extend_selection_by_one_cell( &self.key_config, ))); + out.push(CommandInfo::new(command::extend_selection_by_line( + &self.key_config, + ))); + out.push(CommandInfo::new(command::move_to_head_tail_of_line( + &self.key_config, + ))); + out.push(CommandInfo::new(command::sort_by_column(&self.key_config))); } fn event(&mut self, key: Key) -> Result { @@ -544,12 +677,19 @@ impl Component for TableComponent { } else if key == self.key_config.scroll_to_bottom { self.scroll_to_bottom(); return Ok(EventState::Consumed); + } else if key == self.key_config.move_to_head_of_line { + self.move_to_head_of_line(); + } else if key == self.key_config.move_to_tail_of_line { + self.move_to_tail_of_line(); } else if key == self.key_config.scroll_right { self.next_column(); return Ok(EventState::Consumed); } else if key == self.key_config.extend_selection_by_one_cell_left { self.expand_selected_area_x(false); return Ok(EventState::Consumed); + } else if key == self.key_config.extend_selection_by_horizontal_line { + self.expand_selected_by_horizontal_line(); + return Ok(EventState::Consumed); } else if key == self.key_config.extend_selection_by_one_cell_up { self.expand_selected_area_y(false); return Ok(EventState::Consumed); @@ -566,7 +706,7 @@ impl Component for TableComponent { #[cfg(test)] mod test { - use super::{KeyConfig, TableComponent}; + use super::{KeyConfig, Order, OrderManager, TableComponent}; use ratatui::layout::Constraint; #[test] @@ -684,6 +824,32 @@ mod test { assert_eq!(component.selected_cells(), Some("b\ne".to_string())); } + #[test] + fn test_expand_selected_by_horizontal_line() { + let mut component = TableComponent::new(KeyConfig::default()); + component.headers = vec!["a", "b", "c"].iter().map(|h| h.to_string()).collect(); + component.rows = vec![ + vec!["d", "e", "f"].iter().map(|h| h.to_string()).collect(), + vec!["g", "h", "i"].iter().map(|h| h.to_string()).collect(), + ]; + + // select one line + component.selected_row.select(Some(0)); + component.expand_selected_by_horizontal_line(); + assert_eq!(component.selection_area_corner, Some((2, 0))); + assert_eq!(component.selected_cells(), Some("d,e,f".to_string())); + + // undo select horizontal line + component.expand_selected_by_horizontal_line(); + assert_eq!(component.selection_area_corner, None); + + // select two line + component.expand_selected_area_y(true); + component.expand_selected_by_horizontal_line(); + assert_eq!(component.selection_area_corner, Some((2, 1))); + assert_eq!(component.selected_cells(), Some("d,e,f\ng,h,i".to_string())); + } + #[test] fn test_is_number_column() { let mut component = TableComponent::new(KeyConfig::default()); @@ -777,6 +943,42 @@ mod test { assert!(!component.is_selected_cell(1, 3, 1)); } + #[test] + fn test_move_to_head_of_line() { + let mut component = TableComponent::new(KeyConfig::default()); + + component.headers = vec!["a", "b", "c"].iter().map(|h| h.to_string()).collect(); + component.rows = vec![ + vec!["d", "e", "f"].iter().map(|h| h.to_string()).collect(), + vec!["g", "h", "i"].iter().map(|h| h.to_string()).collect(), + ]; + + // cursor returns to the top. + component.expand_selected_area_y(true); + component.expand_selected_area_y(true); + component.move_to_head_of_line(); + assert_eq!(component.selected_column, 0); + } + + #[test] + fn test_move_to_tail_of_line() { + let mut component = TableComponent::new(KeyConfig::default()); + + // if component does not have a header, cursor is not moved. + component.move_to_head_of_line(); + assert_eq!(component.selected_column, 0); + + // if component has a header, cursor is moved to tail of line. + component.headers = vec!["a", "b", "c"].iter().map(|h| h.to_string()).collect(); + component.rows = vec![ + vec!["d", "e", "f"].iter().map(|h| h.to_string()).collect(), + vec!["g", "h", "i"].iter().map(|h| h.to_string()).collect(), + ]; + + component.move_to_tail_of_line(); + assert_eq!(component.selected_column, 2); + } + #[test] fn test_calculate_cell_widths_when_sum_of_cell_widths_is_greater_than_table_width() { let mut component = TableComponent::new(KeyConfig::default()); @@ -873,4 +1075,76 @@ mod test { ] ); } + + #[test] + fn test_query() { + let asc_order = Order::new(1, true); + let desc_order = Order::new(2, false); + + assert_eq!(asc_order.query(), "1 ASC".to_string()); + assert_eq!(desc_order.query(), "2 DESC".to_string()); + } + + #[test] + fn test_generate_order_query() { + let mut order_manager = OrderManager::new(); + + // If orders is empty, it should return None. + assert_eq!(order_manager.generate_order_query(), None); + + order_manager.add_order(1); + order_manager.add_order(1); + order_manager.add_order(2); + assert_eq!( + order_manager.generate_order_query(), + Some("ORDER BY 2 DESC, 3 ASC".to_string()) + ) + } + + #[test] + fn test_generate_header_icons() { + let mut order_manager = OrderManager::new(); + assert_eq!(order_manager.generate_header_icons(1), vec![String::new()]); + + order_manager.add_order(1); + order_manager.add_order(1); + order_manager.add_order(2); + assert_eq!( + order_manager.generate_header_icons(3), + vec![String::new(), "↓1".to_string(), "↑2".to_string()] + ); + assert_eq!( + order_manager.generate_header_icons(4), + vec![ + String::new(), + "↓1".to_string(), + "↑2".to_string(), + String::new() + ] + ); + } + + #[test] + fn test_add_order() { + let mut order_manager = OrderManager::new(); + + // press first time, condition is asc. + order_manager.add_order(1); + assert_eq!(order_manager.orders, vec![Order::new(2, true)]); + + // press twice times, condition is desc. + order_manager.add_order(1); + assert_eq!(order_manager.orders, vec![Order::new(2, false)]); + + // press another column, this column is second order. + order_manager.add_order(2); + assert_eq!( + order_manager.orders, + vec![Order::new(2, false), Order::new(3, true)] + ); + + // press three times, removed. + order_manager.add_order(1); + assert_eq!(order_manager.orders, vec![Order::new(3, true)]); + } } diff --git a/src/config.rs b/src/config.rs index e327836..b7ed538 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,4 @@ +use crate::key_bind::KeyBind; use crate::log::LogLevel; use crate::Key; use serde::Deserialize; @@ -15,6 +16,17 @@ pub struct CliConfig { /// Set the config file #[structopt(long, short, global = true)] config_path: Option, + + /// Set the key bind file + #[structopt(long, short, global = true)] + key_bind_path: Option, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct ReadConfig { + pub conn: Vec, + #[serde(default)] + pub log_level: LogLevel, } #[derive(Debug, Deserialize, Clone)] @@ -58,6 +70,7 @@ impl Default for Config { path: None, password: None, database: None, + unix_domain_socket: None, }], key_config: KeyConfig::default(), log_level: LogLevel::default(), @@ -74,16 +87,18 @@ pub struct Connection { port: Option, path: Option, password: Option, + unix_domain_socket: Option, pub database: Option, } #[derive(Debug, Deserialize, Clone)] -#[cfg_attr(test, derive(Serialize))] +#[cfg_attr(test, derive(Serialize, PartialEq))] pub struct KeyConfig { pub scroll_up: Key, pub scroll_down: Key, pub scroll_right: Key, pub scroll_left: Key, + pub sort_by_column: Key, pub move_up: Key, pub move_down: Key, pub copy: Key, @@ -101,10 +116,13 @@ pub struct KeyConfig { pub scroll_up_multiple_lines: Key, pub scroll_to_top: Key, pub scroll_to_bottom: Key, + pub move_to_head_of_line: Key, + pub move_to_tail_of_line: Key, pub extend_selection_by_one_cell_left: Key, pub extend_selection_by_one_cell_right: Key, pub extend_selection_by_one_cell_up: Key, pub extend_selection_by_one_cell_down: Key, + pub extend_selection_by_horizontal_line: Key, pub tab_records: Key, pub tab_columns: Key, pub tab_constraints: Key, @@ -123,6 +141,7 @@ impl Default for KeyConfig { scroll_down: Key::Char('j'), scroll_right: Key::Char('l'), scroll_left: Key::Char('h'), + sort_by_column: Key::Char('s'), move_up: Key::Up, move_down: Key::Down, copy: Key::Char('y'), @@ -140,9 +159,12 @@ impl Default for KeyConfig { scroll_up_multiple_lines: Key::Ctrl('u'), scroll_to_top: Key::Char('g'), scroll_to_bottom: Key::Char('G'), + move_to_head_of_line: Key::Char('^'), + move_to_tail_of_line: Key::Char('$'), extend_selection_by_one_cell_left: Key::Char('H'), extend_selection_by_one_cell_right: Key::Char('L'), extend_selection_by_one_cell_down: Key::Char('J'), + extend_selection_by_horizontal_line: Key::Char('V'), extend_selection_by_one_cell_up: Key::Char('K'), tab_records: Key::Char('1'), tab_properties: Key::Char('2'), @@ -164,23 +186,56 @@ impl Config { } else { get_app_config_path()?.join("config.toml") }; + + let key_bind_path = if let Some(key_bind_path) = &config.key_bind_path { + key_bind_path.clone() + } else { + get_app_config_path()?.join("key_bind.ron") + }; + if let Ok(file) = File::open(config_path) { let mut buf_reader = BufReader::new(file); let mut contents = String::new(); buf_reader.read_to_string(&mut contents)?; - - let config: Result = toml::from_str(&contents); + let config: Result = toml::from_str(&contents); match config { - Ok(config) => return Ok(config), + Ok(config) => return Ok(Config::build(config, key_bind_path)), Err(e) => panic!("fail to parse config file: {}", e), } } Ok(Config::default()) } + + fn build(read_config: ReadConfig, key_bind_path: PathBuf) -> Self { + let key_bind = KeyBind::load(key_bind_path).unwrap(); + Config { + conn: read_config.conn, + log_level: read_config.log_level, + key_config: KeyConfig::from(key_bind), + } + } } impl Connection { pub fn database_url(&self) -> anyhow::Result { + let password = self + .password + .as_ref() + .map_or(String::new(), |p| p.to_string()); + return self.build_database_url(password); + } + + fn masked_database_url(&self) -> anyhow::Result { + let password = self + .password + .as_ref() + .map_or(String::new(), |p| p.to_string()); + + let masked_password = "*".repeat(password.len()); + return self.build_database_url(masked_password); + } + + fn build_database_url(&self, password: String) -> anyhow::Result { match self.r#type { DatabaseType::MySql => { let user = self @@ -195,26 +250,27 @@ impl Connection { .port .as_ref() .ok_or_else(|| anyhow::anyhow!("type mysql needs the port field"))?; - let password = self - .password - .as_ref() - .map_or(String::new(), |p| p.to_string()); + let unix_domain_socket = self + .valid_unix_domain_socket() + .map_or(String::new(), |uds| format!("?socket={}", uds)); match self.database.as_ref() { Some(database) => Ok(format!( - "mysql://{user}:{password}@{host}:{port}/{database}", + "mysql://{user}:{password}@{host}:{port}/{database}{unix_domain_socket}", user = user, password = password, host = host, port = port, - database = database + database = database, + unix_domain_socket = unix_domain_socket )), None => Ok(format!( - "mysql://{user}:{password}@{host}:{port}", + "mysql://{user}:{password}@{host}:{port}{unix_domain_socket}", user = user, password = password, host = host, port = port, + unix_domain_socket = unix_domain_socket )), } } @@ -231,27 +287,41 @@ impl Connection { .port .as_ref() .ok_or_else(|| anyhow::anyhow!("type postgres needs the port field"))?; - let password = self - .password - .as_ref() - .map_or(String::new(), |p| p.to_string()); - match self.database.as_ref() { - Some(database) => Ok(format!( - "postgres://{user}:{password}@{host}:{port}/{database}", - user = user, - password = password, - host = host, - port = port, - database = database - )), - None => Ok(format!( - "postgres://{user}:{password}@{host}:{port}", - user = user, - password = password, - host = host, - port = port, - )), + if let Some(unix_domain_socket) = self.valid_unix_domain_socket() { + match self.database.as_ref() { + Some(database) => Ok(format!( + "postgres://?dbname={database}&host={unix_domain_socket}&user={user}&password={password}", + database = database, + unix_domain_socket = unix_domain_socket, + user = user, + password = password, + )), + None => Ok(format!( + "postgres://?host={unix_domain_socket}&user={user}&password={password}", + unix_domain_socket = unix_domain_socket, + user = user, + password = password, + )), + } + } else { + match self.database.as_ref() { + Some(database) => Ok(format!( + "postgres://{user}:{password}@{host}:{port}/{database}", + user = user, + password = password, + host = host, + port = port, + database = database, + )), + None => Ok(format!( + "postgres://{user}:{password}@{host}:{port}", + user = user, + password = password, + host = host, + port = port, + )), + } } } DatabaseType::Sqlite => { @@ -268,7 +338,7 @@ impl Connection { } pub fn database_url_with_name(&self) -> anyhow::Result { - let database_url = self.database_url()?; + let database_url = self.masked_database_url()?; Ok(match &self.name { Some(name) => format!( @@ -287,6 +357,23 @@ impl Connection { pub fn is_postgres(&self) -> bool { matches!(self.r#type, DatabaseType::Postgres) } + + fn valid_unix_domain_socket(&self) -> Option { + if cfg!(windows) { + // NOTE: + // windows also supports UDS, but `rust` does not support UDS in windows now. + // https://github.com/rust-lang/rust/issues/56533 + return None; + } + return self.unix_domain_socket.as_ref().and_then(|uds| { + let path = expand_path(uds)?; + let path_str = path.to_str()?; + if path_str.is_empty() { + return None; + } + Some(path_str.to_owned()) + }); + } } pub fn get_app_config_path() -> anyhow::Result { @@ -325,10 +412,77 @@ fn expand_path(path: &Path) -> Option { #[cfg(test)] mod test { - use super::{expand_path, KeyConfig, Path, PathBuf}; + use super::{ + expand_path, CliConfig, Config, Connection, DatabaseType, KeyConfig, Path, PathBuf, + }; use serde_json::Value; use std::env; + #[test] + fn test_load_config() { + let cli_config = CliConfig { + config_path: Some(Path::new("examples/config.toml").to_path_buf()), + key_bind_path: Some(Path::new("examples/key_bind.ron").to_path_buf()), + }; + + assert_eq!(Config::new(&cli_config).is_ok(), true); + } + + #[test] + #[cfg(unix)] + fn test_database_url() { + let mysql_conn = Connection { + r#type: DatabaseType::MySql, + name: None, + user: Some("root".to_owned()), + host: Some("localhost".to_owned()), + port: Some(3306), + path: None, + password: Some("password".to_owned()), + database: Some("city".to_owned()), + unix_domain_socket: None, + }; + + let mysql_result = mysql_conn.database_url().unwrap(); + assert_eq!( + mysql_result, + "mysql://root:password@localhost:3306/city".to_owned() + ); + + let postgres_conn = Connection { + r#type: DatabaseType::Postgres, + name: None, + user: Some("root".to_owned()), + host: Some("localhost".to_owned()), + port: Some(3306), + path: None, + password: Some("password".to_owned()), + database: Some("city".to_owned()), + unix_domain_socket: None, + }; + + let postgres_result = postgres_conn.database_url().unwrap(); + assert_eq!( + postgres_result, + "postgres://root:password@localhost:3306/city".to_owned() + ); + + let sqlite_conn = Connection { + r#type: DatabaseType::Sqlite, + name: None, + user: None, + host: None, + port: None, + path: Some(PathBuf::from("/home/user/sqlite3.db")), + password: None, + database: None, + unix_domain_socket: None, + }; + + let sqlite_result = sqlite_conn.database_url().unwrap(); + assert_eq!(sqlite_result, "sqlite:///home/user/sqlite3.db".to_owned()); + } + #[test] fn test_overlappted_key() { let value: Value = @@ -349,6 +503,137 @@ mod test { } } + #[test] + #[cfg(unix)] + fn test_dataset_url_in_unix() { + let mut mysql_conn = Connection { + r#type: DatabaseType::MySql, + name: None, + user: Some("root".to_owned()), + host: Some("localhost".to_owned()), + port: Some(3306), + path: None, + password: Some("password".to_owned()), + database: Some("city".to_owned()), + unix_domain_socket: None, + }; + + assert_eq!( + mysql_conn.database_url().unwrap(), + "mysql://root:password@localhost:3306/city".to_owned() + ); + + mysql_conn.unix_domain_socket = Some(Path::new("/tmp/mysql.sock").to_path_buf()); + assert_eq!( + mysql_conn.database_url().unwrap(), + "mysql://root:password@localhost:3306/city?socket=/tmp/mysql.sock".to_owned() + ); + + let mut postgres_conn = Connection { + r#type: DatabaseType::Postgres, + name: None, + user: Some("root".to_owned()), + host: Some("localhost".to_owned()), + port: Some(3306), + path: None, + password: Some("password".to_owned()), + database: Some("city".to_owned()), + unix_domain_socket: None, + }; + + assert_eq!( + postgres_conn.database_url().unwrap(), + "postgres://root:password@localhost:3306/city".to_owned() + ); + postgres_conn.unix_domain_socket = Some(Path::new("/tmp").to_path_buf()); + assert_eq!( + postgres_conn.database_url().unwrap(), + "postgres://?dbname=city&host=/tmp&user=root&password=password".to_owned() + ); + + let sqlite_conn = Connection { + r#type: DatabaseType::Sqlite, + name: None, + user: None, + host: None, + port: None, + path: Some(PathBuf::from("/home/user/sqlite3.db")), + password: None, + database: None, + unix_domain_socket: None, + }; + + let sqlite_result = sqlite_conn.database_url().unwrap(); + assert_eq!(sqlite_result, "sqlite:///home/user/sqlite3.db".to_owned()); + } + + #[test] + #[cfg(windows)] + fn test_database_url_in_windows() { + let mut mysql_conn = Connection { + r#type: DatabaseType::MySql, + name: None, + user: Some("root".to_owned()), + host: Some("localhost".to_owned()), + port: Some(3306), + path: None, + password: Some("password".to_owned()), + database: Some("city".to_owned()), + unix_domain_socket: None, + }; + + assert_eq!( + mysql_conn.database_url().unwrap(), + "mysql://root:password@localhost:3306/city".to_owned() + ); + + mysql_conn.unix_domain_socket = Some(Path::new("/tmp/mysql.sock").to_path_buf()); + assert_eq!( + mysql_conn.database_url().unwrap(), + "mysql://root:password@localhost:3306/city".to_owned() + ); + + let mut postgres_conn = Connection { + r#type: DatabaseType::Postgres, + name: None, + user: Some("root".to_owned()), + host: Some("localhost".to_owned()), + port: Some(3306), + path: None, + password: Some("password".to_owned()), + database: Some("city".to_owned()), + unix_domain_socket: None, + }; + + assert_eq!( + postgres_conn.database_url().unwrap(), + "postgres://root:password@localhost:3306/city".to_owned() + ); + postgres_conn.unix_domain_socket = Some(Path::new("/tmp").to_path_buf()); + assert_eq!( + postgres_conn.database_url().unwrap(), + "postgres://root:password@localhost:3306/city".to_owned() + ); + + let sqlite_conn = Connection { + r#type: DatabaseType::Sqlite, + name: None, + user: None, + host: None, + port: None, + path: Some(PathBuf::from("/home/user/sqlite3.db")), + password: None, + database: None, + unix_domain_socket: None, + }; + + let sqlite_result = sqlite_conn.database_url().unwrap(); + assert_eq!( + sqlite_result, + "sqlite://\\home\\user\\sqlite3.db".to_owned() + ); + } + #[test] #[cfg(unix)] fn test_expand_path() { diff --git a/src/database/mod.rs b/src/database/mod.rs index fbaa0e3..a13499c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -22,6 +22,7 @@ pub trait Pool: Send + Sync { table: &Table, page: u16, filter: Option, + orders: Option, ) -> anyhow::Result<(Vec, Vec>)>; async fn get_columns( &self, diff --git a/src/database/mysql.rs b/src/database/mysql.rs index 67b0184..c7fb663 100644 --- a/src/database/mysql.rs +++ b/src/database/mysql.rs @@ -228,21 +228,41 @@ impl Pool for MySqlPool { table: &Table, page: u16, filter: Option, + orders: Option, ) -> anyhow::Result<(Vec, Vec>)> { - let query = if let Some(filter) = filter { + let query = if let (Some(filter), Some(orders)) = (&filter, &orders) { + format!( + "SELECT * FROM `{database}`.`{table}` WHERE {filter} {orders} LIMIT {page}, {limit}", + database = database.name, + table = table.name, + filter = filter, + page = page, + limit = RECORDS_LIMIT_PER_PAGE, + orders = orders + ) + } else if let Some(filter) = filter { format!( "SELECT * FROM `{database}`.`{table}` WHERE {filter} LIMIT {page}, {limit}", database = database.name, table = table.name, filter = filter, page = page, - limit = RECORDS_LIMIT_PER_PAGE + limit = RECORDS_LIMIT_PER_PAGE, + ) + } else if let Some(orders) = orders { + format!( + "SELECT * FROM `{database}`.`{table}` {orders} LIMIT {page}, {limit}", + database = database.name, + table = table.name, + orders = orders, + page = page, + limit = RECORDS_LIMIT_PER_PAGE, ) } else { format!( - "SELECT * FROM `{}`.`{}` LIMIT {page}, {limit}", - database.name, - table.name, + "SELECT * FROM `{database}`.`{table}` LIMIT {page}, {limit}", + database = database.name, + table = table.name, page = page, limit = RECORDS_LIMIT_PER_PAGE ) diff --git a/src/database/postgres.rs b/src/database/postgres.rs index 6a332e8..526248c 100644 --- a/src/database/postgres.rs +++ b/src/database/postgres.rs @@ -245,8 +245,20 @@ impl Pool for PostgresPool { table: &Table, page: u16, filter: Option, + orders: Option, ) -> anyhow::Result<(Vec, Vec>)> { - let query = if let Some(filter) = filter.as_ref() { + let query = if let (Some(filter), Some(orders)) = (&filter, &orders) { + format!( + r#"SELECT * FROM "{database}"."{table_schema}"."{table}" WHERE {filter} {orders} LIMIT {limit} OFFSET {page}"#, + database = database.name, + table = table.name, + filter = filter, + orders = orders, + table_schema = table.schema.clone().unwrap_or_else(|| "public".to_string()), + page = page, + limit = RECORDS_LIMIT_PER_PAGE + ) + } else if let Some(filter) = &filter { format!( r#"SELECT * FROM "{database}"."{table_schema}"."{table}" WHERE {filter} LIMIT {limit} OFFSET {page}"#, database = database.name, @@ -256,6 +268,16 @@ impl Pool for PostgresPool { page = page, limit = RECORDS_LIMIT_PER_PAGE ) + } else if let Some(orders) = &orders { + format!( + r#"SELECT * FROM "{database}"."{table_schema}"."{table}" {orders} LIMIT {limit} OFFSET {page}"#, + database = database.name, + table = table.name, + orders = orders, + table_schema = table.schema.clone().unwrap_or_else(|| "public".to_string()), + page = page, + limit = RECORDS_LIMIT_PER_PAGE + ) } else { format!( r#"SELECT * FROM "{database}"."{table_schema}"."{table}" LIMIT {limit} OFFSET {page}"#, @@ -283,8 +305,14 @@ impl Pool for PostgresPool { Err(_) => { if json_records.is_none() { json_records = Some( - self.get_json_records(database, table, page, filter.clone()) - .await?, + self.get_json_records( + database, + table, + page, + filter.clone(), + orders.clone(), + ) + .await?, ); } if let Some(json_records) = &json_records { @@ -479,8 +507,20 @@ impl PostgresPool { table: &Table, page: u16, filter: Option, + orders: Option, ) -> anyhow::Result> { - let query = if let Some(filter) = filter { + let query = if let (Some(filter), Some(orders)) = (&filter, &orders) { + format!( + r#"SELECT to_json("{table}".*) FROM "{database}"."{table_schema}"."{table}" WHERE {filter} {orders} LIMIT {limit} OFFSET {page}"#, + database = database.name, + table = table.name, + filter = filter, + orders = orders, + table_schema = table.schema.clone().unwrap_or_else(|| "public".to_string()), + page = page, + limit = RECORDS_LIMIT_PER_PAGE + ) + } else if let Some(filter) = filter { format!( r#"SELECT to_json("{table}".*) FROM "{database}"."{table_schema}"."{table}" WHERE {filter} LIMIT {limit} OFFSET {page}"#, database = database.name, @@ -490,6 +530,16 @@ impl PostgresPool { page = page, limit = RECORDS_LIMIT_PER_PAGE ) + } else if let Some(orders) = orders { + format!( + r#"SELECT to_json("{table}".*) FROM "{database}"."{table_schema}"."{table}" {orders} LIMIT {limit} OFFSET {page}"#, + database = database.name, + table = table.name, + orders = orders, + table_schema = table.schema.clone().unwrap_or_else(|| "public".to_string()), + page = page, + limit = RECORDS_LIMIT_PER_PAGE + ) } else { format!( r#"SELECT to_json("{table}".*) FROM "{database}"."{table_schema}"."{table}" LIMIT {limit} OFFSET {page}"#, diff --git a/src/database/sqlite.rs b/src/database/sqlite.rs index cc3291b..8cf7ab3 100644 --- a/src/database/sqlite.rs +++ b/src/database/sqlite.rs @@ -230,8 +230,18 @@ impl Pool for SqlitePool { table: &Table, page: u16, filter: Option, + orders: Option, ) -> anyhow::Result<(Vec, Vec>)> { - let query = if let Some(filter) = filter { + let query = if let (Some(filter), Some(orders)) = (&filter, &orders) { + format!( + "SELECT * FROM `{table}` WHERE {filter} {orders} LIMIT {page}, {limit}", + table = table.name, + filter = filter, + page = page, + limit = RECORDS_LIMIT_PER_PAGE, + orders = orders + ) + } else if let Some(filter) = filter { format!( "SELECT * FROM `{table}` WHERE {filter} LIMIT {page}, {limit}", table = table.name, @@ -239,10 +249,18 @@ impl Pool for SqlitePool { page = page, limit = RECORDS_LIMIT_PER_PAGE ) + } else if let Some(orders) = orders { + format!( + "SELECT * FROM `{table}`{orders} LIMIT {page}, {limit}", + table = table.name, + orders = orders, + page = page, + limit = RECORDS_LIMIT_PER_PAGE + ) } else { format!( - "SELECT * FROM `{}` LIMIT {page}, {limit}", - table.name, + "SELECT * FROM `{table}` LIMIT {page}, {limit}", + table = table.name, page = page, limit = RECORDS_LIMIT_PER_PAGE ) diff --git a/src/event/events.rs b/src/event/events.rs index a01e93d..582664e 100644 --- a/src/event/events.rs +++ b/src/event/events.rs @@ -3,6 +3,7 @@ use crossterm::event; use std::{sync::mpsc, thread, time::Duration}; #[derive(Debug, Clone, Copy)] +#[allow(dead_code)] pub struct EventConfig { pub exit_key: Key, pub tick_rate: Duration, diff --git a/src/key_bind.rs b/src/key_bind.rs new file mode 100644 index 0000000..dcd861f --- /dev/null +++ b/src/key_bind.rs @@ -0,0 +1,179 @@ +use crate::config::KeyConfig; +use crate::event::Key; +use ron::de::SpannedError; +use serde::Deserialize; +use std::fs::File; +use std::io::{BufReader, Read}; +use std::path::PathBuf; + +#[derive(Debug, Deserialize, Clone, Default)] +pub struct KeyBind { + pub scroll_up: Option, + pub scroll_down: Option, + pub scroll_right: Option, + pub scroll_left: Option, + pub sort_by_column: Option, + pub move_up: Option, + pub move_down: Option, + pub copy: Option, + pub enter: Option, + pub exit: Option, + pub quit: Option, + pub exit_popup: Option, + pub focus_right: Option, + pub focus_left: Option, + pub focus_above: Option, + pub focus_connections: Option, + pub open_help: Option, + pub filter: Option, + pub scroll_down_multiple_lines: Option, + pub scroll_up_multiple_lines: Option, + pub scroll_to_top: Option, + pub scroll_to_bottom: Option, + pub move_to_head_of_line: Option, + pub move_to_tail_of_line: Option, + pub extend_selection_by_one_cell_left: Option, + pub extend_selection_by_one_cell_right: Option, + pub extend_selection_by_one_cell_up: Option, + pub extend_selection_by_one_cell_down: Option, + pub extend_selection_by_horizontal_line: Option, + pub tab_records: Option, + pub tab_columns: Option, + pub tab_constraints: Option, + pub tab_foreign_keys: Option, + pub tab_indexes: Option, + pub tab_sql_editor: Option, + pub tab_properties: Option, + pub extend_or_shorten_widget_width_to_right: Option, + pub extend_or_shorten_widget_width_to_left: Option, +} + +impl KeyBind { + pub fn load(config_path: PathBuf) -> anyhow::Result { + if let Ok(file) = File::open(config_path) { + let mut buf_reader = BufReader::new(file); + let mut contents = String::new(); + buf_reader.read_to_string(&mut contents)?; + let key_bind: Result<_, SpannedError> = ron::from_str(&contents); + match key_bind { + Ok(key_bind) => return Ok(key_bind), + Err(e) => { + eprintln!("fail to parse key bind file: {}", e); + return Ok(Self::default()); + } + } + } + + Ok(Self::default()) + } +} + +macro_rules! merge { + ($kc:expr, $kt:expr) => { + $kc = $kt.unwrap_or_else(|| $kc) + }; +} + +impl From for KeyConfig { + #[allow(clippy::field_reassign_with_default)] + fn from(kb: KeyBind) -> Self { + let mut kc = KeyConfig::default(); + merge!(kc.scroll_up, kb.scroll_up); + merge!(kc.scroll_down, kb.scroll_down); + merge!(kc.scroll_right, kb.scroll_right); + merge!(kc.scroll_left, kb.scroll_left); + merge!(kc.scroll_down, kb.scroll_down); + merge!(kc.move_up, kb.move_up); + merge!(kc.move_down, kb.move_down); + merge!(kc.copy, kb.copy); + merge!(kc.enter, kb.enter); + merge!(kc.exit, kb.exit); + merge!(kc.quit, kb.quit); + merge!(kc.exit_popup, kb.exit_popup); + merge!(kc.focus_right, kb.focus_right); + merge!(kc.focus_left, kb.focus_left); + merge!(kc.focus_above, kb.focus_above); + merge!(kc.focus_connections, kb.focus_connections); + merge!(kc.open_help, kb.open_help); + merge!(kc.filter, kb.filter); + merge!(kc.scroll_down_multiple_lines, kb.scroll_down_multiple_lines); + merge!(kc.scroll_up_multiple_lines, kb.scroll_up_multiple_lines); + merge!(kc.scroll_to_top, kb.scroll_to_top); + merge!(kc.scroll_to_bottom, kb.scroll_to_bottom); + merge!(kc.move_to_head_of_line, kb.move_to_head_of_line); + merge!(kc.move_to_tail_of_line, kb.move_to_tail_of_line); + merge!(kc.sort_by_column, kb.sort_by_column); + merge!( + kc.extend_selection_by_one_cell_left, + kb.extend_selection_by_one_cell_left + ); + merge!( + kc.extend_selection_by_one_cell_right, + kb.extend_selection_by_one_cell_right + ); + merge!( + kc.extend_selection_by_one_cell_down, + kb.extend_selection_by_one_cell_down + ); + merge!( + kc.extend_selection_by_horizontal_line, + kb.extend_selection_by_horizontal_line + ); + merge!( + kc.extend_selection_by_one_cell_up, + kb.extend_selection_by_one_cell_up + ); + merge!(kc.tab_records, kb.tab_records); + merge!(kc.tab_properties, kb.tab_properties); + merge!(kc.tab_sql_editor, kb.tab_sql_editor); + merge!(kc.tab_columns, kb.tab_columns); + merge!(kc.tab_constraints, kb.tab_constraints); + merge!(kc.tab_foreign_keys, kb.tab_foreign_keys); + merge!(kc.tab_indexes, kb.tab_indexes); + merge!( + kc.extend_or_shorten_widget_width_to_right, + kb.extend_or_shorten_widget_width_to_right + ); + merge!( + kc.extend_or_shorten_widget_width_to_left, + kb.extend_or_shorten_widget_width_to_left + ); + kc + } +} + +#[cfg(test)] +mod test { + use super::KeyBind; + use crate::config::KeyConfig; + use crate::event::Key; + use std::path::Path; + + #[test] + fn test_exist_file() { + let config_path = Path::new("examples/key_bind.ron").to_path_buf(); + assert_eq!(config_path.exists(), true); + assert_eq!(KeyBind::load(config_path).is_ok(), true); + } + + #[test] + fn test_not_exist_file() { + let config_path = Path::new("examples/not_exist.ron").to_path_buf(); + assert_eq!(config_path.exists(), false); + assert_eq!(KeyBind::load(config_path).is_ok(), true); + } + + #[test] + fn test_key_config_from_key_bind() { + // Default Config + let empty_kb = KeyBind::default(); + let kc = KeyConfig::default(); + assert_eq!(KeyConfig::from(empty_kb), kc); + + // Merged Config + let mut kb = KeyBind::default(); + kb.scroll_up = Some(Key::Char('M')); + let build_kc = KeyConfig::from(kb); + assert_eq!(build_kc.scroll_up, Key::Char('M')); + } +} diff --git a/src/main.rs b/src/main.rs index 3997a64..4dc62af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ mod components; mod config; mod database; mod event; +mod key_bind; mod tree; mod ui; mod version;