From 03973bbac1fd3441f5be92d6061f657d0c63f599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20H=C3=B6lting?= <87192362+moritz-hoelting@users.noreply.github.com> Date: Tue, 12 Nov 2024 14:17:08 +0100 Subject: [PATCH] allow passing in parameters to functions that will be used as macros --- src/lexical/token.rs | 53 +++++++------ src/semantic/error.rs | 43 +++++++++++ src/transpile/conversions.rs | 5 +- src/transpile/error.rs | 31 +------- src/transpile/mod.rs | 2 +- src/transpile/transpiler.rs | 140 +++++++++++++++++++++++++++++------ src/transpile/util.rs | 57 ++++++++++++++ 7 files changed, 259 insertions(+), 72 deletions(-) diff --git a/src/lexical/token.rs b/src/lexical/token.rs index 67a187c..9a0d9e1 100644 --- a/src/lexical/token.rs +++ b/src/lexical/token.rs @@ -292,19 +292,37 @@ impl MacroStringLiteral { /// Returns the string content without escapement characters, leading and trailing double quotes. #[must_use] pub fn str_content(&self) -> String { - let span = self.span(); - let string = span.str(); - let string = &string[1..string.len() - 1]; - if string.contains('\\') { - string - .replace("\\n", "\n") - .replace("\\r", "\r") - .replace("\\t", "\t") - .replace("\\\"", "\"") - .replace("\\\\", "\\") - } else { - string.to_string() + use std::fmt::Write; + + let mut content = String::new(); + + for part in &self.parts { + match part { + MacroStringLiteralPart::Text(span) => { + let string = span.str(); + if string.contains('\\') { + content += &string + .replace("\\n", "\n") + .replace("\\r", "\r") + .replace("\\t", "\t") + .replace("\\\"", "\"") + .replace("\\\\", "\\"); + } else { + content += string; + } + } + MacroStringLiteralPart::MacroUsage { identifier, .. } => { + write!( + content, + "$({})", + crate::transpile::util::identifier_to_macro(identifier.span.str()) + ) + .expect("can always write to string"); + } + } } + + content } /// Returns the parts that make up the macro string literal. @@ -756,8 +774,7 @@ impl Token { is_inside_macro = false; } else if !encountered_open_parenthesis && character == '(' { encountered_open_parenthesis = true; - } else if encountered_open_parenthesis - && !Self::is_valid_macro_name_character(character) + } else if encountered_open_parenthesis && !Self::is_identifier_character(character) { if character == '`' { return Err(UnclosedMacroUsage { @@ -766,9 +783,7 @@ impl Token { .into()); } - Self::walk_iter(iter, |c| { - c != ')' && !Self::is_valid_macro_name_character(c) - }); + Self::walk_iter(iter, |c| c != ')' && !Self::is_identifier_character(c)); return Err(InvalidMacroNameCharacter { span: Self::create_span(index, iter), } @@ -816,10 +831,6 @@ impl Token { } } - fn is_valid_macro_name_character(character: char) -> bool { - character.is_ascii_alphanumeric() || character == '_' - } - /// Handles a command that is preceeded by a slash fn handle_command_literal(iter: &mut SourceIterator, start: usize) -> Self { Self::walk_iter(iter, |c| !(c.is_whitespace() && c.is_ascii_control())); diff --git a/src/semantic/error.rs b/src/semantic/error.rs index 20cbe6e..9be7db1 100644 --- a/src/semantic/error.rs +++ b/src/semantic/error.rs @@ -241,3 +241,46 @@ impl Display for IncompatibleFunctionAnnotation { } impl std::error::Error for IncompatibleFunctionAnnotation {} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct InvalidFunctionArguments { + pub span: Span, + pub expected: usize, + pub actual: usize, +} + +impl Display for InvalidFunctionArguments { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + Message::new( + Severity::Error, + format!( + "Expected {} arguments, but got {}.", + self.expected, self.actual + ) + ) + )?; + + let help_message = if self.expected > self.actual { + format!( + "You might want to add {} more arguments.", + self.expected - self.actual + ) + } else { + format!( + "You might want to remove {} arguments.", + self.actual - self.expected + ) + }; + + write!( + f, + "\n{}", + SourceCodeDisplay::new(&self.span, Some(help_message)) + ) + } +} + +impl std::error::Error for InvalidFunctionArguments {} diff --git a/src/transpile/conversions.rs b/src/transpile/conversions.rs index 4b35420..f7295e5 100644 --- a/src/transpile/conversions.rs +++ b/src/transpile/conversions.rs @@ -70,7 +70,10 @@ impl From<&MacroStringLiteral> for MacroString { MacroStringPart::String(span.str().to_string()) } MacroStringLiteralPart::MacroUsage { identifier, .. } => { - MacroStringPart::MacroUsage(identifier.span.str().to_string()) + MacroStringPart::MacroUsage( + crate::transpile::util::identifier_to_macro(identifier.span.str()) + .to_string(), + ) } }) .collect(), diff --git a/src/transpile/error.rs b/src/transpile/error.rs index 7fb9d16..4b2c7f7 100644 --- a/src/transpile/error.rs +++ b/src/transpile/error.rs @@ -10,7 +10,7 @@ use crate::{ log::{Message, Severity, SourceCodeDisplay}, source_file::Span, }, - semantic::error::UnexpectedExpression, + semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression}, }; use super::transpiler::FunctionData; @@ -29,6 +29,8 @@ pub enum TranspileError { LuaRuntimeError(#[from] LuaRuntimeError), #[error(transparent)] ConflictingFunctionNames(#[from] ConflictingFunctionNames), + #[error(transparent)] + InvalidFunctionArguments(#[from] InvalidFunctionArguments), } /// The result of a transpilation operation. @@ -143,30 +145,3 @@ impl LuaRuntimeError { } } } - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ConflictingFunctionNames { - pub definition: Span, - pub name: String, -} - -impl Display for ConflictingFunctionNames { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - Message::new( - Severity::Error, - format!("the following function declaration conflicts with an existing function with name `{}`", self.name) - ) - )?; - - write!( - f, - "\n{}", - SourceCodeDisplay::new(&self.definition, Option::::None) - ) - } -} - -impl std::error::Error for ConflictingFunctionNames {} diff --git a/src/transpile/mod.rs b/src/transpile/mod.rs index 0ef8314..6f06dd0 100644 --- a/src/transpile/mod.rs +++ b/src/transpile/mod.rs @@ -14,4 +14,4 @@ mod transpiler; #[doc(inline)] pub use transpiler::Transpiler; -mod util; +pub mod util; diff --git a/src/transpile/transpiler.rs b/src/transpile/transpiler.rs index c6e9502..c4905d5 100644 --- a/src/transpile/transpiler.rs +++ b/src/transpile/transpiler.rs @@ -14,7 +14,7 @@ use crate::{ source_file::{SourceElement, Span}, Handler, }, - semantic::error::UnexpectedExpression, + semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression}, syntax::syntax_tree::{ declaration::{Declaration, ImportItems}, expression::{Expression, FunctionCall, Primary}, @@ -24,7 +24,7 @@ use crate::{ Statement, }, }, - transpile::error::{ConflictingFunctionNames, TranspileMissingFunctionDeclaration}, + transpile::error::TranspileMissingFunctionDeclaration, }; use super::error::{TranspileError, TranspileResult}; @@ -103,7 +103,7 @@ impl Transpiler { ); for identifier_span in always_transpile_functions { - self.get_or_transpile_function(&identifier_span, handler)?; + self.get_or_transpile_function(&identifier_span, None, handler)?; } Ok(()) @@ -221,8 +221,9 @@ impl Transpiler { fn get_or_transpile_function( &mut self, identifier_span: &Span, + arguments: Option<&[&Expression]>, handler: &impl Handler, - ) -> TranspileResult { + ) -> TranspileResult<(String, Option>)> { let program_identifier = identifier_span.source_file().identifier(); let program_query = ( program_identifier.to_string(), @@ -337,21 +338,96 @@ impl Transpiler { ); } - let locations = self.function_locations.read().unwrap(); - locations - .get(&program_query) - .or_else(|| alias_query.and_then(|q| locations.get(&q).filter(|(_, p)| *p))) - .ok_or_else(|| { - let error = TranspileError::MissingFunctionDeclaration( - TranspileMissingFunctionDeclaration::from_context( - identifier_span.clone(), - &self.functions.read().unwrap(), - ), - ); - handler.receive(error.clone()); - error - }) - .map(|(s, _)| s.to_owned()) + let parameters = { + let functions = self.functions.read().unwrap(); + let function_data = functions + .get(&program_query) + .or_else(|| { + alias_query + .clone() + .and_then(|q| functions.get(&q).filter(|f| f.public)) + }) + .ok_or_else(|| { + let error = TranspileError::MissingFunctionDeclaration( + TranspileMissingFunctionDeclaration::from_context( + identifier_span.clone(), + &functions, + ), + ); + handler.receive(error.clone()); + error + })?; + + function_data.parameters.clone() + }; + + let function_location = { + let locations = self.function_locations.read().unwrap(); + locations + .get(&program_query) + .or_else(|| alias_query.and_then(|q| locations.get(&q).filter(|(_, p)| *p))) + .ok_or_else(|| { + let error = TranspileError::MissingFunctionDeclaration( + TranspileMissingFunctionDeclaration::from_context( + identifier_span.clone(), + &self.functions.read().unwrap(), + ), + ); + handler.receive(error.clone()); + error + }) + .map(|(s, _)| s.to_owned())? + }; + + let arg_count = arguments.iter().flat_map(|x| x.iter()).count(); + if arg_count != parameters.len() { + let err = TranspileError::InvalidFunctionArguments(InvalidFunctionArguments { + expected: parameters.len(), + actual: arg_count, + span: identifier_span.clone(), + }); + handler.receive(err.clone()); + Err(err) + } else if arg_count > 0 { + let mut compiled_args = Vec::new(); + let mut errs = Vec::new(); + for expression in arguments.iter().flat_map(|x| x.iter()) { + let value = match expression { + Expression::Primary(Primary::FunctionCall(func)) => self + .transpile_function_call(func, handler) + .map(|cmd| match cmd { + Command::Raw(s) => s, + _ => unreachable!("Function call should always return a raw command"), + }), + Expression::Primary(Primary::Lua(lua)) => { + lua.eval_string(handler).map(Option::unwrap_or_default) + } + Expression::Primary(Primary::StringLiteral(string)) => { + Ok(string.str_content().to_string()) + } + Expression::Primary(Primary::MacroStringLiteral(literal)) => { + Ok(literal.str_content()) + } + }; + + match value { + Ok(value) => { + compiled_args.push(value); + } + Err(err) => { + compiled_args.push(String::new()); + errs.push(err.clone()); + } + } + } + if let Some(err) = errs.first() { + return Err(err.clone()); + } + let function_args = parameters.into_iter().zip(compiled_args).collect(); + Ok((function_location, Some(function_args))) + } else { + Ok((function_location, None)) + } } fn transpile_function( @@ -456,8 +532,30 @@ impl Transpiler { func: &FunctionCall, handler: &impl Handler, ) -> TranspileResult { - let location = self.get_or_transpile_function(&func.identifier().span, handler)?; - Ok(Command::Raw(format!("function {location}"))) + let arguments = func.arguments().as_ref().map(|l| { + l.elements() + .map(derive_more::Deref::deref) + .collect::>() + }); + let (location, arguments) = + self.get_or_transpile_function(&func.identifier().span, arguments.as_deref(), handler)?; + let mut function_call = format!("function {location}"); + if let Some(arguments) = arguments { + use std::fmt::Write; + let arguments = arguments + .iter() + .map(|(ident, v)| { + format!( + r#"{macro_name}:"{escaped}""#, + macro_name = super::util::identifier_to_macro(ident), + escaped = super::util::escape_str(v) + ) + }) + .collect::>() + .join(","); + write!(function_call, " {{{arguments}}}").unwrap(); + } + Ok(Command::Raw(function_call)) } fn transpile_execute_block( diff --git a/src/transpile/util.rs b/src/transpile/util.rs index d170c91..8da2518 100644 --- a/src/transpile/util.rs +++ b/src/transpile/util.rs @@ -1,3 +1,8 @@ +//! Utility methods for transpiling + +use chksum_md5 as md5; +use std::borrow::Cow; + fn normalize_program_identifier(identifier: S) -> String where S: AsRef, @@ -19,6 +24,8 @@ where .join("/") } +/// Calculate the identifier to import the function based on the current identifier and the import path +#[must_use] pub fn calculate_import_identifier(current_identifier: S, import_path: T) -> String where S: AsRef, @@ -32,3 +39,53 @@ where normalize_program_identifier(identifier_elements.join("/") + "/" + import_path.as_ref()) } } + +/// Escapes `"` and `\` in a string. +#[must_use] +pub fn escape_str(s: &str) -> Cow { + if s.contains('"') || s.contains('\\') { + let mut escaped = String::with_capacity(s.len()); + for c in s.chars() { + match c { + '"' => escaped.push_str("\\\""), + '\\' => escaped.push_str("\\\\"), + _ => escaped.push(c), + } + } + Cow::Owned(escaped) + } else { + Cow::Borrowed(s) + } +} + +/// Transforms an identifier to a macro name that only contains `a-zA-Z0-9_`. +#[must_use] +pub fn identifier_to_macro(ident: &str) -> Cow { + if ident + .chars() + .any(|c| !(c == '_' && c.is_ascii_alphanumeric())) + { + let new_ident = ident + .chars() + .filter(|c| *c == '_' || c.is_ascii_alphanumeric()) + .collect::(); + + let chksum = md5::hash(ident).to_hex_lowercase(); + + Cow::Owned(new_ident + "_" + &chksum[..8]) + } else { + Cow::Borrowed(ident) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_escape_str() { + assert_eq!(escape_str("Hello, world!"), "Hello, world!"); + assert_eq!(escape_str(r#"Hello, "world"!"#), r#"Hello, \"world\"!"#); + assert_eq!(escape_str(r"Hello, \world\!"), r"Hello, \\world\\!"); + } +}