diff --git a/rpc-toolkit-macro-internals/src/command/build.rs b/rpc-toolkit-macro-internals/src/command/build.rs index 3f5c5ea..477c790 100644 --- a/rpc-toolkit-macro-internals/src/command/build.rs +++ b/rpc-toolkit-macro-internals/src/command/build.rs @@ -796,10 +796,6 @@ fn cli_handler( .unwrap_or_else(|| LitStr::new(&name.to_string(), name.span())); let field_name = Ident::new(&format!("arg_{}", name), name.span()); let ty = arg.ty.clone(); - let mut ty = quote! { #ty }; - if arg.default.is_some() && !arg.optional { - ty = quote! { Option<#ty> }; - } arg_def.push(quote! { #[serde(rename = #rename)] #field_name: #ty, @@ -849,7 +845,7 @@ fn cli_handler( ::rpc_toolkit::command_helpers::prelude::default_arg_parser(arg_val, matches) } }; - if arg.optional || arg.default.is_some() { + if arg.optional { quote! { #field_name: if let Some(arg_val) = matches.value_of(#arg_name) { Some(#parse_val?) @@ -857,6 +853,28 @@ fn cli_handler( None }, } + } else if let Some(default) = &arg.default { + if let Some(default) = default { + let path: Path = match syn::parse_str(&default.value()) { + Ok(a) => a, + Err(e) => return e.into_compile_error(), + }; + quote! { + #field_name: if let Some(arg_val) = matches.value_of(#arg_name) { + #parse_val? + } else { + #path() + }, + } + } else { + quote! { + #field_name: if let Some(arg_val) = matches.value_of(#arg_name) { + #parse_val? + } else { + Default::default() + }, + } + } } else if arg.multiple.is_some() { quote! { #field_name: matches.values_of(#arg_name).iter().flatten().map(|arg_val| #parse_val).collect::>()?, diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index 9991c0b..c3d59b1 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -42,6 +42,10 @@ impl Context for AppState { } } +fn test_string() -> String { + "test".to_owned() +} + #[command( about = "Does the thing", subcommands("dothething2::", self(dothething_impl(async))) @@ -52,7 +56,7 @@ async fn dothething< >( #[context] _ctx: AppState, #[arg(short = 'a')] arg1: Option, - #[arg(short = 'b')] val: String, + #[arg(short = 'b', default = "test_string")] val: String, #[arg(short = 'c', help = "I am the flag `c`!", default)] arg3: bool, #[arg(stdin)] structured: U, ) -> Result<(Option, String, bool, U), RpcError> { @@ -147,8 +151,8 @@ async fn test_rpc() { .arg("--exact") .arg("--nocapture") .arg("--") - .arg("-b") - .arg("test") + // .arg("-b") + // .arg("test") .arg("dothething2") .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped())