allow passing in parameters to functions that will be used as macros

This commit is contained in:
Moritz Hölting 2024-11-12 14:17:08 +01:00
parent 7e96a43e5f
commit 03973bbac1
7 changed files with 259 additions and 72 deletions

View File

@ -292,19 +292,37 @@ impl MacroStringLiteral {
/// Returns the string content without escapement characters, leading and trailing double quotes. /// Returns the string content without escapement characters, leading and trailing double quotes.
#[must_use] #[must_use]
pub fn str_content(&self) -> String { pub fn str_content(&self) -> String {
let span = self.span(); use std::fmt::Write;
let string = span.str();
let string = &string[1..string.len() - 1]; let mut content = String::new();
if string.contains('\\') {
string for part in &self.parts {
.replace("\\n", "\n") match part {
.replace("\\r", "\r") MacroStringLiteralPart::Text(span) => {
.replace("\\t", "\t") let string = span.str();
.replace("\\\"", "\"") if string.contains('\\') {
.replace("\\\\", "\\") content += &string
} else { .replace("\\n", "\n")
string.to_string() .replace("\\r", "\r")
.replace("\\t", "\t")
.replace("\\\"", "\"")
.replace("\\\\", "\\");
} else {
content += string;
}
}
MacroStringLiteralPart::MacroUsage { identifier, .. } => {
write!(
content,
"$({})",
crate::transpile::util::identifier_to_macro(identifier.span.str())
)
.expect("can always write to string");
}
}
} }
content
} }
/// Returns the parts that make up the macro string literal. /// Returns the parts that make up the macro string literal.
@ -756,8 +774,7 @@ impl Token {
is_inside_macro = false; is_inside_macro = false;
} else if !encountered_open_parenthesis && character == '(' { } else if !encountered_open_parenthesis && character == '(' {
encountered_open_parenthesis = true; encountered_open_parenthesis = true;
} else if encountered_open_parenthesis } else if encountered_open_parenthesis && !Self::is_identifier_character(character)
&& !Self::is_valid_macro_name_character(character)
{ {
if character == '`' { if character == '`' {
return Err(UnclosedMacroUsage { return Err(UnclosedMacroUsage {
@ -766,9 +783,7 @@ impl Token {
.into()); .into());
} }
Self::walk_iter(iter, |c| { Self::walk_iter(iter, |c| c != ')' && !Self::is_identifier_character(c));
c != ')' && !Self::is_valid_macro_name_character(c)
});
return Err(InvalidMacroNameCharacter { return Err(InvalidMacroNameCharacter {
span: Self::create_span(index, iter), span: Self::create_span(index, iter),
} }
@ -816,10 +831,6 @@ impl Token {
} }
} }
fn is_valid_macro_name_character(character: char) -> bool {
character.is_ascii_alphanumeric() || character == '_'
}
/// Handles a command that is preceeded by a slash /// Handles a command that is preceeded by a slash
fn handle_command_literal(iter: &mut SourceIterator, start: usize) -> Self { fn handle_command_literal(iter: &mut SourceIterator, start: usize) -> Self {
Self::walk_iter(iter, |c| !(c.is_whitespace() && c.is_ascii_control())); Self::walk_iter(iter, |c| !(c.is_whitespace() && c.is_ascii_control()));

View File

@ -241,3 +241,46 @@ impl Display for IncompatibleFunctionAnnotation {
} }
impl std::error::Error for IncompatibleFunctionAnnotation {} impl std::error::Error for IncompatibleFunctionAnnotation {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct InvalidFunctionArguments {
pub span: Span,
pub expected: usize,
pub actual: usize,
}
impl Display for InvalidFunctionArguments {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
Message::new(
Severity::Error,
format!(
"Expected {} arguments, but got {}.",
self.expected, self.actual
)
)
)?;
let help_message = if self.expected > self.actual {
format!(
"You might want to add {} more arguments.",
self.expected - self.actual
)
} else {
format!(
"You might want to remove {} arguments.",
self.actual - self.expected
)
};
write!(
f,
"\n{}",
SourceCodeDisplay::new(&self.span, Some(help_message))
)
}
}
impl std::error::Error for InvalidFunctionArguments {}

View File

@ -70,7 +70,10 @@ impl From<&MacroStringLiteral> for MacroString {
MacroStringPart::String(span.str().to_string()) MacroStringPart::String(span.str().to_string())
} }
MacroStringLiteralPart::MacroUsage { identifier, .. } => { MacroStringLiteralPart::MacroUsage { identifier, .. } => {
MacroStringPart::MacroUsage(identifier.span.str().to_string()) MacroStringPart::MacroUsage(
crate::transpile::util::identifier_to_macro(identifier.span.str())
.to_string(),
)
} }
}) })
.collect(), .collect(),

View File

