From 44ae505042a39b8fbc2be5427b8cbb22262f2c13 Mon Sep 17 00:00:00 2001 From: Morgante Pell Date: Thu, 9 May 2024 03:22:16 -0700 Subject: [PATCH] feat: apply transforms on stdin (#320) --- Cargo.lock | 1 - crates/cli/Cargo.toml | 1 - crates/cli/src/analyze.rs | 144 +++++++++++-------- crates/cli/src/commands/apply.rs | 2 +- crates/cli/src/commands/apply_pattern.rs | 172 ++++++++++++++++++----- crates/cli/src/commands/check.rs | 1 + crates/cli/src/commands/mod.rs | 1 + crates/cli/src/commands/plumbing.rs | 2 +- crates/cli/src/flags.rs | 49 ++++--- crates/cli/src/messenger_variant.rs | 18 ++- crates/cli/src/result_formatting.rs | 78 ++++++++++ crates/cli_bin/tests/apply.rs | 143 ++++++++++++++++--- crates/core/src/problem.rs | 10 ++ 13 files changed, 485 insertions(+), 137 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 581995cd5..049bd7d59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1998,7 +1998,6 @@ dependencies = [ "grit-util", "grit_cache", "grit_cloud_client", - "ignore", "indicatif", "indicatif-log-bridge", "lazy_static", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 823ddaefe..5384485f1 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -19,7 +19,6 @@ path = "src/lib.rs" anyhow = { version = "1.0.70" } clap = { version = "4.1.13", features = ["derive"] } indicatif = { version = "0.17.5" } -ignore = { version = "0.4.20" } # Do *NOT* upgrade beyond 1.0.171 until https://github.com/serde-rs/serde/issues/2538 is fixed serde = { version = "1.0.164", features = ["derive"] } serde_json = { version = "1.0.96" } diff --git a/crates/cli/src/analyze.rs b/crates/cli/src/analyze.rs index 6f83c96c9..d31abcdbe 100644 --- a/crates/cli/src/analyze.rs +++ b/crates/cli/src/analyze.rs @@ -11,7 +11,8 @@ use tracing_opentelemetry::OpenTelemetrySpanExt as _; use grit_cache::paths::cache_for_cwd; use grit_util::{FileRange, Position}; -use ignore::Walk; +use marzano_language::target_language::expand_paths; + use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; #[allow(unused_imports)] use marzano_core::built_in_functions::BuiltIns; @@ -140,10 +141,9 @@ macro_rules! emit_error { #[allow(clippy::too_many_arguments)] pub async fn par_apply_pattern( - file_walker: Walk, multi: MultiProgress, compiled: Problem, - my_input: &ApplyInput, + my_input: ApplyInput, mut owned_emitter: M, processed: &AtomicI32, details: &mut ApplyDetails, @@ -190,65 +190,77 @@ where let mut interactive = arg.interactive; let min_level = &arg.visibility; - let (file_paths_tx, file_paths_rx) = channel(); + let (found_count, disk_paths) = match my_input { + ApplyInput::Disk(ref my_input) => { + let (file_paths_tx, file_paths_rx) = channel(); - for file in file_walker { - let file = emit_error!(owned_emitter, &arg.visibility, file); - if file.file_type().unwrap().is_dir() { - continue; - } - if !&compiled.language.match_extension( - file.path() - .extension() - .unwrap_or_default() - .to_str() - .unwrap_or_default(), - ) { - processed.fetch_add(1, Ordering::SeqCst); - let path_string = file.path().to_string_lossy().to_string(); - if my_input.paths.contains(&file.path().to_path_buf()) { - let log = MatchResult::AnalysisLog(AnalysisLog { - level: 410, - message: format!( - "Skipped {} since it is not a {} file", - path_string, - &compiled.language.to_string() - ), - position: Position::first(), - file: path_string.to_string(), - engine_id: "marzano".to_string(), - range: None, - syntax_tree: None, - source: None, - }); - let done_file = MatchResult::DoneFile(DoneFile { - relative_file_path: path_string, - has_results: Some(false), - file_hash: None, - from_cache: false, - }); - emitter.handle_results( - vec![log, done_file], - details, - arg.dry_run, - min_level, - arg.format, - &mut interactive, - None, - Some(processed), - None, - &compiled.language, - ); + let file_walker = emit_error!( + owned_emitter, + &arg.visibility, + expand_paths(&my_input.paths, Some(&[(&compiled.language).into()])) + ); + + for file in file_walker { + let file = emit_error!(owned_emitter, &arg.visibility, file); + if file.file_type().unwrap().is_dir() { + continue; + } + if !&compiled.language.match_extension( + file.path() + .extension() + .unwrap_or_default() + .to_str() + .unwrap_or_default(), + ) { + processed.fetch_add(1, Ordering::SeqCst); + let path_string = file.path().to_string_lossy().to_string(); + if my_input.paths.contains(&file.path().to_path_buf()) { + let log = MatchResult::AnalysisLog(AnalysisLog { + level: 410, + message: format!( + "Skipped {} since it is not a {} file", + path_string, + &compiled.language.to_string() + ), + position: Position::first(), + file: path_string.to_string(), + engine_id: "marzano".to_string(), + range: None, + syntax_tree: None, + source: None, + }); + let done_file = MatchResult::DoneFile(DoneFile { + relative_file_path: path_string, + has_results: Some(false), + file_hash: None, + from_cache: false, + }); + emitter.handle_results( + vec![log, done_file], + details, + arg.dry_run, + min_level, + arg.format, + &mut interactive, + None, + Some(processed), + None, + &compiled.language, + ); + } + continue; + } + file_paths_tx.send(file.path().to_path_buf()).unwrap(); } - continue; - } - file_paths_tx.send(file.path().to_path_buf()).unwrap(); - } - drop(file_paths_tx); + drop(file_paths_tx); + + let found_paths = file_paths_rx.iter().collect::>(); + (found_paths.len(), Some(found_paths)) + } + ApplyInput::Virtual(ref virtual_info) => (virtual_info.files.len(), None), + }; - let found_paths = file_paths_rx.iter().collect::>(); - let found_count = found_paths.len(); if let Some(pg) = pg { pg.set_length(found_count.try_into().unwrap()); } @@ -257,7 +269,7 @@ where #[cfg(feature = "grit_timing")] debug!( "Walked {} files in {}ms", - found_paths.len(), + found_count, current_timer.elapsed().as_millis() ); @@ -323,7 +335,19 @@ where #[cfg(feature = "grit_tracing")] task_span.set_parent(grouped_ctx); task_span.in_scope(|| { - compiled.execute_paths_streaming(found_paths, context, tx, cache_ref); + match disk_paths { + Some(found_paths) => { + compiled.execute_paths_streaming(found_paths, context, tx, cache_ref); + } + None => { + if let ApplyInput::Virtual(my_input) = my_input { + compiled.execute_files_streaming(my_input.files, context, tx, cache_ref); + } else { + unreachable!(); + } + } + } + loop { if processed.load(Ordering::SeqCst) >= found_count.try_into().unwrap() || !should_continue.load(Ordering::SeqCst) diff --git a/crates/cli/src/commands/apply.rs b/crates/cli/src/commands/apply.rs index 3c3bbbeba..6aa490e4a 100644 --- a/crates/cli/src/commands/apply.rs +++ b/crates/cli/src/commands/apply.rs @@ -91,7 +91,7 @@ pub(crate) async fn run_apply( details, None, None, - flags.into(), + flags, None, ) .await diff --git a/crates/cli/src/commands/apply_pattern.rs b/crates/cli/src/commands/apply_pattern.rs index f6fc20aaa..66438d4e0 100644 --- a/crates/cli/src/commands/apply_pattern.rs +++ b/crates/cli/src/commands/apply_pattern.rs @@ -3,6 +3,7 @@ use clap::Args; use dialoguer::Confirm; +use marzano_util::rich_path::RichFile; use tracing::instrument; #[cfg(feature = "grit_tracing")] use tracing::span; @@ -12,14 +13,13 @@ use tracing_opentelemetry::OpenTelemetrySpanExt as _; use grit_util::Position; use indicatif::MultiProgress; -use log::debug; use marzano_core::api::{AllDone, AllDoneReason, AnalysisLog, MatchResult}; use marzano_core::pattern_compiler::CompilationResult; use marzano_gritmodule::fetcher::KeepFetcherKind; use marzano_gritmodule::markdown::get_body_from_md_content; use marzano_gritmodule::searcher::find_grit_modules_dir; use marzano_gritmodule::utils::is_pattern_name; -use marzano_language::target_language::{expand_paths, PatternLanguage}; +use marzano_language::target_language::PatternLanguage; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::env; @@ -30,6 +30,7 @@ use tokio::fs; use crate::commands::filters::extract_filter_ranges; +use crate::flags::GlobalFormatFlags; use crate::{ analyze::par_apply_pattern, error::GoodError, flags::OutputFormat, messenger_variant::create_emitter, result_formatting::get_human_error, updater::Updater, @@ -46,13 +47,50 @@ use crate::utils::has_uncommitted_changes; use super::filters::SharedFilterArgs; use super::init::init_config_from_cwd; +/// Apply a pattern to a set of paths on disk which will be rewritten in place #[derive(Deserialize)] -pub struct ApplyInput { +pub struct ApplyInputDisk { pub pattern_body: String, pub pattern_libs: BTreeMap, pub paths: Vec, } +#[derive(Deserialize)] +pub struct ApplyInputVirtual { + pub pattern_body: String, + pub pattern_libs: BTreeMap, + pub files: Vec, +} + +#[derive(Deserialize)] +pub enum ApplyInput { + Disk(ApplyInputDisk), + Virtual(ApplyInputVirtual), +} + +impl ApplyInput { + pub fn pattern_body(&self) -> &str { + match self { + ApplyInput::Disk(d) => &d.pattern_body, + ApplyInput::Virtual(v) => &v.pattern_body, + } + } + + pub fn pattern_libs(&self) -> &BTreeMap { + match self { + ApplyInput::Disk(d) => &d.pattern_libs, + ApplyInput::Virtual(v) => &v.pattern_libs, + } + } + + pub fn is_empty(&self) -> bool { + match self { + ApplyInput::Disk(d) => d.paths.is_empty(), + ApplyInput::Virtual(v) => v.files.is_empty(), + } + } +} + #[derive(Args, Clone, Debug, Serialize)] pub struct ApplyPatternArgs { // Level of detail to show for results @@ -103,6 +141,14 @@ pub struct ApplyPatternArgs { help = "Path to a file to write the results to, defaults to stdout" )] output_file: Option, + /// Use this option when you want to transform code piped from `stdin`, and print the output to `stdout`. + /// + /// If you use this option, you *must* specify a file path, to allow Grit to determine the language of the code. + /// + /// Example: `echo 'console.log(hello)' | grit apply '`hello` => `goodbye`' file.js --stdin + /// This will print `console.log(goodbye)` to stdout + #[clap(long = "stdin")] + pub stdin: bool, /// Use cache #[clap(long = "cache", conflicts_with = "refresh_cache")] pub cache: bool, @@ -133,6 +179,7 @@ impl Default for ApplyPatternArgs { refresh_cache: Default::default(), ai: Default::default(), language: Default::default(), + stdin: Default::default(), } } } @@ -159,8 +206,8 @@ pub(crate) async fn run_apply_pattern( multi: MultiProgress, details: &mut ApplyDetails, pattern_libs: Option>, - lang: Option, - format: OutputFormat, + default_lang: Option, + format: &GlobalFormatFlags, root_path: Option, ) -> Result<()> { let mut context = Updater::from_current_bin() @@ -169,6 +216,36 @@ pub(crate) async fn run_apply_pattern( .get_context() .unwrap(); + let format = OutputFormat::from_flags( + format, + if arg.stdin { + OutputFormat::Transformed + } else { + OutputFormat::Standard + }, + ); + + let default_lang = default_lang.or(arg.language); + + let default_lang = if !arg.stdin { + default_lang + } else if default_lang.is_none() { + // Look at the first path and get the language from the extension + let first_path = paths.first().ok_or(anyhow::anyhow!( + "A path must be provided as the virtual file name for stdin" + ))?; + let ext = first_path.extension().ok_or(anyhow::anyhow!( + "A path must have an extension to determine the language for stdin" + ))?; + if let Some(ext) = ext.to_str() { + PatternLanguage::from_extension(ext) + } else { + default_lang + } + } else { + default_lang + }; + if arg.ignore_limit { context.ignore_limit_pattern = true; } @@ -176,6 +253,16 @@ pub(crate) async fn run_apply_pattern( let interactive = arg.interactive; let min_level = &arg.visibility; + let mut emitter = create_emitter( + &format, + arg.output.clone(), + arg.output_file.as_ref(), + interactive, + Some(&pattern), + root_path.as_ref(), + ) + .await?; + #[cfg(feature = "ai_querygen")] if arg.ai { log::info!("{}", style("Computing query...").bold()); @@ -204,16 +291,6 @@ pub(crate) async fn run_apply_pattern( #[cfg(feature = "grit_tracing")] module_resolution.exit(); - let mut emitter = create_emitter( - &format, - arg.output.clone(), - arg.output_file.as_ref(), - interactive, - Some(&pattern), - root_path.as_ref(), - ) - .await?; - let filter_range = flushable_unwrap!( emitter, extract_filter_ranges(&shared, current_repo_root.as_ref()) @@ -224,12 +301,12 @@ pub(crate) async fn run_apply_pattern( let (my_input, lang) = if let Some(pattern_libs) = pattern_libs { ( - ApplyInput { + ApplyInputDisk { pattern_body: pattern.clone(), paths, pattern_libs, }, - lang, + default_lang, ) } else { #[cfg(feature = "grit_tracing")] @@ -333,7 +410,7 @@ pub(crate) async fn run_apply_pattern( } } }; - if let Some(lang_option) = &arg.language { + if let Some(lang_option) = &default_lang { if let Some(lang) = lang { if lang != *lang_option { return Err(anyhow::anyhow!( @@ -351,8 +428,9 @@ pub(crate) async fn run_apply_pattern( ); #[cfg(feature = "grit_tracing")] grit_file_discovery.exit(); + ( - ApplyInput { + ApplyInputDisk { pattern_body, pattern_libs, paths: paths.to_owned(), @@ -361,7 +439,38 @@ pub(crate) async fn run_apply_pattern( ) }; - if my_input.paths.is_empty() { + let final_input = if arg.stdin { + let mut content = String::new(); + use std::io::Read; + std::io::stdin().read_to_string(&mut content)?; + + let ApplyInputDisk { + pattern_body, + pattern_libs, + paths, + } = my_input; + + if paths.len() != 1 { + bail!("Only one path can be provided as the virtual file name for --stdin"); + } + + let first_path = paths.first().ok_or(anyhow::anyhow!( + "A path must be provided as the virtual file name for stdin" + ))?; + + ApplyInput::Virtual(ApplyInputVirtual { + pattern_body, + pattern_libs, + files: vec![RichFile { + path: first_path.to_string_lossy().into(), + content, + }], + }) + } else { + ApplyInput::Disk(my_input) + }; + + if final_input.is_empty() { let all_done = MatchResult::AllDone(AllDone { processed: 0, found: 0, @@ -378,8 +487,8 @@ pub(crate) async fn run_apply_pattern( let current_name = if is_pattern_name(&pattern) { Some(pattern.trim_end_matches("()").to_string()) } else { - my_input - .pattern_libs + final_input + .pattern_libs() .iter() .find(|(_, body)| body.trim() == pattern.trim()) .map(|(name, _)| name.clone()) @@ -389,7 +498,7 @@ pub(crate) async fn run_apply_pattern( let pattern: crate::resolver::RichPattern<'_> = flushable_unwrap!( emitter, - resolver.make_pattern(&my_input.pattern_body, current_name) + resolver.make_pattern(final_input.pattern_body(), current_name) ); #[cfg(feature = "grit_tracing")] @@ -398,7 +507,7 @@ pub(crate) async fn run_apply_pattern( let CompilationResult { problem: compiled, compilation_warnings, - } = match pattern.compile(&my_input.pattern_libs, lang, filter_range, arg.limit) { + } = match pattern.compile(final_input.pattern_libs(), lang, filter_range, arg.limit) { Ok(c) => c, Err(e) => { let log = match e.downcast::() { @@ -423,7 +532,7 @@ pub(crate) async fn run_apply_pattern( (false, false) => bail!(GoodError::new()), (false, true) => bail!(GoodError::new_with_message(get_human_error( log, - &my_input.pattern_body + final_input.pattern_body(), ))), } } @@ -434,23 +543,12 @@ pub(crate) async fn run_apply_pattern( .unwrap(); } - debug!( - "Applying pattern: {:?}, {:?}", - my_input.paths, compiled.language - ); - - let file_walker = flushable_unwrap!( - emitter, - expand_paths(&my_input.paths, Some(&[(&compiled.language).into()])) - ); - let processed = AtomicI32::new(0); let mut emitter = par_apply_pattern( - file_walker, multi, compiled, - &my_input, + final_input, emitter, &processed, details, diff --git a/crates/cli/src/commands/check.rs b/crates/cli/src/commands/check.rs index 03d55047f..bd156c245 100644 --- a/crates/cli/src/commands/check.rs +++ b/crates/cli/src/commands/check.rs @@ -259,6 +259,7 @@ pub(crate) async fn run_check( match emitter { crate::messenger_variant::MessengerVariant::Formatted(_) + | crate::messenger_variant::MessengerVariant::Transformed(_) | crate::messenger_variant::MessengerVariant::JsonLine(_) => { info!("Local only, skipping check registration."); } diff --git a/crates/cli/src/commands/mod.rs b/crates/cli/src/commands/mod.rs index 0cae7a6d6..85b63e474 100644 --- a/crates/cli/src/commands/mod.rs +++ b/crates/cli/src/commands/mod.rs @@ -323,6 +323,7 @@ async fn run_command() -> Result<()> { .filter_level(log_level) .target(match format { OutputFormat::Standard => env_logger::Target::Stdout, + OutputFormat::Transformed => env_logger::Target::Stderr, OutputFormat::Json | OutputFormat::Jsonl => env_logger::Target::Stderr, #[cfg(feature = "remote_redis")] OutputFormat::Redis => env_logger::Target::Stderr, diff --git a/crates/cli/src/commands/plumbing.rs b/crates/cli/src/commands/plumbing.rs index fbca62173..97cab0e8b 100644 --- a/crates/cli/src/commands/plumbing.rs +++ b/crates/cli/src/commands/plumbing.rs @@ -155,7 +155,7 @@ pub(crate) async fn run_plumbing( details, Some(pattern_libs.library()), Some(pattern_libs.language()), - parent.into(), + &parent, input.root_path.map(|p| ensure_trailing_slash(&p)), ) .await diff --git a/crates/cli/src/flags.rs b/crates/cli/src/flags.rs index 5e993246d..d94b93316 100644 --- a/crates/cli/src/flags.rs +++ b/crates/cli/src/flags.rs @@ -22,6 +22,8 @@ pub struct GlobalFormatFlags { #[derive(Debug, PartialEq, Clone)] pub enum OutputFormat { Standard, + /// Print every transformed file back out in full, with no other output + Transformed, Json, Jsonl, #[cfg(feature = "remote_redis")] @@ -33,24 +35,9 @@ pub enum OutputFormat { } impl OutputFormat { - /// Should the command always succeed, and should we show an error message? - /// Returns (always_succeed, show_error) - pub fn is_always_ok(&self) -> (bool, bool) { - match self { - OutputFormat::Standard => (false, false), - OutputFormat::Json | OutputFormat::Jsonl => (true, true), - #[cfg(feature = "remote_redis")] - OutputFormat::Redis => (false, true), - #[cfg(feature = "remote_pubsub")] - OutputFormat::PubSub => (false, true), - #[cfg(feature = "server")] - OutputFormat::Combined => (false, true), - } - } -} - -impl From<&GlobalFormatFlags> for OutputFormat { - fn from(flags: &GlobalFormatFlags) -> Self { + /// Gets the OutputFormat from the GlobalFormatFlags + /// A default should be provided based on other CLI flags + pub fn from_flags(flags: &GlobalFormatFlags, default: OutputFormat) -> Self { #[cfg(feature = "server")] if flags.pubsub && flags.redis { return OutputFormat::Combined; @@ -68,11 +55,35 @@ impl From<&GlobalFormatFlags> for OutputFormat { } else if flags.jsonl { OutputFormat::Jsonl } else { - OutputFormat::Standard + default + } + } +} + +impl OutputFormat { + /// Should the command always succeed, and should we show an error message? + /// Returns (always_succeed, show_error) + pub fn is_always_ok(&self) -> (bool, bool) { + match self { + OutputFormat::Standard => (false, false), + OutputFormat::Transformed => (false, false), + OutputFormat::Json | OutputFormat::Jsonl => (true, true), + #[cfg(feature = "remote_redis")] + OutputFormat::Redis => (false, true), + #[cfg(feature = "remote_pubsub")] + OutputFormat::PubSub => (false, true), + #[cfg(feature = "server")] + OutputFormat::Combined => (false, true), } } } +impl From<&GlobalFormatFlags> for OutputFormat { + fn from(flags: &GlobalFormatFlags) -> Self { + OutputFormat::from_flags(flags, OutputFormat::Standard) + } +} + impl From for OutputFormat { fn from(flags: GlobalFormatFlags) -> Self { (&flags).into() diff --git a/crates/cli/src/messenger_variant.rs b/crates/cli/src/messenger_variant.rs index 2d5f51834..49dce8dde 100644 --- a/crates/cli/src/messenger_variant.rs +++ b/crates/cli/src/messenger_variant.rs @@ -15,12 +15,17 @@ use cli_server::pubsub::GooglePubSubMessenger; #[cfg(feature = "remote_redis")] use cli_server::redis::RedisMessenger; -use crate::{flags::OutputFormat, jsonl::JSONLineMessenger, result_formatting::FormattedMessager}; +use crate::{ + flags::OutputFormat, + jsonl::JSONLineMessenger, + result_formatting::{FormattedMessager, TransformedMessenger}, +}; #[allow(clippy::large_enum_variant)] pub enum MessengerVariant<'a> { Formatted(FormattedMessager<'a>), JsonLine(JSONLineMessenger<'a>), + Transformed(TransformedMessenger<'a>), #[cfg(feature = "remote_redis")] Redis(RedisMessenger), #[cfg(feature = "remote_pubsub")] @@ -33,6 +38,7 @@ impl<'a> Messager for MessengerVariant<'a> { fn raw_emit(&mut self, message: &marzano_core::api::MatchResult) -> anyhow::Result<()> { match self { MessengerVariant::Formatted(m) => m.raw_emit(message), + MessengerVariant::Transformed(m) => m.raw_emit(message), MessengerVariant::JsonLine(m) => m.raw_emit(message), #[cfg(feature = "remote_redis")] MessengerVariant::Redis(m) => m.raw_emit(message), @@ -46,6 +52,7 @@ impl<'a> Messager for MessengerVariant<'a> { fn emit_estimate(&mut self, count: usize) -> anyhow::Result<()> { match self { MessengerVariant::Formatted(m) => m.emit_estimate(count), + MessengerVariant::Transformed(m) => m.emit_estimate(count), MessengerVariant::JsonLine(m) => m.emit_estimate(count), #[cfg(feature = "remote_redis")] MessengerVariant::Redis(m) => m.emit_estimate(count), @@ -59,6 +66,7 @@ impl<'a> Messager for MessengerVariant<'a> { fn start_workflow(&mut self) -> anyhow::Result<()> { match self { MessengerVariant::Formatted(m) => m.start_workflow(), + MessengerVariant::Transformed(m) => m.start_workflow(), MessengerVariant::JsonLine(m) => m.start_workflow(), #[cfg(feature = "remote_redis")] MessengerVariant::Redis(m) => m.start_workflow(), @@ -72,6 +80,7 @@ impl<'a> Messager for MessengerVariant<'a> { fn finish_workflow(&mut self, outcome: &PackagedWorkflowOutcome) -> anyhow::Result<()> { match self { MessengerVariant::Formatted(m) => m.finish_workflow(outcome), + MessengerVariant::Transformed(m) => m.finish_workflow(outcome), MessengerVariant::JsonLine(m) => m.finish_workflow(outcome), #[cfg(feature = "remote_redis")] MessengerVariant::Redis(m) => m.finish_workflow(outcome), @@ -95,6 +104,12 @@ impl<'a> From> for MessengerVariant<'a> { } } +impl<'a> From> for MessengerVariant<'a> { + fn from(value: TransformedMessenger<'a>) -> Self { + Self::Transformed(value) + } +} + #[cfg(feature = "remote_redis")] impl<'a> From for MessengerVariant<'a> { fn from(value: cli_server::redis::RedisMessenger) -> Self { @@ -170,6 +185,7 @@ pub async fn create_emitter<'a>( OutputFormat::Json => { bail!("JSON output is not supported for apply_pattern"); } + OutputFormat::Transformed => TransformedMessenger::new(writer).into(), OutputFormat::Jsonl => { let jsonl = JSONLineMessenger::new(writer.unwrap_or_else(|| Box::new(io::stdout())), mode); diff --git a/crates/cli/src/result_formatting.rs b/crates/cli/src/result_formatting.rs index 8d2b6d511..359b9c693 100644 --- a/crates/cli/src/result_formatting.rs +++ b/crates/cli/src/result_formatting.rs @@ -342,3 +342,81 @@ impl Messager for FormattedMessager<'_> { Ok(()) } } + +/// Prints the transformed files themselves, with no metadata +pub struct TransformedMessenger<'a> { + writer: Option>>>, + total_accepted: usize, + total_rejected: usize, + total_supressed: usize, +} + +impl<'a> TransformedMessenger<'_> { + pub fn new(writer: Option>) -> TransformedMessenger<'a> { + TransformedMessenger { + writer: writer.map(|w| Arc::new(Mutex::new(w))), + total_accepted: 0, + total_rejected: 0, + total_supressed: 0, + } + } +} + +impl Messager for TransformedMessenger<'_> { + fn raw_emit(&mut self, message: &MatchResult) -> anyhow::Result<()> { + match message { + MatchResult::PatternInfo(_) + | MatchResult::AllDone(_) + | MatchResult::InputFile(_) + | MatchResult::DoneFile(_) => { + // ignore these + } + MatchResult::Match(message) => { + info!("Matched file {}", message.file_name()); + } + MatchResult::Rewrite(file) => { + // Write the file contents to the output + if let Some(writer) = &mut self.writer { + let mut writer = writer.lock().map_err(|_| anyhow!("Output lock poisoned"))?; + writeln!(writer, "{}", file.rewritten.content)?; + } else { + info!("{}", file.rewritten.content); + } + } + MatchResult::CreateFile(file) => { + // Write the file contents to the output + if let Some(writer) = &mut self.writer { + let mut writer = writer.lock().map_err(|_| anyhow!("Output lock poisoned"))?; + writeln!(writer, "{}", file.rewritten.content)?; + } else { + info!("{}", file.rewritten.content); + } + } + MatchResult::RemoveFile(file) => { + info!("File {} should be removed", file.original.source_file); + } + MatchResult::AnalysisLog(_) => { + // TODO: should this go somewhere else + let formatted = FormattedResult::new(message.clone(), false); + if let Some(formatted) = formatted { + info!("{}", formatted); + } + } + } + + Ok(()) + } + + fn track_accept(&mut self, _accepted: &MatchResult) -> anyhow::Result<()> { + self.total_accepted += 1; + Ok(()) + } + fn track_reject(&mut self, _rejected: &MatchResult) -> anyhow::Result<()> { + self.total_rejected += 1; + Ok(()) + } + fn track_supress(&mut self, _supressed: &MatchResult) -> anyhow::Result<()> { + self.total_supressed += 1; + Ok(()) + } +} diff --git a/crates/cli_bin/tests/apply.rs b/crates/cli_bin/tests/apply.rs index 81a1172bd..2e7429209 100644 --- a/crates/cli_bin/tests/apply.rs +++ b/crates/cli_bin/tests/apply.rs @@ -47,21 +47,6 @@ fn pattern_file_does_not_exist() -> Result<()> { Ok(()) } -#[test] -fn malformed_stdin_input() -> Result<()> { - let mut cmd: Command = get_test_cmd()?; - - let input = r#"{ "pattern_body" : "empty paths" }"#; - - cmd.arg("plumbing").arg("apply"); - cmd.write_stdin(input); - cmd.assert() - .failure() - .stderr(predicate::str::contains("Failed to parse input JSON")); - - Ok(()) -} - #[test] fn empty_paths_array() -> Result<()> { let mut cmd = get_test_cmd()?; @@ -72,12 +57,16 @@ fn empty_paths_array() -> Result<()> { cmd.write_stdin(input); let output = cmd.output()?; + let stdout = String::from_utf8(output.stdout)?; + let stderr = String::from_utf8(output.stderr)?; + println!("stdout: {:?}", stdout); + println!("stderr: {:?}", stderr); + assert!( output.status.success(), "Command didn't finish successfully" ); - let stdout = String::from_utf8(output.stdout)?; let line = stdout.lines().next().ok_or_else(|| anyhow!("No output"))?; let v: serde_json::Value = serde_json::from_str(line)?; @@ -2463,3 +2452,125 @@ fn tty_behavior() -> Result<()> { Ok(()) } + +#[test] +fn apply_stdin() -> Result<()> { + let (_temp_dir, fixture_dir) = get_fixture("limit_files", false)?; + + let input_file = r#" +const foo = bar; +const x = 6; +console.error("nice"); +const w = 6; +console.log("king"); +console.error(w); +"#; + let expected_output = r#" +const foo = bar; +const x = 6; +console.error(foobar); +const w = 6; +console.log("king"); +console.error(foobar); + +"#; + + let mut cmd = get_test_cmd()?; + cmd.arg("apply") + .arg("`console.error($x)` where $x => `foobar`") + .arg("--stdin") + .arg("sample.js") + .current_dir(&fixture_dir); + + cmd.write_stdin(String::from_utf8(input_file.into())?); + + let result = cmd.output()?; + + let stderr = String::from_utf8(result.stderr)?; + println!("stderr: {:?}", stderr); + let stdout = String::from_utf8(result.stdout)?; + println!("stdout: {:?}", stdout); + + // assert + assert!(result.status.success(), "Command failed"); + + // Expect the output to be the same as the expected output + assert_eq!(stdout, expected_output); + + Ok(()) +} + +/// Ensure that we assume the --lang option from the file extension if using stdin +#[test] +fn apply_stdin_autocode() -> Result<()> { + let (_temp_dir, fixture_dir) = get_fixture("limit_files", false)?; + + let input_file = r#" +def cool(name): + print(name) +"#; + let expected_output = r#" +def renamed(name): + print(name) + +"#; + + let mut cmd = get_test_cmd()?; + cmd.arg("apply") + .arg("`def $x($_): $_` where $x => `renamed`") + .arg("--stdin") + .arg("sample.py") + .current_dir(&fixture_dir); + + cmd.write_stdin(String::from_utf8(input_file.into())?); + + let result = cmd.output()?; + + let stderr = String::from_utf8(result.stderr)?; + println!("stderr: {:?}", stderr); + let stdout = String::from_utf8(result.stdout)?; + println!("stdout: {:?}", stdout); + + // assert + assert!(result.status.success(), "Command failed"); + + // Expect the output to be the same as the expected output + assert_eq!(stdout, expected_output); + + Ok(()) +} + +/// Ban multiple stdin paths +#[test] +fn apply_stdin_two_paths() -> Result<()> { + let (_temp_dir, fixture_dir) = get_fixture("limit_files", false)?; + + let input_file = r#" +def cool(name): + print(name) +"#; + + let mut cmd = get_test_cmd()?; + cmd.arg("apply") + .arg("`def $x($_): $_` where $x => `renamed`") + .arg("--stdin") + .arg("sample.py") + .arg("sample2.py") + .current_dir(&fixture_dir); + + cmd.write_stdin(String::from_utf8(input_file.into())?); + + let result = cmd.output()?; + + let stderr = String::from_utf8(result.stderr)?; + println!("stderr: {:?}", stderr); + let stdout = String::from_utf8(result.stdout)?; + println!("stdout: {:?}", stdout); + + // assert + assert!(!result.status.success(), "Command should have failed"); + + assert!(stderr.contains("--stdin")); + + Ok(()) +} diff --git a/crates/core/src/problem.rs b/crates/core/src/problem.rs index 1e860c7b5..4a2f65bda 100644 --- a/crates/core/src/problem.rs +++ b/crates/core/src/problem.rs @@ -282,6 +282,16 @@ impl Problem { results } + pub fn execute_files_streaming( + &self, + files: Vec, + context: &ExecutionContext, + tx: Sender>, + cache: &impl GritCache, + ) { + self.execute_shared(files, context, tx, cache) + } + pub fn execute_paths<'a>( &self, files: Vec<&'a RichPath>,