use std::{
collections::HashMap,
fmt::{self, Debug},
marker::PhantomData,
sync::Mutex,
};
use futures_util::StreamExt;
use serde::{
de::{self, Error as SeError, Visitor},
ser::SerializeTuple,
Deserialize, Deserializer, Serialize,
};
use zbus::{
proxy::SignalStream,
zvariant::{ObjectPath, Type, Value},
};
use crate::{desktop::HandleToken, proxy::Proxy, Error};
#[derive(Debug, Type)]
#[zvariant(signature = "(ua{sv})")]
pub enum Response<T>
where
T: for<'de> Deserialize<'de> + Type,
{
Ok(T),
Err(ResponseError),
}
#[cfg(feature = "backend")]
impl<T> Response<T>
where
T: for<'de> Deserialize<'de> + Type,
{
pub fn ok(inner: T) -> Self {
Self::Ok(inner)
}
pub fn cancelled() -> Self {
Self::Err(ResponseError::Cancelled)
}
pub fn other() -> Self {
Self::Err(ResponseError::Other)
}
}
impl<'de, T> Deserialize<'de> for Response<T>
where
T: for<'d> Deserialize<'d> + Type,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ResponseVisitor<T>(PhantomData<fn() -> (ResponseType, T)>);
impl<'de, T> Visitor<'de> for ResponseVisitor<T>
where
T: Deserialize<'de>,
{
type Value = (ResponseType, Option<T>);
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"a tuple composed of the response status along with the response"
)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let type_: ResponseType = seq.next_element()?.ok_or_else(|| A::Error::custom(
"Failed to deserialize the response. Expected a numeric (u) value as the first item of the returned tuple",
))?;
if type_ == ResponseType::Success {
let data: T = seq.next_element()?.ok_or_else(|| A::Error::custom(
"Failed to deserialize the response. Expected a vardict (a{sv}) with the returned results",
))?;
Ok((type_, Some(data)))
} else {
Ok((type_, None))
}
}
}
let visitor = ResponseVisitor::<T>(PhantomData);
let response: (ResponseType, Option<T>) = deserializer.deserialize_tuple(2, visitor)?;
Ok(response.into())
}
}
impl<T> Serialize for Response<T>
where
T: for<'de> Deserialize<'de> + Serialize + Type,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_tuple(2)?;
match self {
Self::Err(err) => {
map.serialize_element(&ResponseType::from(*err))?;
map.serialize_element(&HashMap::<&str, Value<'_>>::new())?;
}
Self::Ok(response) => {
map.serialize_element(&ResponseType::Success)?;
map.serialize_element(response)?;
}
};
map.end()
}
}
#[doc(hidden)]
impl<T> From<(ResponseType, Option<T>)> for Response<T>
where
T: for<'de> Deserialize<'de> + Type,
{
fn from(f: (ResponseType, Option<T>)) -> Self {
match f.0 {
ResponseType::Success => {
Response::Ok(f.1.expect("Expected a valid response, found nothing."))
}
ResponseType::Cancelled => Response::Err(ResponseError::Cancelled),
ResponseType::Other => Response::Err(ResponseError::Other),
}
}
}
#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone)]
pub enum ResponseError {
Cancelled,
Other,
}
impl std::error::Error for ResponseError {}
impl std::fmt::Display for ResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Cancelled => f.write_str("Cancelled"),
Self::Other => f.write_str("Other"),
}
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Type)]
#[doc(hidden)]
enum ResponseType {
Success = 0,
Cancelled = 1,
Other = 2,
}
#[doc(hidden)]
impl From<ResponseError> for ResponseType {
fn from(err: ResponseError) -> Self {
match err {
ResponseError::Other => Self::Other,
ResponseError::Cancelled => Self::Cancelled,
}
}
}
#[doc(alias = "org.freedesktop.portal.Request")]
pub struct Request<T>(
Proxy<'static>,
SignalStream<'static>,
Mutex<Option<Result<T, Error>>>,
PhantomData<T>,
)
where
T: for<'de> Deserialize<'de> + Type + Debug;
impl<T> Request<T>
where
T: for<'de> Deserialize<'de> + Type + Debug,
{
pub(crate) async fn new<P>(path: P) -> Result<Request<T>, Error>
where
P: TryInto<ObjectPath<'static>>,
P::Error: Into<zbus::Error>,
{
let proxy = Proxy::new_desktop_with_path("org.freedesktop.portal.Request", path).await?;
let stream = proxy.receive_signal("Response").await?;
Ok(Self(proxy, stream, Default::default(), PhantomData))
}
pub(crate) async fn from_unique_name(handle_token: &HandleToken) -> Result<Request<T>, Error> {
let path =
Proxy::unique_name("/org/freedesktop/portal/desktop/request", handle_token).await?;
#[cfg(feature = "tracing")]
tracing::info!("Creating a org.freedesktop.portal.Request {}", path);
Self::new(path).await
}
pub(crate) async fn prepare_response(&mut self) -> Result<(), Error> {
let message = self.1.next().await.ok_or(Error::NoResponse)?;
#[cfg(feature = "tracing")]
tracing::info!("Received signal 'Response' on '{}'", self.0.interface());
let response = match message.body().deserialize::<Response<T>>()? {
Response::Err(e) => Err(e.into()),
Response::Ok(r) => Ok(r),
};
#[cfg(feature = "tracing")]
tracing::debug!("Received response {:#?}", response);
let r = response as Result<T, Error>;
*self.2.get_mut().unwrap() = Some(r);
Ok(())
}
pub fn response(&self) -> Result<T, Error> {
self.2.lock().unwrap().take().unwrap()
}
#[doc(alias = "Close")]
pub async fn close(&self) -> Result<(), Error> {
self.0.call("Close", &()).await
}
pub(crate) fn path(&self) -> &ObjectPath<'_> {
self.0.path()
}
}
impl<T> Debug for Request<T>
where
T: for<'de> Deserialize<'de> + Type + Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Request")
.field(&self.path().as_str())
.finish()
}
}
#[cfg(test)]
mod tests {
use zbus::zvariant::Value;
use super::*;
#[test]
fn response_signature() {
use crate::desktop::account::UserInformation;
assert_eq!(
<(ResponseType, HashMap<&str, Value<'_>>)>::signature(),
Response::<()>::signature(),
);
assert_eq!(
<(ResponseType, UserInformation)>::signature(),
Response::<UserInformation>::signature(),
);
assert_eq!(Response::<()>::signature(), "(ua{sv})");
}
}