@ -10,7 +10,7 @@ use crate::{
log::{Message, Severity, SourceCodeDisplay}, log::{Message, Severity, SourceCodeDisplay},
source_file::Span, source_file::Span,
}, },
semantic::error::UnexpectedExpression, semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression},
}; };
use super::transpiler::FunctionData; use super::transpiler::FunctionData;
@ -29,6 +29,8 @@ pub enum TranspileError {
LuaRuntimeError(#[from] LuaRuntimeError), LuaRuntimeError(#[from] LuaRuntimeError),
#[error(transparent)] #[error(transparent)]
ConflictingFunctionNames(#[from] ConflictingFunctionNames), ConflictingFunctionNames(#[from] ConflictingFunctionNames),
#[error(transparent)]
InvalidFunctionArguments(#[from] InvalidFunctionArguments),
} }
/// The result of a transpilation operation. /// The result of a transpilation operation.
@ -143,30 +145,3 @@ impl LuaRuntimeError {
} }
} }
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConflictingFunctionNames {
pub definition: Span,
pub name: String,
}
impl Display for ConflictingFunctionNames {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
Message::new(
Severity::Error,
format!("the following function declaration conflicts with an existing function with name `{}`", self.name)
)
)?;
write!(
f,
"\n{}",
SourceCodeDisplay::new(&self.definition, Option::<u8>::None)
)
}
}
impl std::error::Error for ConflictingFunctionNames {}

View File

@ -14,4 +14,4 @@ mod transpiler;
#[doc(inline)] #[doc(inline)]
pub use transpiler::Transpiler; pub use transpiler::Transpiler;
mod util; pub mod util;

View File

