diff --git a/src/syntax/syntax_tree/expression.rs b/src/syntax/syntax_tree/expression.rs index e53a7e5..c290185 100644 --- a/src/syntax/syntax_tree/expression.rs +++ b/src/syntax/syntax_tree/expression.rs @@ -349,7 +349,7 @@ pub struct LuaCode { left_parenthesis: Punctuation, /// The arguments of the lua code. #[get = "pub"] - variables: Option>, + inputs: Option>, /// The right parenthesis of the lua code. #[get = "pub"] right_parenthesis: Punctuation, @@ -550,10 +550,7 @@ impl<'a> Parser<'a> { Delimiter::Parenthesis, ',', |parser| match parser.next_significant_token() { - Reading::Atomic(Token::Identifier(identifier)) => { - parser.forward(); - Ok(identifier) - } + Reading::Atomic(Token::Identifier(identifier)) => Ok(identifier), unexpected => { let err = Error::UnexpectedSyntax(UnexpectedSyntax { expected: syntax::error::SyntaxKind::Identifier, @@ -597,7 +594,7 @@ impl<'a> Parser<'a> { Ok(Primary::Lua(Box::new(LuaCode { lua_keyword, left_parenthesis: variables.open, - variables: variables.list, + inputs: variables.list, right_parenthesis: variables.close, left_brace: tree.open, code: tree.tree?, diff --git a/src/transpile/expression.rs b/src/transpile/expression.rs index 95122aa..e6688d6 100644 --- a/src/transpile/expression.rs +++ b/src/transpile/expression.rs @@ -296,7 +296,7 @@ impl Primary { Self::Lua(lua) => { cfg_if::cfg_if! { if #[cfg(feature = "lua")] { - lua.eval(&VoidHandler).map_or(false, |value| match value { + lua.eval(scope, &VoidHandler).map_or(false, |(value, _)| match value { mlua::Value::Boolean(_) => matches!(r#type, ValueType::Boolean), mlua::Value::Integer(_) => matches!(r#type, ValueType::Integer), mlua::Value::String(_) => matches!(r#type, ValueType::String), @@ -345,7 +345,7 @@ impl Primary { }) } Self::Lua(lua) => lua - .eval_comptime(&VoidHandler) + .eval_comptime(scope, &VoidHandler) .inspect_err(|err| { handler.receive(err.clone()); }) @@ -621,7 +621,7 @@ impl Transpiler { Primary::Lua(lua) => { #[expect(clippy::option_if_let_else)] - if let Some(value) = lua.eval_comptime(handler)? { + if let Some(value) = lua.eval_comptime(scope, handler)? { self.store_comptime_value(&value, target, lua, handler) } else { let err = TranspileError::MissingValue(MissingValue { @@ -1033,7 +1033,7 @@ impl Transpiler { Err(err) } }, - Primary::Lua(lua) => match lua.eval_comptime(handler)? { + Primary::Lua(lua) => match lua.eval_comptime(scope, handler)? { Some(ComptimeValue::String(value)) => Ok(( Vec::new(), ExtendedCondition::Runtime(Condition::Atom(value.into())), diff --git a/src/transpile/lua.rs b/src/transpile/lua.rs index c8fbc97..081f668 100644 --- a/src/transpile/lua.rs +++ b/src/transpile/lua.rs @@ -2,14 +2,15 @@ #[cfg(feature = "lua")] mod enabled { + use std::sync::Arc; + use mlua::{Lua, Value}; use crate::{ base::{self, source_file::SourceElement, Handler}, syntax::syntax_tree::expression::LuaCode, transpile::{ - error::{LuaRuntimeError, TranspileError, TranspileResult}, - expression::ComptimeValue, + Scope, VariableData, }, }; @@ -19,7 +20,11 @@ mod enabled { /// # Errors /// - If evaluation fails #[tracing::instrument(level = "debug", name = "eval_lua", skip_all, ret)] - pub fn eval(&self, handler: &impl Handler) -> TranspileResult { + pub fn eval( + &self, + scope: &Arc, + handler: &impl Handler, + ) -> TranspileResult<(mlua::Value, mlua::Lua)> { tracing::debug!("Evaluating Lua code"); let lua = Lua::new(); @@ -46,9 +51,17 @@ mod enabled { ) }; - self.add_globals(&lua).unwrap(); + if let Err(err) = self.add_globals(&lua, scope) { + let err = TranspileError::LuaRuntimeError(LuaRuntimeError::from_lua_err( + &err, + self.span(), + )); + handler.receive(crate::Error::from(err.clone())); + return Err(err); + } - lua.load(self.code()) + let res = lua + .load(self.code()) .set_name(name) .eval::() .map_err(|err| { @@ -56,7 +69,12 @@ mod enabled { TranspileError::from(LuaRuntimeError::from_lua_err(&err, self.span())); handler.receive(crate::Error::from(err.clone())); err - }) + }); + + res.map(|v| { + tracing::debug!("Lua code evaluated successfully"); + (v, lua) + }) } /// Evaluates the Lua code and returns the resulting [`ComptimeValue`]. @@ -66,17 +84,16 @@ mod enabled { #[tracing::instrument(level = "debug", name = "eval_lua", skip_all, ret)] pub fn eval_comptime( &self, + scope: &Arc, handler: &impl Handler, ) -> TranspileResult> { - let lua_result = self.eval(handler)?; + // required to keep the lua instance alive + let (lua_result, _lua) = self.eval(scope, handler)?; self.handle_lua_result(lua_result, handler) - .inspect_err(|err| { - handler.receive(err.clone()); - }) } - fn add_globals(&self, lua: &Lua) -> mlua::Result<()> { + fn add_globals(&self, lua: &Lua, scope: &Arc) -> mlua::Result<()> { let globals = lua.globals(); let location = { @@ -86,6 +103,32 @@ mod enabled { }; globals.set("shu_location", location.to_string_lossy())?; + if let Some(inputs) = self.inputs() { + for x in inputs.elements() { + let name = x.span.str(); + let value = match scope.get_variable(name).as_deref() { + Some(VariableData::MacroParameter { macro_name, .. }) => { + Value::String(lua.create_string(format!("$({macro_name})"))?) + } + Some(VariableData::ScoreboardValue { objective, target }) => { + let table = lua.create_table()?; + table.set("objective", lua.create_string(objective)?)?; + table.set("target", lua.create_string(target)?)?; + Value::Table(table) + } + Some(VariableData::BooleanStorage { storage_name, path }) => { + let table = lua.create_table()?; + table.set("storage", lua.create_string(storage_name)?)?; + table.set("path", lua.create_string(path)?)?; + Value::Table(table) + } + Some(_) => todo!("allow other types"), + None => todo!("throw correct error"), + }; + globals.set(name, value)?; + } + } + Ok(()) } @@ -145,7 +188,11 @@ mod disabled { /// # Errors /// - Always, as the lua feature is disabled #[tracing::instrument(level = "debug", name = "eval_lua", skip_all, ret)] - pub fn eval(&self, handler: &impl Handler) -> TranspileResult<()> { + pub fn eval( + &self, + scope: &Arc, + handler: &impl Handler, + ) -> TranspileResult<()> { handler.receive(TranspileError::LuaDisabled); tracing::error!("Lua code evaluation is disabled"); Err(TranspileError::LuaDisabled) @@ -158,6 +205,7 @@ mod disabled { /// - If Lua code evaluation is disabled. pub fn eval_comptime( &self, + scope: &Arc, handler: &impl Handler, ) -> TranspileResult> { handler.receive(TranspileError::LuaDisabled); diff --git a/src/transpile/transpiler.rs b/src/transpile/transpiler.rs index 4fdbac8..1594a70 100644 --- a/src/transpile/transpiler.rs +++ b/src/transpile/transpiler.rs @@ -439,7 +439,7 @@ impl Transpiler { for expression in arguments.iter().flat_map(|x| x.iter()) { let value = match expression { Expression::Primary(Primary::Lua(lua)) => { - lua.eval_comptime(handler).and_then(|val| match val { + lua.eval_comptime(scope, handler).and_then(|val| match val { Some(ComptimeValue::MacroString(s)) => Ok(Parameter::Static(s)), Some(val) => Ok(Parameter::Static(val.to_string().into())), None => { @@ -782,7 +782,7 @@ impl Transpiler { Expression::Primary(Primary::MacroStringLiteral(string)) => { Ok(vec![Command::UsesMacro(string.into())]) } - Expression::Primary(Primary::Lua(code)) => match code.eval_comptime(handler)? { + Expression::Primary(Primary::Lua(code)) => match code.eval_comptime(scope, handler)? { Some(ComptimeValue::String(cmd)) => Ok(vec![Command::Raw(cmd)]), Some(ComptimeValue::MacroString(cmd)) => Ok(vec![Command::UsesMacro(cmd)]), Some(ComptimeValue::Boolean(_) | ComptimeValue::Integer(_)) => {