docs_gen/
main.rs

1use quote::ToTokens;
2use regex::Regex;
3use std::error::Error;
4use std::fs;
5use std::path::Path;
6use std::process::Command;
7use std::{collections::HashSet, env};
8use syn::{Attribute, File, Item, ItemImpl, ItemStruct, ItemTrait, ItemUse, Type, TypePath};
9
10#[derive(Debug)]
11struct CodeExample {
12    input_code: String,
13    output_start: usize,
14    output_end: usize,
15}
16
17#[derive(Debug)]
18struct ExtractedCode {
19    use_statements: Vec<ItemUse>,
20    trait_definitions: Vec<ItemTrait>,
21    struct_definitions: Vec<ItemStruct>,
22    fieldwork_impls: Vec<ItemImpl>,
23}
24
25fn main() -> Result<(), Box<dyn Error>> {
26    // Check if we're in verbose mode
27    let verbose = env::args().any(|arg| arg == "--verbose" || arg == "-v");
28    let verify = env::args().any(|arg| arg == "--verify");
29
30    let docs_path = env::args()
31        .skip(1)
32        .find(|arg| !arg.starts_with("--"))
33        .unwrap_or_else(|| "docs.md".to_string());
34    let content = fs::read_to_string(&docs_path)?;
35
36    println!("Looking for examples in {docs_path}...");
37    let examples = find_expandable_examples(&content)?;
38    println!("Found {} examples", examples.len());
39
40    if verbose {
41        for (i, example) in examples.iter().enumerate() {
42            println!(
43                "Example {}: {} chars of input code",
44                i + 1,
45                example.input_code.len()
46            );
47            println!(
48                "  First line: {}",
49                example.input_code.lines().next().unwrap_or("")
50            );
51        }
52    }
53
54    let mut new_content = content.clone();
55    let mut updated_count = 0;
56
57    let example_file = env::current_dir()?.join("examples/docs-expansion.rs");
58
59    // Process examples in reverse order to avoid position shifts
60    for (i, example) in examples.iter().rev().enumerate() {
61        let example_num = examples.len() - i;
62
63        println!(
64            "🔄 Processing example {} of {}...",
65            example_num,
66            examples.len()
67        );
68
69        match process_example(&example.input_code, &example_file) {
70            Ok(formatted) => {
71                if verbose {
72                    println!("Generated output ({} chars):", formatted.len());
73                    let lines: Vec<&str> = formatted.lines().collect();
74                    for (i, line) in lines.iter().take(5).enumerate() {
75                        println!("  {}: {}", i + 1, line);
76                    }
77                    if lines.len() > 5 {
78                        println!("  ... ({} more lines)", lines.len() - 5);
79                    }
80                }
81
82                // Since we're processing in reverse, original positions should still be valid
83                let start = example.output_start;
84                let end = example.output_end;
85
86                // Ensure we're on character boundaries and validate range
87                let safe_start = if start >= new_content.len() {
88                    new_content.len()
89                } else if new_content.is_char_boundary(start) {
90                    start
91                } else {
92                    // Find previous char boundary
93                    (0..=start)
94                        .rev()
95                        .find(|&i| new_content.is_char_boundary(i))
96                        .unwrap_or(0)
97                };
98
99                let safe_end = if end > new_content.len() {
100                    new_content.len()
101                } else if new_content.is_char_boundary(end) {
102                    end
103                } else {
104                    // Find next char boundary
105                    (end..new_content.len())
106                        .find(|&i| new_content.is_char_boundary(i))
107                        .unwrap_or(new_content.len())
108                };
109
110                // Ensure we have a valid range
111                if safe_start <= safe_end {
112                    new_content.replace_range(safe_start..safe_end, &formatted);
113                } else {
114                    eprintln!("⚠️  Invalid range for example {example_num}, skipping replacement");
115                }
116
117                updated_count += 1;
118                println!("✅ Example {example_num} updated successfully");
119            }
120            Err(e) => {
121                eprintln!("❌ Failed to process example {example_num}: {e}");
122                if verbose {
123                    eprintln!("Input code was:\n{}", example.input_code);
124                }
125                continue;
126            }
127        }
128    }
129
130    if verify {
131        if new_content != content {
132            eprintln!("❌ Documentation is out of date! Run `cargo run --bin docs-gen` to update.");
133            std::process::exit(1);
134        } else {
135            println!("✅ Documentation is up to date.");
136        }
137    } else {
138        // Normal mode: write changes
139        if updated_count > 0 {
140            fs::write(docs_path, new_content)?;
141            println!("📝 Updated {updated_count} examples");
142        }
143    }
144    Ok(())
145}
146
147fn find_expandable_examples(content: &str) -> Result<Vec<CodeExample>, Box<dyn Error>> {
148    let mut examples = Vec::new();
149    let block_pattern = Regex::new(r"(?s)```rust\n(.*?)\n```")?;
150    let blocks: Vec<_> = block_pattern.captures_iter(content).collect();
151
152    for (i, block_match) in blocks.iter().enumerate() {
153        let block_content = block_match.get(1).unwrap().as_str();
154
155        if block_content.contains("#[derive(") && !block_content.contains("// docgen-skip") {
156            let input_code = block_content
157                .lines()
158                .map(|line| {
159                    if let Some(stripped) = line.strip_prefix("# ") {
160                        stripped
161                    } else if line == "#" {
162                        ""
163                    } else {
164                        line
165                    }
166                })
167                .collect::<Vec<_>>()
168                .join("\n");
169
170            if let Some(next_block) = blocks.get(i + 1) {
171                let next_full = next_block.get(0).unwrap();
172
173                // Calculate byte positions for the content inside the next code block
174                let output_start = next_full.start() + 8; // +8 for "```rust\n"
175                let output_end = next_full.end() - 4; // -4 for "\n```"
176
177                examples.push(CodeExample {
178                    input_code,
179                    output_start,
180                    output_end,
181                });
182            }
183        }
184    }
185
186    Ok(examples)
187}
188
189// ... rest of your functions remain the same ...
190fn process_example(input: &str, example_file: &Path) -> Result<String, Box<dyn Error>> {
191    // First, find the struct names in the input to know what we're looking for
192    let target_structs = extract_struct_names_from_input(input)?;
193
194    // Expand the code
195    let expanded = expand_single_example(input, example_file)?;
196
197    // Parse with syn and extract what we need
198    let extracted = extract_fieldwork_code(&expanded, &target_structs)?;
199
200    // Format the output
201    format_extracted_code(&extracted)
202}
203
204fn extract_struct_names_from_input(input: &str) -> Result<HashSet<String>, Box<dyn Error>> {
205    let parsed: File = syn::parse_str(input)?;
206    let mut struct_names = HashSet::new();
207
208    for item in parsed.items {
209        if let Item::Struct(item_struct) = item {
210            struct_names.insert(item_struct.ident.to_string());
211        }
212    }
213
214    Ok(struct_names)
215}
216
217fn extract_fieldwork_code(
218    expanded: &str,
219    target_structs: &HashSet<String>,
220) -> Result<ExtractedCode, Box<dyn Error>> {
221    let parsed: File = syn::parse_str(expanded)?;
222
223    let mut use_statements = vec![];
224    let mut trait_definitions = vec![];
225    let mut struct_definitions = vec![];
226    let mut fieldwork_impls = vec![];
227
228    for item in parsed.items {
229        match item {
230            Item::Use(use_item) => {
231                use_statements.push(use_item);
232            }
233            Item::Trait(item_trait) => {
234                // Include all trait definitions found in the expanded code
235                trait_definitions.push(item_trait);
236            }
237            Item::Struct(item_struct) => {
238                let struct_name = item_struct.ident.to_string();
239                if target_structs.contains(&struct_name) {
240                    struct_definitions.push(item_struct);
241                }
242            }
243            Item::Impl(item_impl) => {
244                if is_fieldwork_impl(&item_impl, target_structs) {
245                    fieldwork_impls.push(item_impl);
246                }
247            }
248            _ => {} // Skip other items (use statements, other impls, etc.)
249        }
250    }
251
252    Ok(ExtractedCode {
253        trait_definitions,
254        struct_definitions,
255        fieldwork_impls,
256        use_statements,
257    })
258}
259
260fn is_fieldwork_impl(item_impl: &ItemImpl, target_structs: &HashSet<String>) -> bool {
261    // Must be an inherent impl (not a trait impl)
262    if item_impl.trait_.is_some() {
263        return false;
264    }
265
266    // Check if this impl is for one of our target structs
267    if let Type::Path(TypePath { path, .. }) = &*item_impl.self_ty {
268        if let Some(segment) = path.segments.last() {
269            let type_name = segment.ident.to_string();
270            return target_structs.contains(&type_name);
271        }
272    }
273
274    false
275}
276
277fn format_extracted_code(extracted: &ExtractedCode) -> Result<String, Box<dyn Error>> {
278    let mut result = vec!["// GENERATED".to_string()];
279
280    // Add commented trait definitions
281    for use_statement in &extracted.use_statements {
282        let formatted_use = concise_format(&use_statement.to_token_stream().to_string());
283        for line in formatted_use.lines() {
284            if !line.trim().is_empty()
285                && !line.starts_with("#[prelude_import]")
286                && line != "use fieldwork::Fieldwork;"
287            {
288                result.push(format!("# {line}"));
289            }
290        }
291    }
292
293    // Add commented trait definitions
294    for trait_def in &extracted.trait_definitions {
295        let formatted_trait = concise_format(&trait_def.to_token_stream().to_string());
296        for line in formatted_trait.lines() {
297            if !line.trim().is_empty() {
298                result.push(format!("# {line}"));
299            }
300        }
301    }
302
303    // Add commented struct definitions (strip fieldwork attributes)
304    for struct_def in &extracted.struct_definitions {
305        let mut cleaned_struct = struct_def.clone();
306        // Remove fieldwork attributes from the struct itself
307        cleaned_struct
308            .attrs
309            .retain(|attr| !is_fieldwork_attr(attr) && !attr.path().is_ident("doc"));
310
311        // Remove fieldwork attributes from all fields
312        for field in &mut cleaned_struct.fields {
313            field
314                .attrs
315                .retain(|attr| !is_fieldwork_attr(attr) && !attr.path().is_ident("doc"));
316        }
317
318        let formatted_struct = concise_format(&cleaned_struct.into_token_stream().to_string());
319        for line in formatted_struct.lines() {
320            if !line.trim().is_empty() {
321                result.push(format!("# {line}"));
322            }
323        }
324    }
325
326    // Add fieldwork impl blocks using prettyplease
327    for impl_block in &extracted.fieldwork_impls {
328        let formatted_impl = prettyplease::unparse(&syn::parse_quote! { #impl_block });
329        result.push(formatted_impl);
330    }
331
332    Ok(result.join("\n"))
333}
334
335fn is_fieldwork_attr(attr: &Attribute) -> bool {
336    let path = attr.path();
337    path.is_ident("fieldwork") || path.is_ident("field")
338}
339
340fn concise_format(s: &str) -> String {
341    s.replace(" : ", ": ")
342        .replace(" < ", "<")
343        .replace(" > ", ">")
344        .replace(" , ", ", ")
345        .replace(" ; ", "; ")
346        .replace(" :: ", "::")
347        .replace("# ", "#")
348        .replace(" ;", ";")
349}
350
351fn expand_single_example(input: &str, example_file: &Path) -> Result<String, Box<dyn Error>> {
352    // Write the example code to the .rs file in examples/
353    let file_content = format!("use fieldwork::Fieldwork;\n\n{input}");
354    fs::write(example_file, file_content)?;
355
356    // Run cargo expand on the example file
357    let output = Command::new("cargo")
358        .current_dir(env::current_dir()?)
359        .args(["expand", "--example", "docs-expansion"])
360        .output()?;
361
362    if output.status.success() {
363        fs::remove_file(example_file)?;
364    } else {
365        return Err(format!(
366            "cargo expand failed: {}",
367            String::from_utf8_lossy(&output.stderr)
368        )
369        .into());
370    }
371
372    if output.stdout.is_empty() {
373        return Err("cargo expand was empty, that's probably not right".into());
374    }
375
376    Ok(String::from_utf8(output.stdout)?)
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn test_byte_to_char_conversion() {
385        let text = "Hello 🦀 World";
386        assert_eq!(byte_to_char_pos(text, 0), 0);
387        assert_eq!(byte_to_char_pos(text, 6), 6); // Just before 🦀
388        assert_eq!(byte_to_char_pos(text, 10), 7); // Just after 🦀
389    }
390
391    #[test]
392    fn test_extract_struct_names() {
393        let input = r#"
394        #[derive(fieldwork::Fieldwork)]
395        struct User { name: String }
396        
397        #[derive(fieldwork::Fieldwork)]  
398        struct Post { title: String }
399        "#;
400
401        let names = extract_struct_names_from_input(input).unwrap();
402        assert!(names.contains("User"));
403        assert!(names.contains("Post"));
404        assert_eq!(names.len(), 2);
405    }
406}