From 22e876a0ed93e00d3f7d0b25b329c29a9feacc98 Mon Sep 17 00:00:00 2001 From: "Jip J. Dekker" Date: Fri, 2 Aug 2024 16:06:55 +1000 Subject: [PATCH] Implement recursive sorting for mod and impl items --- src/main.rs | 322 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 205 insertions(+), 117 deletions(-) diff --git a/src/main.rs b/src/main.rs index ab92051..13e3817 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,6 @@ -use std::{cmp::Ordering, convert::Infallible, ops::Range, path::PathBuf, process::ExitCode}; +use std::{ + borrow::Cow, cmp::Ordering, convert::Infallible, fmt::Display, path::PathBuf, process::ExitCode, +}; use pico_args::Arguments; use tree_sitter::{Node, Parser}; @@ -14,8 +16,6 @@ FLAGS -w, --write Overwrite the file with the reorganized contents. "#; -type ByteRange = Range; - #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct Cli { check: bool, @@ -25,30 +25,45 @@ struct Cli { #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum Item<'a> { - InnerDoc(ByteRange), - Mod { + InnerDoc(Cow<'a, str>), + ModDecl { name: &'a str, - is_declaration: bool, - content: ByteRange, + content: Cow<'a, str>, }, - Use(ByteRange), + Use(Cow<'a, str>), Const { name: &'a str, - content: ByteRange, + content: Cow<'a, str>, }, Type { name: &'a str, - content: ByteRange, + content: Cow<'a, str>, }, Func { name: &'a str, - content: ByteRange, + content: Cow<'a, str>, }, Impl { name: TypeIdent<'a>, trt: Option<&'a str>, - content: ByteRange, + content: SortableContent<'a>, }, + Mod { + name: &'a str, + content: SortableContent<'a>, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct Module<'a> { + items: Vec<(bool, Item<'a>)>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct SortableContent<'a> { + before: Cow<'a, str>, + inner: Module<'a>, + after: Cow<'a, str>, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -95,67 +110,29 @@ impl Cli { return Err("unable to parse file".to_owned()); }; - let mut items = Vec::new(); + let root = tree.root_node(); + assert_eq!(root.kind(), "source_file"); + let mut root = Module::from_node(&text, root); - let mut cursor = tree.walk(); - assert_eq!(cursor.node().kind(), "source_file"); - assert!(!cursor.goto_next_sibling()); - cursor.goto_first_child(); - let mut start = None; - loop { - let node = cursor.node(); - // println!("{} : {}\n\n", node.kind(), node.to_sexp()); - if let Some(item) = Item::maybe_item(&text, node, start) { - let last = items.last(); - let inbetween = - last.map(|(_, i): &(_, Item)| i.end_byte()).unwrap_or(0)..item.start_byte(); - debug_assert!(text[inbetween.clone()].trim().is_empty()); - let newline_before = text[inbetween].contains("\n\n"); - items.push((newline_before, item)); - start = None; - } else if start.is_none() { - start = Some(node.start_byte()); - } - if !cursor.goto_next_sibling() { - break; - } + let is_sorted = root.is_sorted(self.check); + if self.check { + return if is_sorted { + Ok(ExitCode::SUCCESS) + } else { + Ok(ExitCode::FAILURE) + }; } - - let mut is_sorted = true; - for window in items.windows(2) { - if window[0] > window[1] { - if self.check { - eprintln!( - "Expected \n\"\"\"\n{}\n\"\"\"\n before \n\"\"\"\n{}\n\"\"\"", - window[1].1.content(&text), - window[0].1.content(&text) - ); - return Ok(ExitCode::FAILURE); - } - is_sorted = false; - break; - } - } - if self.check || (self.overwrite && is_sorted) { + if self.overwrite && is_sorted { return Ok(ExitCode::SUCCESS); } - // Sort items by their order in the file - items.sort_by(|a, b| a.1.cmp(&b.1)); - - println!("{:?}", items); + root.sort(); if self.overwrite { - todo!() - } - - let mut last = None; - for (newline, item) in items { - if newline || last != Some(item.item_order()) { - println!(); - } - println!("{}", item.content(&text)); - last = Some(item.item_order()); + std::fs::write(&self.path, root.to_string()) + .map_err(|e| format!("unable to write file: {}", e))?; + } else { + println!("{}", root); } Ok(ExitCode::SUCCESS) @@ -195,50 +172,16 @@ impl TryFrom for Cli { } impl<'a> Item<'a> { - fn byte_range(&self) -> ByteRange { - match self { - Item::InnerDoc(content) - | Item::Mod { content, .. } - | Item::Use(content) - | Item::Const { content, .. } - | Item::Type { content, .. } - | Item::Func { content, .. } - | Item::Impl { content, .. } => content.clone(), - } - } - - fn content(&self, text: &'a str) -> &'a str { - match self { - Item::InnerDoc(content) - | Item::Mod { content, .. } - | Item::Use(content) - | Item::Const { content, .. } - | Item::Type { content, .. } - | Item::Func { content, .. } - | Item::Impl { content, .. } => &text[content.clone()], - } - } - - fn end_byte(&self) -> usize { - self.byte_range().end - } - fn item_order(&self) -> u8 { match self { Item::InnerDoc(_) => 0, - Item::Mod { - is_declaration: true, - .. - } => 1, + Item::ModDecl { .. } => 1, Item::Use(_) => 2, Item::Const { .. } => 3, Item::Type { .. } => 4, Item::Func { .. } => 5, Item::Impl { .. } => 6, - Item::Mod { - is_declaration: false, - .. - } => 7, + Item::Mod { .. } => 7, } } @@ -249,6 +192,7 @@ impl<'a> Item<'a> { }; let start = start.unwrap_or(node.start_byte()); + let content: Cow<'a, str> = Cow::Borrowed(&text[start..node.end_byte()]); match node.kind() { "attribute_item" => { // Ignore and add to the next item @@ -258,49 +202,57 @@ impl<'a> Item<'a> { let comment = node.utf8_text(text.as_bytes()).unwrap(); if comment.starts_with("//!") || comment.starts_with("/*!") { // Doc comment for the file (ensure that it's at the top of the file). - Some(Self::InnerDoc(start..node.end_byte())) + Some(Self::InnerDoc(content)) } else { None // Move comment with the next item } } "const_item" => { let name = get_field_str("name").unwrap(); - let content = start..node.end_byte(); Some(Self::Const { name, content }) } "enum_item" | "struct_item" => { let name = get_field_str("name").unwrap(); - let content = start..node.end_byte(); Some(Self::Type { name, content }) } "function_item" => { let name = get_field_str("name").unwrap(); - let content = start..node.end_byte(); Some(Self::Func { name, content }) } "impl_item" => { let name = TypeIdent::from_node(text, node.child_by_field_name("type").unwrap()); let trt = get_field_str("trait"); - let content = start..node.end_byte(); + let content = SortableContent::within_node(text, node, Some(start), "body"); Some(Self::Impl { name, trt, content }) } "mod_item" => { let name = get_field_str("name").unwrap(); - let is_declaration = node.child_by_field_name("body").is_none(); - let content = start..node.end_byte(); - Some(Self::Mod { - name, - is_declaration, - content, - }) + if node.child_by_field_name("body").is_some() { + let content = SortableContent::within_node(text, node, Some(start), "body"); + Some(Self::Mod { name, content }) + } else { + Some(Self::ModDecl { name, content }) + } } - "use_declaration" => Some(Self::Use(start..node.end_byte())), + "use_declaration" => Some(Self::Use(content)), _ => panic!("unexpected node kind: {}", node.kind()), } } +} - fn start_byte(&self) -> usize { - self.byte_range().start +impl Display for Item<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Item::InnerDoc(content) + | Item::ModDecl { content, .. } + | Item::Use(content) + | Item::Const { content, .. } + | Item::Type { content, .. } + | Item::Func { content, .. } => write!(f, "{content}"), + Item::Mod { content, .. } | Item::Impl { content, .. } => { + write!(f, "{content}") + } + } } } @@ -349,6 +301,142 @@ impl PartialOrd for Item<'_> { } } +impl<'a> Module<'a> { + pub fn from_node(text: &'a str, root: Node<'a>) -> Self { + assert!(matches!(root.kind(), "source_file" | "declaration_list")); + let mut cursor = root.walk(); + cursor.goto_first_child(); + + let mut items = Vec::new(); + let mut start = None; + let mut last = None; + if cursor.node().kind() == "{" { + last = Some(cursor.node().end_byte()); + cursor.goto_next_sibling(); + } + loop { + let node = cursor.node(); + // println!("{} : {}\n\n", node.kind(), node.to_sexp()); + if let Some(item) = Item::maybe_item(&text, node, start) { + let inbetween = + &text[last.unwrap_or(root.start_byte())..start.unwrap_or(node.start_byte())]; + debug_assert!( + inbetween.trim().is_empty(), + "unexpected skipped content: {:?}", + inbetween + ); + let newline_before = inbetween.contains("\n\n"); + items.push((newline_before, item)); + start = None; + last = Some(node.end_byte()); + } else if start.is_none() { + start = Some(node.start_byte()); + } + if !cursor.goto_next_sibling() { + break; + } + if cursor.node().kind() == "}" { + assert!(!cursor.goto_next_sibling()); + break; + } + } + + Self { items } + } + + pub fn is_sorted(&self, print_diff: bool) -> bool { + for it in &self.items { + match &it.1 { + Item::Mod { content, .. } | Item::Impl { content, .. } => { + if !content.is_sorted(print_diff) { + return false; + } + } + _ => {} + } + } + for window in self.items.windows(2) { + if window[0] > window[1] { + if print_diff { + eprintln!( + "Expected \n\"\"\"\n{}\n\"\"\"\n before \n\"\"\"\n{}\n\"\"\"", + window[1].1, window[0].1 + ); + return false; + } + } + } + true + } + + pub fn sort(&mut self) { + for it in self.items.iter_mut() { + match &mut it.1 { + Item::Mod { content, .. } | Item::Impl { content, .. } => content.sort(), + _ => {} + } + } + self.items.sort_unstable_by(|a, b| a.1.cmp(&b.1)); + } +} + +impl Display for Module<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut last = None; + for (newline, item) in &self.items { + if *newline || last != Some(item.item_order()) { + writeln!(f)?; + } + writeln!(f, "{}", item)?; + last = Some(item.item_order()); + } + Ok(()) + } +} + +impl<'a> SortableContent<'a> { + fn is_sorted(&self, print_diff: bool) -> bool { + self.inner.is_sorted(print_diff) + } + + fn sort(&mut self) { + self.inner.sort(); + } + + fn within_node( + text: &'a str, + node: Node<'a>, + start: Option, + child: &'static str, + ) -> Self { + let start = start.unwrap_or(node.start_byte()); + let body = node.child_by_field_name(child).unwrap(); + + let mut cursor = body.walk(); + cursor.goto_first_child(); + assert_eq!(cursor.node().kind(), "{"); + let before = Cow::Borrowed(&text[start..cursor.node().end_byte()]); + + cursor.goto_parent(); + cursor.goto_last_child(); + assert_eq!(cursor.node().kind(), "}"); + let after = Cow::Borrowed(&text[cursor.node().start_byte()..node.end_byte()]); + + let inner = Module::from_node(text, body); + Self { + before, + inner, + after, + } + } +} + +impl Display for SortableContent<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}{}{}", self.before, self.inner, self.after) + } +} + impl<'a> TypeIdent<'a> { fn from_node(text: &'a str, node: Node<'a>) -> Self { let get_field_str = |field_name| {