use std::{ borrow::Cow, cmp::Ordering, convert::Infallible, fmt::Display, path::PathBuf, process::ExitCode, }; use pico_args::Arguments; use tree_sitter::{Node, Parser}; const CLI_HELP: &str = r#"USAGE $ rust-organizer [-c] [-w] FILE ARGUMENTS FILE File name of the Rust source file to reorganize. FLAGS -c, --check Check whether reorganizing the file would change the file contents. -w, --write Overwrite the file with the reorganized contents. "#; #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct Cli { check: bool, overwrite: bool, path: PathBuf, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum Item<'a> { InnerDoc(Cow<'a, str>), Macro { name: &'a str, content: Cow<'a, str>, }, ModDecl { name: &'a str, content: Cow<'a, str>, }, Use(Cow<'a, str>), Const { name: &'a str, content: Cow<'a, str>, }, Type { name: &'a str, content: Cow<'a, str>, }, Func { name: &'a str, content: Cow<'a, str>, }, Impl { name: TypeIdent<'a>, trt: Option<&'a str>, content: SortableContent<'a>, }, MacroInvocation(Cow<'a, str>), Mod { name: &'a str, content: SortableContent<'a>, }, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct Module<'a> { items: Vec<(bool, Item<'a>)>, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SortableContent<'a> { before: Cow<'a, str>, inner: Module<'a>, after: Cow<'a, str>, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct TypeIdent<'a> { name: &'a str, generics: Option<&'a str>, reference_type: Option<&'a str>, } fn main() -> ExitCode { // Parse commandline arguments let mut args = Arguments::from_env(); if args.contains(["-h", "--help"]) { print!("{}", CLI_HELP); return ExitCode::SUCCESS; } let cli: Cli = match args.try_into() { Ok(cli) => cli, Err(e) => { eprintln!("Error: {}", e); return ExitCode::FAILURE; } }; // Run the main program match cli.run() { Ok(code) => code, Err(e) => { eprintln!("Error: {}", e); ExitCode::FAILURE } } } impl Cli { fn run(&self) -> Result { let mut parser = Parser::new(); parser .set_language(&tree_sitter_rust::LANGUAGE.into()) .expect("Error loading Rust grammar"); let text = std::fs::read_to_string(&self.path) .map_err(|e| format!("unable to read file: {}", e))?; let Some(tree) = parser.parse(&text, None) else { return Err("unable to parse file".to_owned()); }; let root = tree.root_node(); assert_eq!(root.kind(), "source_file"); let mut root = Module::from_node(&text, root); let is_sorted = root.is_sorted(self.check); if self.check { return if is_sorted { Ok(ExitCode::SUCCESS) } else { Ok(ExitCode::FAILURE) }; } if self.overwrite && is_sorted { return Ok(ExitCode::SUCCESS); } root.sort(); if self.overwrite { std::fs::write(&self.path, root.to_string()) .map_err(|e| format!("unable to write file: {}", e))?; } else { println!("{}", root); } Ok(ExitCode::SUCCESS) } } impl TryFrom for Cli { type Error = String; fn try_from(mut args: Arguments) -> Result { let cli = Cli { check: args.contains(["-c", "--check"]), overwrite: args.contains(["-w", "--write"]), path: args .free_from_os_str::<_, Infallible>(|s| Ok(PathBuf::from(s))) .unwrap(), }; let remaining = args.finish(); match remaining.len() { 0 => Ok(()), 1 => Err(format!( "unexpected argument: '{}'", remaining[0].to_string_lossy() )), _ => Err(format!( "unexpected arguments: {}", remaining .into_iter() .map(|s| format!("'{}'", s.to_string_lossy())) .collect::>() .join(", ") )), }?; Ok(cli) } } impl<'a> Item<'a> { fn append_content(&mut self, text: &str) { match self { Item::Macro { content, .. } | Item::ModDecl { content, .. } | Item::Const { content, .. } | Item::Type { content, .. } | Item::Func { content, .. } | Item::InnerDoc(content) | Item::Use(content) | Item::MacroInvocation(content) => { *content = Cow::Owned(format!("{}{}", content, text)); } Item::Impl { .. } | Item::Mod { .. } => { // Cannot add content to these items } } } fn item_order(&self) -> u8 { match self { Item::InnerDoc(_) => 0, Item::Macro { .. } => 1, Item::ModDecl { .. } => 2, Item::Use(_) => 3, Item::Const { .. } => 4, Item::Type { .. } => 5, Item::Func { .. } => 6, Item::Impl { .. } => 7, Item::MacroInvocation(_) => 8, Item::Mod { .. } => 9, } } fn maybe_item(text: &'a str, node: Node<'a>, start: Option) -> Option { let get_field_str = |field_name| { node.child_by_field_name(field_name) .map(|n| n.utf8_text(text.as_bytes()).unwrap()) }; let start = start.unwrap_or(node.start_byte()); let content: Cow<'a, str> = Cow::Borrowed(&text[start..node.end_byte()]); match node.kind() { "attribute_item" => { // Ignore and add to the next item None } "block_comment" | "line_comment" => { let comment = node.utf8_text(text.as_bytes()).unwrap(); if comment.starts_with("//!") || comment.starts_with("/*!") { // Doc comment for the file (ensure that it's at the top of the file). Some(Self::InnerDoc(content)) } else { None // Move comment with the next item } } "const_item" => { let name = get_field_str("name").unwrap(); Some(Self::Const { name, content }) } "enum_item" | "struct_item" | "trait_item" | "type_item" => { let name = get_field_str("name").unwrap(); Some(Self::Type { name, content }) } "function_item" => { let name = get_field_str("name").unwrap(); Some(Self::Func { name, content }) } "impl_item" => { let name = TypeIdent::from_node(text, node.child_by_field_name("type").unwrap()); let trt = get_field_str("trait"); let content = SortableContent::within_node(text, node, Some(start), "body"); Some(Self::Impl { name, trt, content }) } "macro_definition" => { let name = get_field_str("name").unwrap(); Some(Self::Macro { name, content }) } "macro_invocation" => Some(Self::MacroInvocation(content)), "mod_item" => { let name = get_field_str("name").unwrap(); if node.child_by_field_name("body").is_some() { let content = SortableContent::within_node(text, node, Some(start), "body"); Some(Self::Mod { name, content }) } else { Some(Self::ModDecl { name, content }) } } "use_declaration" => Some(Self::Use(content)), _ => panic!( "unexpected node kind: {}\ncontent: {}", node.kind(), content ), } } } impl Display for Item<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Item::InnerDoc(content) | Item::Macro { content, .. } | Item::MacroInvocation(content) | Item::ModDecl { content, .. } | Item::Use(content) | Item::Const { content, .. } | Item::Type { content, .. } | Item::Func { content, .. } => write!(f, "{content}"), Item::Mod { content, .. } | Item::Impl { content, .. } => { write!(f, "{content}") } } } } impl Ord for Item<'_> { fn cmp(&self, other: &Self) -> Ordering { let self_order = self.item_order(); let other_order = other.item_order(); if self_order != other_order { return self_order.cmp(&other_order); } match (self, other) { (Item::InnerDoc(_), Item::InnerDoc(_)) => Ordering::Equal, (Item::Const { name: a, .. }, Item::Const { name: b, .. }) | (Item::Macro { name: a, .. }, Item::Macro { name: b, .. }) | (Item::Mod { name: a, .. }, Item::Mod { name: b, .. }) | (Item::ModDecl { name: a, .. }, Item::ModDecl { name: b, .. }) | (Item::Type { name: a, .. }, Item::Type { name: b, .. }) | (Item::Func { name: a, .. }, Item::Func { name: b, .. }) => a.cmp(b), (Item::Use(_), Item::Use(_)) | (Item::MacroInvocation(_), Item::MacroInvocation(_)) => { Ordering::Equal } ( Item::Impl { name: a, trt: t_a, .. }, Item::Impl { name: b, trt: t_b, .. }, ) => { let name_order = a.name.cmp(b.name); if name_order == Ordering::Equal { let trt_order = t_a.unwrap_or("").cmp(t_b.unwrap_or("")); if trt_order == Ordering::Equal { let a_parts = (a.generics.unwrap_or(""), a.reference_type.unwrap_or("")); let b_parts = (b.generics.unwrap_or(""), b.reference_type.unwrap_or("")); a_parts.cmp(&b_parts) } else { trt_order } } else { name_order } } _ => { // eprintln!("{} -- {}", self, other); unreachable!(); } } } } impl PartialOrd for Item<'_> { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl<'a> Module<'a> { pub fn from_node(text: &'a str, root: Node<'a>) -> Self { assert!(matches!(root.kind(), "source_file" | "declaration_list")); let mut cursor = root.walk(); cursor.goto_first_child(); let mut items: Vec<(bool, Item)> = Vec::new(); let mut start = None; let mut last = None; if cursor.node().kind() == "{" { last = Some(cursor.node().end_byte()); cursor.goto_next_sibling(); } loop { if cursor.node().kind() == "}" { assert!(!cursor.goto_next_sibling()); break; } let node = cursor.node(); // eprintln!("{} : {}\n\n", node.kind(), node.to_sexp()); let inbetween = &text[last.unwrap_or(root.start_byte())..start.unwrap_or(node.start_byte())]; if node.kind() == "empty_statement" { if let Some((_, it)) = items.last_mut() { it.append_content(";"); } debug_assert!( inbetween.trim().is_empty(), "unexpected skipped content: {:?}", inbetween ); start = None; last = Some(node.end_byte()); } else if let Some(item) = Item::maybe_item(&text, node, start) { debug_assert!( inbetween.trim().is_empty(), "unexpected skipped content: {:?}", inbetween ); let newline_before = inbetween.contains("\n\n"); items.push((newline_before, item)); start = None; last = Some(node.end_byte()); } else if start.is_none() { start = Some(node.start_byte()); } if !cursor.goto_next_sibling() { break; } } Self { items } } pub fn is_sorted(&self, print_diff: bool) -> bool { for it in &self.items { match &it.1 { Item::Mod { content, .. } | Item::Impl { content, .. } => { if !content.is_sorted(print_diff) { return false; } } _ => {} } } for window in self.items.windows(2) { if window[0].1 > window[1].1 { if print_diff { eprintln!( "Expected \n\"\"\"\n{}\n\"\"\"\n before \n\"\"\"\n{}\n\"\"\"", window[1].1, window[0].1 ); } return false; } } true } pub fn sort(&mut self) { for it in self.items.iter_mut() { match &mut it.1 { Item::Mod { content, .. } | Item::Impl { content, .. } => content.sort(), _ => {} } } self.items.sort_unstable_by(|a, b| a.1.cmp(&b.1)); } } impl Display for Module<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut last = None; for (newline, item) in &self.items { if *newline || last != Some(item.item_order()) { writeln!(f)?; } writeln!(f, "{}", item)?; last = Some(item.item_order()); } Ok(()) } } impl<'a> SortableContent<'a> { fn is_sorted(&self, print_diff: bool) -> bool { self.inner.is_sorted(print_diff) } fn sort(&mut self) { self.inner.sort(); } fn within_node( text: &'a str, node: Node<'a>, start: Option, child: &'static str, ) -> Self { let start = start.unwrap_or(node.start_byte()); let body = node.child_by_field_name(child).unwrap(); let mut cursor = body.walk(); cursor.goto_first_child(); assert_eq!(cursor.node().kind(), "{"); let before = Cow::Borrowed(&text[start..cursor.node().end_byte()]); cursor.goto_parent(); cursor.goto_last_child(); assert_eq!(cursor.node().kind(), "}"); let after = Cow::Borrowed(&text[cursor.node().start_byte()..node.end_byte()]); let inner = Module::from_node(text, body); Self { before, inner, after, } } } impl Display for SortableContent<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}{}{}", self.before, self.inner, self.after) } } impl<'a> TypeIdent<'a> { fn from_node(text: &'a str, node: Node<'a>) -> Self { let get_field_str = |field_name| { node.child_by_field_name(field_name) .map(|n| n.utf8_text(text.as_bytes()).unwrap()) }; match node.kind() { "type_identifier" => Self { name: node.utf8_text(text.as_bytes()).unwrap(), generics: None, reference_type: None, }, "generic_type" => { let name = get_field_str("type").unwrap(); let generics = get_field_str("type_arguments"); debug_assert!(generics.is_some()); Self { name, generics, reference_type: None, } } "reference_type" => { let inner = node.child_by_field_name("type").unwrap(); let mut ty = TypeIdent::from_node(text, inner); let reference_str = std::str::from_utf8(&text.as_bytes()[node.start_byte()..inner.start_byte()]) .unwrap(); ty.reference_type = Some(reference_str); ty } _ => panic!("invalid type identifier node: {}", node.kind()), } } }