From 61b8f1ffb9034f3d4d2706130023aa24ce651fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20H=C3=B6lting?= <87192362+moritz-hoelting@users.noreply.github.com> Date: Fri, 20 Sep 2024 14:55:48 +0200 Subject: [PATCH] improve lua integration by allowing more flexible return types and introducing globals --- Cargo.toml | 1 + src/base/log.rs | 2 +- src/base/source_file.rs | 6 ++++ src/transpile/error.rs | 14 ++++++++ src/transpile/lua.rs | 72 +++++++++++++++++++++++++++++-------- src/transpile/transpiler.rs | 2 +- 6 files changed, 81 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c6bfcc7..e4173ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ getset = "0.1.2" itertools = "0.13.0" mlua = { version = "0.9.7", features = ["lua54", "vendored"], optional = true } path-absolutize = "3.1.1" +pathdiff = "0.2.1" serde = { version = "1.0.197", features = ["derive", "rc"], optional = true } shulkerbox = { git = "https://github.com/moritz-hoelting/shulkerbox", default-features = false, optional = true, rev = "aff342a64a94981af942223345b5a5f105212957" } strsim = "0.11.1" diff --git a/src/base/log.rs b/src/base/log.rs index a7b853e..d009754 100644 --- a/src/base/log.rs +++ b/src/base/log.rs @@ -291,7 +291,7 @@ fn write_error_line( (line_number == start_line && index >= start_location.column) || (line_number == end_line && (index + 1) - < end_location + <= end_location .map_or(usize::MAX, |end_location| end_location.column)) || (line_number > start_line && line_number < end_line) } else { diff --git a/src/base/source_file.rs b/src/base/source_file.rs index cfcd32f..9de2123 100644 --- a/src/base/source_file.rs +++ b/src/base/source_file.rs @@ -136,6 +136,12 @@ impl SourceFile { None } } + + /// Get the relative path of the source file from the current working directory. + #[must_use] + pub fn path_relative(&self) -> Option { + pathdiff::diff_paths(&self.path, std::env::current_dir().ok()?) + } } /// Represents a range of characters in a source file. diff --git a/src/transpile/error.rs b/src/transpile/error.rs index d003250..b383c3b 100644 --- a/src/transpile/error.rs +++ b/src/transpile/error.rs @@ -128,6 +128,20 @@ impl Display for LuaRuntimeError { impl std::error::Error for LuaRuntimeError {} +#[cfg(feature = "lua")] +impl LuaRuntimeError { + pub fn from_lua_err(err: &mlua::Error, span: Span) -> Self { + let err_string = err.to_string(); + Self { + error_message: err_string + .strip_prefix("runtime error: ") + .unwrap_or(&err_string) + .to_string(), + code_block: span, + } + } +} + /// An error that occurs when a function declaration is missing. #[derive(Debug, Clone, PartialEq, Eq)] pub struct UnexpectedExpression(pub Expression); diff --git a/src/transpile/lua.rs b/src/transpile/lua.rs index 5567cf8..b98e8cb 100644 --- a/src/transpile/lua.rs +++ b/src/transpile/lua.rs @@ -2,7 +2,7 @@ #[cfg(feature = "lua")] mod enabled { - use mlua::Lua; + use mlua::{Lua, Value}; use crate::{ base::{self, source_file::SourceElement, Handler}, @@ -16,7 +16,10 @@ mod enabled { /// # Errors /// - If Lua code evaluation is disabled. #[tracing::instrument(level = "debug", name = "eval_lua", skip_all, ret)] - pub fn eval_string(&self, handler: &impl Handler) -> TranspileResult { + pub fn eval_string( + &self, + handler: &impl Handler, + ) -> TranspileResult> { tracing::debug!("Evaluating Lua code"); let lua = Lua::new(); @@ -24,7 +27,7 @@ mod enabled { let name = { let span = self.span(); let file = span.source_file(); - let path = file.path(); + let path = file.path_relative().unwrap_or_else(|| file.path().clone()); let start = span.start_location(); let end = span.end_location().unwrap_or_else(|| { @@ -43,24 +46,62 @@ mod enabled { ) }; + self.add_globals(&lua).unwrap(); + let lua_result = lua .load(self.code()) .set_name(name) - .eval::() + .eval::() .map_err(|err| { - let err_string = err.to_string(); - let err = TranspileError::from(LuaRuntimeError { - error_message: err_string - .strip_prefix("runtime error: ") - .unwrap_or(&err_string) - .to_string(), - code_block: self.span(), - }); + let err = + TranspileError::from(LuaRuntimeError::from_lua_err(&err, self.span())); handler.receive(crate::Error::from(err.clone())); err })?; - Ok(lua_result) + self.handle_lua_result(dbg!(lua_result)).map_err(|err| { + handler.receive(err.clone()); + err + }) + } + + fn add_globals(&self, lua: &Lua) -> mlua::Result<()> { + let globals = lua.globals(); + + let location = { + let span = self.span(); + let file = span.source_file(); + file.path_relative().unwrap_or_else(|| file.path().clone()) + }; + globals.set("shu_location", location.to_string_lossy())?; + + Ok(()) + } + + fn handle_lua_result(&self, value: Value) -> TranspileResult> { + match value { + Value::Nil => Ok(None), + Value::String(s) => Ok(Some(s.to_string_lossy().into_owned())), + Value::Integer(i) => Ok(Some(i.to_string())), + Value::Number(n) => Ok(Some(n.to_string())), + Value::Function(f) => self.handle_lua_result(f.call(()).map_err(|err| { + TranspileError::LuaRuntimeError(LuaRuntimeError::from_lua_err( + &err, + self.span(), + )) + })?), + Value::Boolean(_) + | Value::Error(_) + | Value::Table(_) + | Value::Thread(_) + | Value::UserData(_) + | Value::LightUserData(_) => { + Err(TranspileError::LuaRuntimeError(LuaRuntimeError { + code_block: self.span(), + error_message: format!("invalid return type {}", value.type_name()), + })) + } + } } } } @@ -79,7 +120,10 @@ mod disabled { /// /// # Errors /// - If Lua code evaluation is disabled. - pub fn eval_string(&self, handler: &impl Handler) -> TranspileResult { + pub fn eval_string( + &self, + handler: &impl Handler, + ) -> TranspileResult> { handler.receive(TranspileError::LuaDisabled); tracing::error!("Lua code evaluation is disabled"); Err(TranspileError::LuaDisabled) diff --git a/src/transpile/transpiler.rs b/src/transpile/transpiler.rs index 64033a7..6e053e6 100644 --- a/src/transpile/transpiler.rs +++ b/src/transpile/transpiler.rs @@ -355,7 +355,7 @@ impl Transpiler { Ok(Some(Command::Raw(string.str_content().to_string()))) } Expression::Primary(Primary::Lua(code)) => { - Ok(Some(Command::Raw(code.eval_string(handler)?))) + Ok(code.eval_string(handler)?.map(Command::Raw)) } }, Statement::Block(_) => {