use std::{
fmt::{self, Debug, Formatter},
pin::Pin,
task::{Context, Poll},
};
use futures_util::{stream::Stream, StreamExt};
use serde::de::DeserializeOwned;
use super::{ReceiverStream, Task};
use crate::error;
#[derive(Clone)]
pub struct Receiver<T: 'static> {
receiver: flume::r#async::RecvStream<'static, Result<T, error::Receiver>>,
task: Task<Result<(), error::AlreadyClosed>>,
}
impl<T> Debug for Receiver<T> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("Receiver")
.field("receiver", &"RecvStream")
.field("task", &self.task)
.finish()
}
}
impl<T> Receiver<T> {
pub(super) fn new(mut stream: ReceiverStream<T>) -> Self
where
T: DeserializeOwned + Send,
{
let (sender, receiver) = flume::unbounded();
let receiver = receiver.into_stream();
let task = Task::new(|mut shutdown| async move {
enum Message<T> {
Data(Result<T, error::Receiver>),
Close,
}
while let Some(message) = futures_util::select_biased! {
message = stream.next() => message.map(Message::Data),
shutdown = shutdown => shutdown.ok().map(|_| Message::Close),
complete => None,
} {
match message {
Message::Data(message) => {
let failed = message.is_err();
if sender.send(message).is_err() {
break;
}
if failed {
break;
}
}
Message::Close => {
stream.stop()?;
break;
}
}
}
Ok(())
});
Self { receiver, task }
}
pub async fn finish(&self) -> Result<(), error::AlreadyClosed> {
(&self.task).await?
}
pub async fn close(&self) -> Result<(), error::AlreadyClosed> {
self.task.close(()).await?
}
}
impl<T> Stream for Receiver<T> {
type Item = Result<T, error::Receiver>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_next_unpin(cx)
}
}