allow passing in parameters to functions that will be used as macros
This commit is contained in:
		
							parent
							
								
									7e96a43e5f
								
							
						
					
					
						commit
						03973bbac1
					
				|  | @ -292,20 +292,38 @@ impl MacroStringLiteral { | ||||||
|     /// Returns the string content without escapement characters, leading and trailing double quotes.
 |     /// Returns the string content without escapement characters, leading and trailing double quotes.
 | ||||||
|     #[must_use] |     #[must_use] | ||||||
|     pub fn str_content(&self) -> String { |     pub fn str_content(&self) -> String { | ||||||
|         let span = self.span(); |         use std::fmt::Write; | ||||||
|  | 
 | ||||||
|  |         let mut content = String::new(); | ||||||
|  | 
 | ||||||
|  |         for part in &self.parts { | ||||||
|  |             match part { | ||||||
|  |                 MacroStringLiteralPart::Text(span) => { | ||||||
|                     let string = span.str(); |                     let string = span.str(); | ||||||
|         let string = &string[1..string.len() - 1]; |  | ||||||
|                     if string.contains('\\') { |                     if string.contains('\\') { | ||||||
|             string |                         content += &string | ||||||
|                             .replace("\\n", "\n") |                             .replace("\\n", "\n") | ||||||
|                             .replace("\\r", "\r") |                             .replace("\\r", "\r") | ||||||
|                             .replace("\\t", "\t") |                             .replace("\\t", "\t") | ||||||
|                             .replace("\\\"", "\"") |                             .replace("\\\"", "\"") | ||||||
|                 .replace("\\\\", "\\") |                             .replace("\\\\", "\\"); | ||||||
|                     } else { |                     } else { | ||||||
|             string.to_string() |                         content += string; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |                 MacroStringLiteralPart::MacroUsage { identifier, .. } => { | ||||||
|  |                     write!( | ||||||
|  |                         content, | ||||||
|  |                         "$({})", | ||||||
|  |                         crate::transpile::util::identifier_to_macro(identifier.span.str()) | ||||||
|  |                     ) | ||||||
|  |                     .expect("can always write to string"); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         content | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns the parts that make up the macro string literal.
 |     /// Returns the parts that make up the macro string literal.
 | ||||||
|     #[must_use] |     #[must_use] | ||||||
|  | @ -756,8 +774,7 @@ impl Token { | ||||||
|                     is_inside_macro = false; |                     is_inside_macro = false; | ||||||
|                 } else if !encountered_open_parenthesis && character == '(' { |                 } else if !encountered_open_parenthesis && character == '(' { | ||||||
|                     encountered_open_parenthesis = true; |                     encountered_open_parenthesis = true; | ||||||
|                 } else if encountered_open_parenthesis |                 } else if encountered_open_parenthesis && !Self::is_identifier_character(character) | ||||||
|                     && !Self::is_valid_macro_name_character(character) |  | ||||||
|                 { |                 { | ||||||
|                     if character == '`' { |                     if character == '`' { | ||||||
|                         return Err(UnclosedMacroUsage { |                         return Err(UnclosedMacroUsage { | ||||||
|  | @ -766,9 +783,7 @@ impl Token { | ||||||
|                         .into()); |                         .into()); | ||||||
|                     } |                     } | ||||||
| 
 | 
 | ||||||
|                     Self::walk_iter(iter, |c| { |                     Self::walk_iter(iter, |c| c != ')' && !Self::is_identifier_character(c)); | ||||||
|                         c != ')' && !Self::is_valid_macro_name_character(c) |  | ||||||
|                     }); |  | ||||||
|                     return Err(InvalidMacroNameCharacter { |                     return Err(InvalidMacroNameCharacter { | ||||||
|                         span: Self::create_span(index, iter), |                         span: Self::create_span(index, iter), | ||||||
|                     } |                     } | ||||||
|  | @ -816,10 +831,6 @@ impl Token { | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn is_valid_macro_name_character(character: char) -> bool { |  | ||||||
|         character.is_ascii_alphanumeric() || character == '_' |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /// Handles a command that is preceeded by a slash
 |     /// Handles a command that is preceeded by a slash
 | ||||||
|     fn handle_command_literal(iter: &mut SourceIterator, start: usize) -> Self { |     fn handle_command_literal(iter: &mut SourceIterator, start: usize) -> Self { | ||||||
|         Self::walk_iter(iter, |c| !(c.is_whitespace() && c.is_ascii_control())); |         Self::walk_iter(iter, |c| !(c.is_whitespace() && c.is_ascii_control())); | ||||||
|  |  | ||||||
|  | @ -241,3 +241,46 @@ impl Display for IncompatibleFunctionAnnotation { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl std::error::Error for IncompatibleFunctionAnnotation {} | impl std::error::Error for IncompatibleFunctionAnnotation {} | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone, PartialEq, Eq, Hash)] | ||||||
|  | pub struct InvalidFunctionArguments { | ||||||
|  |     pub span: Span, | ||||||
|  |     pub expected: usize, | ||||||
|  |     pub actual: usize, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Display for InvalidFunctionArguments { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         write!( | ||||||
|  |             f, | ||||||
|  |             "{}", | ||||||
|  |             Message::new( | ||||||
|  |                 Severity::Error, | ||||||
|  |                 format!( | ||||||
|  |                     "Expected {} arguments, but got {}.", | ||||||
|  |                     self.expected, self.actual | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         )?; | ||||||
|  | 
 | ||||||
|  |         let help_message = if self.expected > self.actual { | ||||||
|  |             format!( | ||||||
|  |                 "You might want to add {} more arguments.", | ||||||
|  |                 self.expected - self.actual | ||||||
|  |             ) | ||||||
|  |         } else { | ||||||
|  |             format!( | ||||||
|  |                 "You might want to remove {} arguments.", | ||||||
|  |                 self.actual - self.expected | ||||||
|  |             ) | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         write!( | ||||||
|  |             f, | ||||||
|  |             "\n{}", | ||||||
|  |             SourceCodeDisplay::new(&self.span, Some(help_message)) | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl std::error::Error for InvalidFunctionArguments {} | ||||||
|  |  | ||||||
|  | @ -70,7 +70,10 @@ impl From<&MacroStringLiteral> for MacroString { | ||||||
|                             MacroStringPart::String(span.str().to_string()) |                             MacroStringPart::String(span.str().to_string()) | ||||||
|                         } |                         } | ||||||
|                         MacroStringLiteralPart::MacroUsage { identifier, .. } => { |                         MacroStringLiteralPart::MacroUsage { identifier, .. } => { | ||||||
|                             MacroStringPart::MacroUsage(identifier.span.str().to_string()) |                             MacroStringPart::MacroUsage( | ||||||
|  |                                 crate::transpile::util::identifier_to_macro(identifier.span.str()) | ||||||
|  |                                     .to_string(), | ||||||
|  |                             ) | ||||||
|                         } |                         } | ||||||
|                     }) |                     }) | ||||||
|                     .collect(), |                     .collect(), | ||||||
|  |  | ||||||
|  | @ -10,7 +10,7 @@ use crate::{ | ||||||
|         log::{Message, Severity, SourceCodeDisplay}, |         log::{Message, Severity, SourceCodeDisplay}, | ||||||
|         source_file::Span, |         source_file::Span, | ||||||
|     }, |     }, | ||||||
|     semantic::error::UnexpectedExpression, |     semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression}, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| use super::transpiler::FunctionData; | use super::transpiler::FunctionData; | ||||||
|  | @ -29,6 +29,8 @@ pub enum TranspileError { | ||||||
|     LuaRuntimeError(#[from] LuaRuntimeError), |     LuaRuntimeError(#[from] LuaRuntimeError), | ||||||
|     #[error(transparent)] |     #[error(transparent)] | ||||||
|     ConflictingFunctionNames(#[from] ConflictingFunctionNames), |     ConflictingFunctionNames(#[from] ConflictingFunctionNames), | ||||||
|  |     #[error(transparent)] | ||||||
|  |     InvalidFunctionArguments(#[from] InvalidFunctionArguments), | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// The result of a transpilation operation.
 | /// The result of a transpilation operation.
 | ||||||
|  | @ -143,30 +145,3 @@ impl LuaRuntimeError { | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 |  | ||||||
| #[derive(Debug, Clone, PartialEq, Eq)] |  | ||||||
| pub struct ConflictingFunctionNames { |  | ||||||
|     pub definition: Span, |  | ||||||
|     pub name: String, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Display for ConflictingFunctionNames { |  | ||||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |  | ||||||
|         write!( |  | ||||||
|             f, |  | ||||||
|             "{}", |  | ||||||
|             Message::new( |  | ||||||
|                 Severity::Error, |  | ||||||
|                 format!("the following function declaration conflicts with an existing function with name `{}`", self.name) |  | ||||||
|             ) |  | ||||||
|         )?; |  | ||||||
| 
 |  | ||||||
|         write!( |  | ||||||
|             f, |  | ||||||
|             "\n{}", |  | ||||||
|             SourceCodeDisplay::new(&self.definition, Option::<u8>::None) |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl std::error::Error for ConflictingFunctionNames {} |  | ||||||
|  |  | ||||||
|  | @ -14,4 +14,4 @@ mod transpiler; | ||||||
| #[doc(inline)] | #[doc(inline)] | ||||||
| pub use transpiler::Transpiler; | pub use transpiler::Transpiler; | ||||||
| 
 | 
 | ||||||
| mod util; | pub mod util; | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ use crate::{ | ||||||
|         source_file::{SourceElement, Span}, |         source_file::{SourceElement, Span}, | ||||||
|         Handler, |         Handler, | ||||||
|     }, |     }, | ||||||
|     semantic::error::UnexpectedExpression, |     semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression}, | ||||||
|     syntax::syntax_tree::{ |     syntax::syntax_tree::{ | ||||||
|         declaration::{Declaration, ImportItems}, |         declaration::{Declaration, ImportItems}, | ||||||
|         expression::{Expression, FunctionCall, Primary}, |         expression::{Expression, FunctionCall, Primary}, | ||||||
|  | @ -24,7 +24,7 @@ use crate::{ | ||||||
|             Statement, |             Statement, | ||||||
|         }, |         }, | ||||||
|     }, |     }, | ||||||
|     transpile::error::{ConflictingFunctionNames, TranspileMissingFunctionDeclaration}, |     transpile::error::TranspileMissingFunctionDeclaration, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| use super::error::{TranspileError, TranspileResult}; | use super::error::{TranspileError, TranspileResult}; | ||||||
|  | @ -103,7 +103,7 @@ impl Transpiler { | ||||||
|         ); |         ); | ||||||
| 
 | 
 | ||||||
|         for identifier_span in always_transpile_functions { |         for identifier_span in always_transpile_functions { | ||||||
|             self.get_or_transpile_function(&identifier_span, handler)?; |             self.get_or_transpile_function(&identifier_span, None, handler)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -221,8 +221,9 @@ impl Transpiler { | ||||||
|     fn get_or_transpile_function( |     fn get_or_transpile_function( | ||||||
|         &mut self, |         &mut self, | ||||||
|         identifier_span: &Span, |         identifier_span: &Span, | ||||||
|  |         arguments: Option<&[&Expression]>, | ||||||
|         handler: &impl Handler<base::Error>, |         handler: &impl Handler<base::Error>, | ||||||
|     ) -> TranspileResult<String> { |     ) -> TranspileResult<(String, Option<BTreeMap<String, String>>)> { | ||||||
|         let program_identifier = identifier_span.source_file().identifier(); |         let program_identifier = identifier_span.source_file().identifier(); | ||||||
|         let program_query = ( |         let program_query = ( | ||||||
|             program_identifier.to_string(), |             program_identifier.to_string(), | ||||||
|  | @ -337,6 +338,30 @@ impl Transpiler { | ||||||
|             ); |             ); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         let parameters = { | ||||||
|  |             let functions = self.functions.read().unwrap(); | ||||||
|  |             let function_data = functions | ||||||
|  |                 .get(&program_query) | ||||||
|  |                 .or_else(|| { | ||||||
|  |                     alias_query | ||||||
|  |                         .clone() | ||||||
|  |                         .and_then(|q| functions.get(&q).filter(|f| f.public)) | ||||||
|  |                 }) | ||||||
|  |                 .ok_or_else(|| { | ||||||
|  |                     let error = TranspileError::MissingFunctionDeclaration( | ||||||
|  |                         TranspileMissingFunctionDeclaration::from_context( | ||||||
|  |                             identifier_span.clone(), | ||||||
|  |                             &functions, | ||||||
|  |                         ), | ||||||
|  |                     ); | ||||||
|  |                     handler.receive(error.clone()); | ||||||
|  |                     error | ||||||
|  |                 })?; | ||||||
|  | 
 | ||||||
|  |             function_data.parameters.clone() | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let function_location = { | ||||||
|             let locations = self.function_locations.read().unwrap(); |             let locations = self.function_locations.read().unwrap(); | ||||||
|             locations |             locations | ||||||
|                 .get(&program_query) |                 .get(&program_query) | ||||||
|  | @ -351,7 +376,58 @@ impl Transpiler { | ||||||
|                     handler.receive(error.clone()); |                     handler.receive(error.clone()); | ||||||
|                     error |                     error | ||||||
|                 }) |                 }) | ||||||
|             .map(|(s, _)| s.to_owned()) |                 .map(|(s, _)| s.to_owned())? | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let arg_count = arguments.iter().flat_map(|x| x.iter()).count(); | ||||||
|  |         if arg_count != parameters.len() { | ||||||
|  |             let err = TranspileError::InvalidFunctionArguments(InvalidFunctionArguments { | ||||||
|  |                 expected: parameters.len(), | ||||||
|  |                 actual: arg_count, | ||||||
|  |                 span: identifier_span.clone(), | ||||||
|  |             }); | ||||||
|  |             handler.receive(err.clone()); | ||||||
|  |             Err(err) | ||||||
|  |         } else if arg_count > 0 { | ||||||
|  |             let mut compiled_args = Vec::new(); | ||||||
|  |             let mut errs = Vec::new(); | ||||||
|  |             for expression in arguments.iter().flat_map(|x| x.iter()) { | ||||||
|  |                 let value = match expression { | ||||||
|  |                     Expression::Primary(Primary::FunctionCall(func)) => self | ||||||
|  |                         .transpile_function_call(func, handler) | ||||||
|  |                         .map(|cmd| match cmd { | ||||||
|  |                             Command::Raw(s) => s, | ||||||
|  |                             _ => unreachable!("Function call should always return a raw command"), | ||||||
|  |                         }), | ||||||
|  |                     Expression::Primary(Primary::Lua(lua)) => { | ||||||
|  |                         lua.eval_string(handler).map(Option::unwrap_or_default) | ||||||
|  |                     } | ||||||
|  |                     Expression::Primary(Primary::StringLiteral(string)) => { | ||||||
|  |                         Ok(string.str_content().to_string()) | ||||||
|  |                     } | ||||||
|  |                     Expression::Primary(Primary::MacroStringLiteral(literal)) => { | ||||||
|  |                         Ok(literal.str_content()) | ||||||
|  |                     } | ||||||
|  |                 }; | ||||||
|  | 
 | ||||||
|  |                 match value { | ||||||
|  |                     Ok(value) => { | ||||||
|  |                         compiled_args.push(value); | ||||||
|  |                     } | ||||||
|  |                     Err(err) => { | ||||||
|  |                         compiled_args.push(String::new()); | ||||||
|  |                         errs.push(err.clone()); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             if let Some(err) = errs.first() { | ||||||
|  |                 return Err(err.clone()); | ||||||
|  |             } | ||||||
|  |             let function_args = parameters.into_iter().zip(compiled_args).collect(); | ||||||
|  |             Ok((function_location, Some(function_args))) | ||||||
|  |         } else { | ||||||
|  |             Ok((function_location, None)) | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn transpile_function( |     fn transpile_function( | ||||||
|  | @ -456,8 +532,30 @@ impl Transpiler { | ||||||
|         func: &FunctionCall, |         func: &FunctionCall, | ||||||
|         handler: &impl Handler<base::Error>, |         handler: &impl Handler<base::Error>, | ||||||
|     ) -> TranspileResult<Command> { |     ) -> TranspileResult<Command> { | ||||||
|         let location = self.get_or_transpile_function(&func.identifier().span, handler)?; |         let arguments = func.arguments().as_ref().map(|l| { | ||||||
|         Ok(Command::Raw(format!("function {location}"))) |             l.elements() | ||||||
|  |                 .map(derive_more::Deref::deref) | ||||||
|  |                 .collect::<Vec<_>>() | ||||||
|  |         }); | ||||||
|  |         let (location, arguments) = | ||||||
|  |             self.get_or_transpile_function(&func.identifier().span, arguments.as_deref(), handler)?; | ||||||
|  |         let mut function_call = format!("function {location}"); | ||||||
|  |         if let Some(arguments) = arguments { | ||||||
|  |             use std::fmt::Write; | ||||||
|  |             let arguments = arguments | ||||||
|  |                 .iter() | ||||||
|  |                 .map(|(ident, v)| { | ||||||
|  |                     format!( | ||||||
|  |                         r#"{macro_name}:"{escaped}""#, | ||||||
|  |                         macro_name = super::util::identifier_to_macro(ident), | ||||||
|  |                         escaped = super::util::escape_str(v) | ||||||
|  |                     ) | ||||||
|  |                 }) | ||||||
|  |                 .collect::<Vec<_>>() | ||||||
|  |                 .join(","); | ||||||
|  |             write!(function_call, " {{{arguments}}}").unwrap(); | ||||||
|  |         } | ||||||
|  |         Ok(Command::Raw(function_call)) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn transpile_execute_block( |     fn transpile_execute_block( | ||||||
|  |  | ||||||
|  | @ -1,3 +1,8 @@ | ||||||
|  | //! Utility methods for transpiling
 | ||||||
|  | 
 | ||||||
|  | use chksum_md5 as md5; | ||||||
|  | use std::borrow::Cow; | ||||||
|  | 
 | ||||||
| fn normalize_program_identifier<S>(identifier: S) -> String | fn normalize_program_identifier<S>(identifier: S) -> String | ||||||
| where | where | ||||||
|     S: AsRef<str>, |     S: AsRef<str>, | ||||||
|  | @ -19,6 +24,8 @@ where | ||||||
|         .join("/") |         .join("/") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /// Calculate the identifier to import the function based on the current identifier and the import path
 | ||||||
|  | #[must_use] | ||||||
| pub fn calculate_import_identifier<S, T>(current_identifier: S, import_path: T) -> String | pub fn calculate_import_identifier<S, T>(current_identifier: S, import_path: T) -> String | ||||||
| where | where | ||||||
|     S: AsRef<str>, |     S: AsRef<str>, | ||||||
|  | @ -32,3 +39,53 @@ where | ||||||
|         normalize_program_identifier(identifier_elements.join("/") + "/" + import_path.as_ref()) |         normalize_program_identifier(identifier_elements.join("/") + "/" + import_path.as_ref()) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | /// Escapes `"` and `\` in a string.
 | ||||||
|  | #[must_use] | ||||||
|  | pub fn escape_str(s: &str) -> Cow<str> { | ||||||
|  |     if s.contains('"') || s.contains('\\') { | ||||||
|  |         let mut escaped = String::with_capacity(s.len()); | ||||||
|  |         for c in s.chars() { | ||||||
|  |             match c { | ||||||
|  |                 '"' => escaped.push_str("\\\""), | ||||||
|  |                 '\\' => escaped.push_str("\\\\"), | ||||||
|  |                 _ => escaped.push(c), | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         Cow::Owned(escaped) | ||||||
|  |     } else { | ||||||
|  |         Cow::Borrowed(s) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// Transforms an identifier to a macro name that only contains `a-zA-Z0-9_`.
 | ||||||
|  | #[must_use] | ||||||
|  | pub fn identifier_to_macro(ident: &str) -> Cow<str> { | ||||||
|  |     if ident | ||||||
|  |         .chars() | ||||||
|  |         .any(|c| !(c == '_' && c.is_ascii_alphanumeric())) | ||||||
|  |     { | ||||||
|  |         let new_ident = ident | ||||||
|  |             .chars() | ||||||
|  |             .filter(|c| *c == '_' || c.is_ascii_alphanumeric()) | ||||||
|  |             .collect::<String>(); | ||||||
|  | 
 | ||||||
|  |         let chksum = md5::hash(ident).to_hex_lowercase(); | ||||||
|  | 
 | ||||||
|  |         Cow::Owned(new_ident + "_" + &chksum[..8]) | ||||||
|  |     } else { | ||||||
|  |         Cow::Borrowed(ident) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(test)] | ||||||
|  | mod tests { | ||||||
|  |     use super::*; | ||||||
|  | 
 | ||||||
|  |     #[test] | ||||||
|  |     fn test_escape_str() { | ||||||
|  |         assert_eq!(escape_str("Hello, world!"), "Hello, world!"); | ||||||
|  |         assert_eq!(escape_str(r#"Hello, "world"!"#), r#"Hello, \"world\"!"#); | ||||||
|  |         assert_eq!(escape_str(r"Hello, \world\!"), r"Hello, \\world\\!"); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue