allow passing in parameters to functions that will be used as macros

This commit is contained in:
Moritz Hölting 2024-11-12 14:17:08 +01:00
parent 7e96a43e5f
commit 03973bbac1
7 changed files with 259 additions and 72 deletions

View File

@ -292,20 +292,38 @@ impl MacroStringLiteral {
/// Returns the string content without escapement characters, leading and trailing double quotes.
#[must_use]
pub fn str_content(&self) -> String {
let span = self.span();
use std::fmt::Write;
let mut content = String::new();
for part in &self.parts {
match part {
MacroStringLiteralPart::Text(span) => {
let string = span.str();
let string = &string[1..string.len() - 1];
if string.contains('\\') {
string
content += &string
.replace("\\n", "\n")
.replace("\\r", "\r")
.replace("\\t", "\t")
.replace("\\\"", "\"")
.replace("\\\\", "\\")
.replace("\\\\", "\\");
} else {
string.to_string()
content += string;
}
}
MacroStringLiteralPart::MacroUsage { identifier, .. } => {
write!(
content,
"$({})",
crate::transpile::util::identifier_to_macro(identifier.span.str())
)
.expect("can always write to string");
}
}
}
content
}
/// Returns the parts that make up the macro string literal.
#[must_use]
@ -756,8 +774,7 @@ impl Token {
is_inside_macro = false;
} else if !encountered_open_parenthesis && character == '(' {
encountered_open_parenthesis = true;
} else if encountered_open_parenthesis
&& !Self::is_valid_macro_name_character(character)
} else if encountered_open_parenthesis && !Self::is_identifier_character(character)
{
if character == '`' {
return Err(UnclosedMacroUsage {
@ -766,9 +783,7 @@ impl Token {
.into());
}
Self::walk_iter(iter, |c| {
c != ')' && !Self::is_valid_macro_name_character(c)
});
Self::walk_iter(iter, |c| c != ')' && !Self::is_identifier_character(c));
return Err(InvalidMacroNameCharacter {
span: Self::create_span(index, iter),
}
@ -816,10 +831,6 @@ impl Token {
}
}
fn is_valid_macro_name_character(character: char) -> bool {
character.is_ascii_alphanumeric() || character == '_'
}
/// Handles a command that is preceeded by a slash
fn handle_command_literal(iter: &mut SourceIterator, start: usize) -> Self {
Self::walk_iter(iter, |c| !(c.is_whitespace() && c.is_ascii_control()));

View File

@ -241,3 +241,46 @@ impl Display for IncompatibleFunctionAnnotation {
}
impl std::error::Error for IncompatibleFunctionAnnotation {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct InvalidFunctionArguments {
pub span: Span,
pub expected: usize,
pub actual: usize,
}
impl Display for InvalidFunctionArguments {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
Message::new(
Severity::Error,
format!(
"Expected {} arguments, but got {}.",
self.expected, self.actual
)
)
)?;
let help_message = if self.expected > self.actual {
format!(
"You might want to add {} more arguments.",
self.expected - self.actual
)
} else {
format!(
"You might want to remove {} arguments.",
self.actual - self.expected
)
};
write!(
f,
"\n{}",
SourceCodeDisplay::new(&self.span, Some(help_message))
)
}
}
impl std::error::Error for InvalidFunctionArguments {}

View File

@ -70,7 +70,10 @@ impl From<&MacroStringLiteral> for MacroString {
MacroStringPart::String(span.str().to_string())
}
MacroStringLiteralPart::MacroUsage { identifier, .. } => {
MacroStringPart::MacroUsage(identifier.span.str().to_string())
MacroStringPart::MacroUsage(
crate::transpile::util::identifier_to_macro(identifier.span.str())
.to_string(),
)
}
})
.collect(),

View File

@ -10,7 +10,7 @@ use crate::{
log::{Message, Severity, SourceCodeDisplay},
source_file::Span,
},
semantic::error::UnexpectedExpression,
semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression},
};
use super::transpiler::FunctionData;
@ -29,6 +29,8 @@ pub enum TranspileError {
LuaRuntimeError(#[from] LuaRuntimeError),
#[error(transparent)]
ConflictingFunctionNames(#[from] ConflictingFunctionNames),
#[error(transparent)]
InvalidFunctionArguments(#[from] InvalidFunctionArguments),
}
/// The result of a transpilation operation.
@ -143,30 +145,3 @@ impl LuaRuntimeError {
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConflictingFunctionNames {
pub definition: Span,
pub name: String,
}
impl Display for ConflictingFunctionNames {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
Message::new(
Severity::Error,
format!("the following function declaration conflicts with an existing function with name `{}`", self.name)
)
)?;
write!(
f,
"\n{}",
SourceCodeDisplay::new(&self.definition, Option::<u8>::None)
)
}
}
impl std::error::Error for ConflictingFunctionNames {}

View File

@ -14,4 +14,4 @@ mod transpiler;
#[doc(inline)]
pub use transpiler::Transpiler;
mod util;
pub mod util;

View File

@ -14,7 +14,7 @@ use crate::{
source_file::{SourceElement, Span},
Handler,
},
semantic::error::UnexpectedExpression,
semantic::error::{ConflictingFunctionNames, InvalidFunctionArguments, UnexpectedExpression},
syntax::syntax_tree::{
declaration::{Declaration, ImportItems},
expression::{Expression, FunctionCall, Primary},
@ -24,7 +24,7 @@ use crate::{
Statement,
},
},
transpile::error::{ConflictingFunctionNames, TranspileMissingFunctionDeclaration},
transpile::error::TranspileMissingFunctionDeclaration,
};
use super::error::{TranspileError, TranspileResult};
@ -103,7 +103,7 @@ impl Transpiler {
);
for identifier_span in always_transpile_functions {
self.get_or_transpile_function(&identifier_span, handler)?;
self.get_or_transpile_function(&identifier_span, None, handler)?;
}
Ok(())
@ -221,8 +221,9 @@ impl Transpiler {
fn get_or_transpile_function(
&mut self,
identifier_span: &Span,
arguments: Option<&[&Expression]>,
handler: &impl Handler<base::Error>,
) -> TranspileResult<String> {
) -> TranspileResult<(String, Option<BTreeMap<String, String>>)> {
let program_identifier = identifier_span.source_file().identifier();
let program_query = (
program_identifier.to_string(),
@ -337,6 +338,30 @@ impl Transpiler {
);
}
let parameters = {
let functions = self.functions.read().unwrap();
let function_data = functions
.get(&program_query)
.or_else(|| {
alias_query
.clone()
.and_then(|q| functions.get(&q).filter(|f| f.public))
})
.ok_or_else(|| {
let error = TranspileError::MissingFunctionDeclaration(
TranspileMissingFunctionDeclaration::from_context(
identifier_span.clone(),
&functions,
),
);
handler.receive(error.clone());
error
})?;
function_data.parameters.clone()
};
let function_location = {
let locations = self.function_locations.read().unwrap();
locations
.get(&program_query)
@ -351,7 +376,58 @@ impl Transpiler {
handler.receive(error.clone());
error
})
.map(|(s, _)| s.to_owned())
.map(|(s, _)| s.to_owned())?
};
let arg_count = arguments.iter().flat_map(|x| x.iter()).count();
if arg_count != parameters.len() {
let err = TranspileError::InvalidFunctionArguments(InvalidFunctionArguments {
expected: parameters.len(),
actual: arg_count,
span: identifier_span.clone(),
});
handler.receive(err.clone());
Err(err)
} else if arg_count > 0 {
let mut compiled_args = Vec::new();
let mut errs = Vec::new();
for expression in arguments.iter().flat_map(|x| x.iter()) {
let value = match expression {
Expression::Primary(Primary::FunctionCall(func)) => self
.transpile_function_call(func, handler)
.map(|cmd| match cmd {
Command::Raw(s) => s,
_ => unreachable!("Function call should always return a raw command"),
}),
Expression::Primary(Primary::Lua(lua)) => {
lua.eval_string(handler).map(Option::unwrap_or_default)
}
Expression::Primary(Primary::StringLiteral(string)) => {
Ok(string.str_content().to_string())
}
Expression::Primary(Primary::MacroStringLiteral(literal)) => {
Ok(literal.str_content())
}
};
match value {
Ok(value) => {
compiled_args.push(value);
}
Err(err) => {
compiled_args.push(String::new());
errs.push(err.clone());
}
}
}
if let Some(err) = errs.first() {
return Err(err.clone());
}
let function_args = parameters.into_iter().zip(compiled_args).collect();
Ok((function_location, Some(function_args)))
} else {
Ok((function_location, None))
}
}
fn transpile_function(
@ -456,8 +532,30 @@ impl Transpiler {
func: &FunctionCall,
handler: &impl Handler<base::Error>,
) -> TranspileResult<Command> {
let location = self.get_or_transpile_function(&func.identifier().span, handler)?;
Ok(Command::Raw(format!("function {location}")))
let arguments = func.arguments().as_ref().map(|l| {
l.elements()
.map(derive_more::Deref::deref)
.collect::<Vec<_>>()
});
let (location, arguments) =
self.get_or_transpile_function(&func.identifier().span, arguments.as_deref(), handler)?;
let mut function_call = format!("function {location}");
if let Some(arguments) = arguments {
use std::fmt::Write;
let arguments = arguments
.iter()
.map(|(ident, v)| {
format!(
r#"{macro_name}:"{escaped}""#,
macro_name = super::util::identifier_to_macro(ident),
escaped = super::util::escape_str(v)
)
})
.collect::<Vec<_>>()
.join(",");
write!(function_call, " {{{arguments}}}").unwrap();
}
Ok(Command::Raw(function_call))
}
fn transpile_execute_block(

View File

@ -1,3 +1,8 @@
//! Utility methods for transpiling
use chksum_md5 as md5;
use std::borrow::Cow;
fn normalize_program_identifier<S>(identifier: S) -> String
where
S: AsRef<str>,
@ -19,6 +24,8 @@ where
.join("/")
}
/// Calculate the identifier to import the function based on the current identifier and the import path
#[must_use]
pub fn calculate_import_identifier<S, T>(current_identifier: S, import_path: T) -> String
where
S: AsRef<str>,
@ -32,3 +39,53 @@ where
normalize_program_identifier(identifier_elements.join("/") + "/" + import_path.as_ref())
}
}
/// Escapes `"` and `\` in a string.
#[must_use]
pub fn escape_str(s: &str) -> Cow<str> {
if s.contains('"') || s.contains('\\') {
let mut escaped = String::with_capacity(s.len());
for c in s.chars() {
match c {
'"' => escaped.push_str("\\\""),
'\\' => escaped.push_str("\\\\"),
_ => escaped.push(c),
}
}
Cow::Owned(escaped)
} else {
Cow::Borrowed(s)
}
}
/// Transforms an identifier to a macro name that only contains `a-zA-Z0-9_`.
#[must_use]
pub fn identifier_to_macro(ident: &str) -> Cow<str> {
if ident
.chars()
.any(|c| !(c == '_' && c.is_ascii_alphanumeric()))
{
let new_ident = ident
.chars()
.filter(|c| *c == '_' || c.is_ascii_alphanumeric())
.collect::<String>();
let chksum = md5::hash(ident).to_hex_lowercase();
Cow::Owned(new_ident + "_" + &chksum[..8])
} else {
Cow::Borrowed(ident)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_str() {
assert_eq!(escape_str("Hello, world!"), "Hello, world!");
assert_eq!(escape_str(r#"Hello, "world"!"#), r#"Hello, \"world\"!"#);
assert_eq!(escape_str(r"Hello, \world\!"), r"Hello, \\world\\!");
}
}