@ -14,7 +14,7 @@ use crate::{
source_file::{SourceElement, Span}, source_file::{SourceElement, Span},
Handler, Handler,
}, },
semantic::error::UnexpectedExpression, semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression},
syntax::syntax_tree::{ syntax::syntax_tree::{
declaration::{Declaration, ImportItems}, declaration::{Declaration, ImportItems},
expression::{Expression, FunctionCall, Primary}, expression::{Expression, FunctionCall, Primary},
@ -24,7 +24,7 @@ use crate::{
Statement, Statement,
}, },
}, },
transpile::error::{ConflictingFunctionNames, TranspileMissingFunctionDeclaration}, transpile::error::TranspileMissingFunctionDeclaration,
}; };
use super::error::{TranspileError, TranspileResult}; use super::error::{TranspileError, TranspileResult};
@ -103,7 +103,7 @@ impl Transpiler {
); );
for identifier_span in always_transpile_functions { for identifier_span in always_transpile_functions {
self.get_or_transpile_function(&identifier_span, handler)?; self.get_or_transpile_function(&identifier_span, None, handler)?;
} }
Ok(()) Ok(())
@ -221,8 +221,9 @@ impl Transpiler {
fn get_or_transpile_function( fn get_or_transpile_function(
&mut self, &mut self,
identifier_span: &Span, identifier_span: &Span,
arguments: Option<&[&Expression]>,
handler: &impl Handler<base::Error>, handler: &impl Handler<base::Error>,
) -> TranspileResult<String> { ) -> TranspileResult<(String, Option<BTreeMap<String, String>>)> {
let program_identifier = identifier_span.source_file().identifier(); let program_identifier = identifier_span.source_file().identifier();
let program_query = ( let program_query = (
program_identifier.to_string(), program_identifier.to_string(),
@ -337,21 +338,96 @@ impl Transpiler {
); );
} }
let locations = self.function_locations.read().unwrap(); let parameters = {
locations let functions = self.functions.read().unwrap();
.get(&program_query) let function_data = functions
.or_else(|| alias_query.and_then(|q| locations.get(&q).filter(|(_, p)| *p))) .get(&program_query)
.ok_or_else(|| { .or_else(|| {
let error = TranspileError::MissingFunctionDeclaration( alias_query
TranspileMissingFunctionDeclaration::from_context( .clone()
identifier_span.clone(), .and_then(|q| functions.get(&q).filter(|f| f.public))
&self.functions.read().unwrap(), })
), .ok_or_else(|| {
); let error = TranspileError::MissingFunctionDeclaration(
handler.receive(error.clone()); TranspileMissingFunctionDeclaration::from_context(
error identifier_span.clone(),
}) &functions,
.map(|(s, _)| s.to_owned()) ),
);
handler.receive(error.clone());
error
})?;
function_data.parameters.clone()
};
let function_location = {
let locations = self.function_locations.read().unwrap();
locations
.get(&program_query)
.or_else(|| alias_query.and_then(|q| locations.get(&q).filter(|(_, p)| *p)))
.ok_or_else(|| {
let error = TranspileError::MissingFunctionDeclaration(
TranspileMissingFunctionDeclaration::from_context(
identifier_span.clone(),
&self.functions.read().unwrap(),
),
);
handler.receive(error.clone());
error
})
.map(|(s, _)| s.to_owned())?
};
let arg_count = arguments.iter().flat_map(|x| x.iter()).count();
if arg_count != parameters.len() {
let err = TranspileError::InvalidFunctionArguments(InvalidFunctionArguments {
expected: parameters.len(),
actual: arg_count,
span: identifier_span.clone(),
});
handler.receive(err.clone());
Err(err)
} else if arg_count > 0 {
let mut compiled_args = Vec::new();
let mut errs = Vec::new();
for expression in arguments.iter().flat_map(|x| x.iter()) {
let value = match expression {
Expression::Primary(Primary::FunctionCall(func)) => self
.transpile_function_call(func, handler)
.map(|cmd| match cmd {
Command::Raw(s) => s,
_ => unreachable!("Function call should always return a raw command"),
}),
Expression::Primary(Primary::Lua(lua)) => {
lua.eval_string(handler).map(Option::unwrap_or_default)
}
Expression::Primary(Primary::StringLiteral(string)) => {
Ok(string.str_content().to_string())
}
Expression::Primary(Primary::MacroStringLiteral(literal)) => {
Ok(literal.str_content())
}
};
match value {
Ok(value) => {
compiled_args.push(value);
}
Err(err) => {
compiled_args.push(String::new());
errs.push(err.clone());
}
}
}
if let Some(err) = errs.first() {
return Err(err.clone());
}
let function_args = parameters.into_iter().zip(compiled_args).collect();
Ok((function_location, Some(function_args)))
} else {
Ok((function_location, None))
}
} }
fn transpile_function( fn transpile_function(
@ -456,8 +532,30 @@ impl Transpiler {
func: &FunctionCall, func: &FunctionCall,
handler: &impl Handler<base::Error>, handler: &impl Handler<base::Error>,
) -> TranspileResult<Command> { ) -> TranspileResult<Command> {
let location = self.get_or_transpile_function(&func.identifier().span, handler)?; let arguments = func.arguments().as_ref().map(|l| {
Ok(Command::Raw(format!("function {location}"))) l.elements()
.map(derive_more::Deref::deref)
.collect::<Vec<_>>()
});
let (location, arguments) =
self.get_or_transpile_function(&func.identifier().span, arguments.as_deref(), handler)?;
let mut function_call = format!("function {location}");
if let Some(arguments) = arguments {
use std::fmt::Write;
let arguments = arguments
.iter()
.map(|(ident, v)| {
format!(
r#"{macro_name}:"{escaped}""#,
macro_name = super::util::identifier_to_macro(ident),
escaped = super::util::escape_str(v)
)
})
.collect::<Vec<_>>()
.join(",");
write!(function_call, " {{{arguments}}}").unwrap();
}
Ok(Command::Raw(function_call))
} }
fn transpile_execute_block( fn transpile_execute_block(

View File

@ -1,3 +1,8 @@
//! Utility methods for transpiling
use chksum_md5 as md5;
use std::borrow::Cow;
fn normalize_program_identifier<S>(identifier: S) -> String fn normalize_program_identifier<S>(identifier: S) -> String
where where
S: AsRef<str>, S: AsRef<str>,
@ -19,6 +24,8 @@ where
.join("/") .join("/")
} }
/// Calculate the identifier to import the function based on the current identifier and the import path
#[must_use]
pub fn calculate_import_identifier<S, T>(current_identifier: S, import_path: T) -> String pub fn calculate_import_identifier<S, T>(current_identifier: S, import_path: T) -> String
where where
S: AsRef<str>, S: AsRef<str>,
@ -32,3 +39,53 @@ where
normalize_program_identifier(identifier_elements.join("/") + "/" + import_path.as_ref()) normalize_program_identifier(identifier_elements.join("/") + "/" + import_path.as_ref())
} }
} }
/// Escapes `"` and `\` in a string.
#[must_use]
pub fn escape_str(s: &str) -> Cow<str> {
if s.contains('"') || s.contains('\\') {
let mut escaped = String::with_capacity(s.len());
for c in s.chars() {
match c {
'"' => escaped.push_str("\\\""),
'\\' => escaped.push_str("\\\\"),
_ => escaped.push(c),
}
}
Cow::Owned(escaped)
} else {
Cow::Borrowed(s)
}
}
/// Transforms an identifier to a macro name that only contains `a-zA-Z0-9_`.
#[must_use]
pub fn identifier_to_macro(ident: &str) -> Cow<str> {
if ident
.chars()
.any(|c| !(c == '_' && c.is_ascii_alphanumeric()))
{
let new_ident = ident
.chars()
.filter(|c| *c == '_' || c.is_ascii_alphanumeric())
.collect::<String>();
let chksum = md5::hash(ident).to_hex_lowercase();
Cow::Owned(new_ident + "_" + &chksum[..8])
} else {
Cow::Borrowed(ident)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_str() {
assert_eq!(escape_str("Hello, world!"), "Hello, world!");
assert_eq!(escape_str(r#"Hello, "world"!"#), r#"Hello, \"world\"!"#);
assert_eq!(escape_str(r"Hello, \world\!"), r"Hello, \\world\\!");
}
}