diff --git a/src/base/source_file.rs b/src/base/source_file.rs index a19d576..caf7475 100644 --- a/src/base/source_file.rs +++ b/src/base/source_file.rs @@ -145,7 +145,6 @@ impl SourceFile { } /// Represents a range of characters in a source file. -#[cfg_attr(feature = "serde", derive(serde::Deserialize))] #[derive(Clone, Getters, CopyGetters)] pub struct Span { /// Get the start byte index of the span. diff --git a/src/serde.rs b/src/serde.rs index d09f070..badfcd4 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -2,17 +2,25 @@ use std::{ collections::HashMap, + marker::PhantomData, sync::{Arc, LazyLock, Mutex, RwLock}, }; -use serde::{de::{self, Visitor}, ser::SerializeStruct, Deserialize, Serialize}; +use serde::{ + de::{self, Visitor}, + ser::SerializeStruct, + Deserialize, Serialize, +}; use crate::base::source_file::{SourceFile, Span}; +static DEDUPLICATE_SOURCE_FILES: LazyLock> = LazyLock::new(|| RwLock::new(false)); + static SERIALIZE_DATA: LazyLock> = LazyLock::new(|| Mutex::new(SerializeData::default())); -static DEDUPLICATE_SOURCE_FILES: LazyLock> = LazyLock::new(|| RwLock::new(false)); +static DESERIALIZE_DATA: LazyLock>> = + LazyLock::new(|| RwLock::new(None)); /// Wrapper to remove duplicate source file data during (de-)serialization #[derive(Debug)] @@ -39,6 +47,102 @@ where } } +impl<'de, T> Deserialize<'de> for SerdeWrapper +where + T: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "snake_case")] + enum Field { + Data, + SourceFiles, + } + + struct WrapperVisitor(PhantomData); + + impl<'de, T> Visitor<'de> for WrapperVisitor + where + T: Deserialize<'de>, + { + type Value = SerdeWrapper; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct SerdeWrapper") + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: de::SeqAccess<'de>, + { + let source_files: HashMap = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + *DESERIALIZE_DATA.write().unwrap() = Some(DeserializeData { + id_to_source_file: source_files + .into_iter() + .map(|(k, v)| (k, Arc::new(v))) + .collect(), + }); + let data = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + Ok(SerdeWrapper(data)) + } + + fn visit_map(self, mut map: V) -> Result + where + V: de::MapAccess<'de>, + { + let mut source_files: Option> = None; + let mut data = None; + + while let Some(key) = map.next_key()? { + match key { + Field::Data => { + if data.is_some() { + return Err(de::Error::duplicate_field("data")); + } + *DESERIALIZE_DATA.write().unwrap() = + source_files.as_ref().map(|source_files| DeserializeData { + id_to_source_file: source_files + .iter() + .map(|(&k, v)| (k, Arc::new(v.clone()))) + .collect(), + }); + data = Some(map.next_value()?); + } + Field::SourceFiles => { + if source_files.is_some() { + return Err(de::Error::duplicate_field("source_files")); + } + source_files = Some(map.next_value()?); + } + } + } + + let data = data.ok_or_else(|| de::Error::missing_field("data"))?; + Ok(SerdeWrapper(data)) + } + } + + *DEDUPLICATE_SOURCE_FILES.write().unwrap() = true; + *DESERIALIZE_DATA.write().unwrap() = None; + let res = deserializer.deserialize_struct( + "SerdeWrapper", + &["source_files", "data"], + WrapperVisitor(PhantomData::::default()), + ); + *DEDUPLICATE_SOURCE_FILES.write().unwrap() = false; + + res + } +} + /// Internally used for Serialization #[derive(Debug, Default)] struct SerializeData { @@ -91,3 +195,131 @@ impl Serialize for Span { s.end() } } + +#[derive(Debug, Default)] +struct DeserializeData { + id_to_source_file: HashMap>, +} + +impl<'de> Deserialize<'de> for Span { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "snake_case")] + enum Field { + Start, + End, + SourceFile, + } + + struct SpanVisitor; + + impl<'de> Visitor<'de> for SpanVisitor { + type Value = Span; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + if *DEDUPLICATE_SOURCE_FILES.read().unwrap() { + formatter.write_str("struct Span with deduplicated SourceFiles") + } else { + formatter.write_str("struct Span") + } + } + + fn visit_seq(self, mut seq: V) -> Result + where + V: serde::de::SeqAccess<'de>, + { + let start = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let end = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + let source_file = if *DEDUPLICATE_SOURCE_FILES.read().unwrap() { + DESERIALIZE_DATA + .read() + .unwrap() + .as_ref() + .ok_or_else(|| { + de::Error::custom("SourceFiles do not have been loaded yet") + })? + .id_to_source_file + .get( + &seq.next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?, + ) + .ok_or_else(|| de::Error::custom("invalid source_file id"))? + .clone() + } else { + Arc::new( + seq.next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?, + ) + }; + + Ok(Span::new(source_file, start, end) + .ok_or_else(|| de::Error::custom("Invalid data"))?) + } + + fn visit_map(self, mut map: V) -> Result + where + V: de::MapAccess<'de>, + { + let mut start = None; + let mut end = None; + let mut source_file = None; + + while let Some(key) = map.next_key()? { + match key { + Field::Start => { + if start.is_some() { + return Err(de::Error::duplicate_field("start")); + } + start = Some(map.next_value()?); + } + Field::End => { + if end.is_some() { + return Err(de::Error::duplicate_field("end")); + } + end = Some(map.next_value()?); + } + Field::SourceFile => { + if source_file.is_some() { + return Err(de::Error::duplicate_field("source_file")); + } + source_file = if *DEDUPLICATE_SOURCE_FILES.read().unwrap() { + Some( + DESERIALIZE_DATA + .read() + .unwrap() + .as_ref() + .ok_or_else(|| { + de::Error::custom( + "SourceFiles do not have been loaded yet", + ) + })? + .id_to_source_file + .get(&map.next_value()?) + .ok_or_else(|| de::Error::custom("invalid source_file id"))? + .clone(), + ) + } else { + Some(Arc::new(map.next_value()?)) + }; + } + } + } + let start = start.ok_or_else(|| de::Error::missing_field("start"))?; + let end = end.ok_or_else(|| de::Error::missing_field("end"))?; + let source_file = source_file.ok_or_else(|| de::Error::missing_field("source"))?; + + Ok(Span::new(source_file, start, end) + .ok_or_else(|| de::Error::custom("Invalid data"))?) + } + } + + deserializer.deserialize_struct("Span", &["start", "end", "source_file"], SpanVisitor) + } +}