Skip to content

feat!: Make blobs more cheaply cloneable by by giving it an Inner #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 41 additions & 34 deletions src/net_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,21 @@ impl Default for GcState {
}
}

#[derive(Debug, Clone)]
pub struct Blobs<S> {
#[derive(Debug)]
struct BlobsInner<S> {
rt: LocalPoolHandle,
pub(crate) store: S,
events: EventSender,
downloader: Downloader,
#[cfg(feature = "rpc")]
batches: Arc<tokio::sync::Mutex<BlobBatches>>,
endpoint: Endpoint,
gc_state: Arc<std::sync::Mutex<GcState>>,
gc_state: std::sync::Mutex<GcState>,
#[cfg(feature = "rpc")]
pub(crate) rpc_handler: Arc<std::sync::OnceLock<crate::rpc::RpcHandler>>,
batches: tokio::sync::Mutex<BlobBatches>,
}

#[derive(Debug, Clone)]
pub struct Blobs<S> {
inner: Arc<BlobsInner<S>>,
}

/// Keeps track of all the currently active batch operations of the blobs api.
Expand Down Expand Up @@ -178,40 +181,44 @@ impl<S: crate::store::Store> Blobs<S> {
endpoint: Endpoint,
) -> Self {
Self {
rt,
store,
events,
downloader,
endpoint,
#[cfg(feature = "rpc")]
batches: Default::default(),
gc_state: Default::default(),
#[cfg(feature = "rpc")]
rpc_handler: Default::default(),
inner: Arc::new(BlobsInner {
rt,
store,
events,
downloader,
endpoint,
#[cfg(feature = "rpc")]
batches: Default::default(),
gc_state: Default::default(),
}),
}
}

pub fn store(&self) -> &S {
&self.store
&self.inner.store
}

pub fn events(&self) -> &EventSender {
&self.inner.events
}

pub fn rt(&self) -> &LocalPoolHandle {
&self.rt
&self.inner.rt
}

pub fn downloader(&self) -> &Downloader {
&self.downloader
&self.inner.downloader
}

pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
&self.inner.endpoint
}

/// Add a callback that will be called before the garbage collector runs.
///
/// This can only be called before the garbage collector has started, otherwise it will return an error.
pub fn add_protected(&self, cb: ProtectCb) -> Result<()> {
let mut state = self.gc_state.lock().unwrap();
let mut state = self.inner.gc_state.lock().unwrap();
match &mut *state {
GcState::Initial(cbs) => {
cbs.push(cb);
Expand All @@ -225,7 +232,7 @@ impl<S: crate::store::Store> Blobs<S> {

/// Start garbage collection with the given settings.
pub fn start_gc(&self, config: GcConfig) -> Result<()> {
let mut state = self.gc_state.lock().unwrap();
let mut state = self.inner.gc_state.lock().unwrap();
let protected = match state.deref_mut() {
GcState::Initial(items) => std::mem::take(items),
GcState::Started(_) => bail!("gc already started"),
Expand All @@ -241,17 +248,17 @@ impl<S: crate::store::Store> Blobs<S> {
set
}
};
let store = self.store.clone();
let store = self.store().clone();
let run = self
.rt
.rt()
.spawn(move || async move { store.gc_run(config, protected_cb).await });
*state = GcState::Started(Some(run));
Ok(())
}

#[cfg(feature = "rpc")]
pub(crate) async fn batches(&self) -> tokio::sync::MutexGuard<'_, BlobBatches> {
self.batches.lock().await
self.inner.batches.lock().await
}

pub(crate) async fn download(
Expand All @@ -268,7 +275,7 @@ impl<S: crate::store::Store> Blobs<S> {
mode,
} = req;
let hash_and_format = HashAndFormat { hash, format };
let temp_tag = self.store.temp_tag(hash_and_format);
let temp_tag = self.store().temp_tag(hash_and_format);
let stats = match mode {
DownloadMode::Queued => {
self.download_queued(endpoint, hash_and_format, nodes, progress.clone())
Expand All @@ -283,10 +290,10 @@ impl<S: crate::store::Store> Blobs<S> {
progress.send(DownloadProgress::AllDone(stats)).await.ok();
match tag {
SetTagOption::Named(tag) => {
self.store.set_tag(tag, Some(hash_and_format)).await?;
self.store().set_tag(tag, Some(hash_and_format)).await?;
}
SetTagOption::Auto => {
self.store.create_tag(hash_and_format).await?;
self.store().create_tag(hash_and_format).await?;
}
}
drop(temp_tag);
Expand Down Expand Up @@ -316,7 +323,7 @@ impl<S: crate::store::Store> Blobs<S> {
let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some());
anyhow::ensure!(can_download, "no way to reach a node for download");
let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress);
let handle = self.downloader.queue(req).await;
let handle = self.downloader().queue(req).await;
let stats = handle.await?;
Ok(stats)
}
Expand All @@ -334,7 +341,7 @@ impl<S: crate::store::Store> Blobs<S> {
let mut nodes_iter = nodes.into_iter();
'outer: loop {
match crate::get::db::get_to_db_in_steps(
self.store.clone(),
self.store().clone(),
hash_and_format,
progress.clone(),
)
Expand Down Expand Up @@ -393,9 +400,9 @@ impl<S: crate::store::Store> Blobs<S> {

impl<S: crate::store::Store> ProtocolHandler for Blobs<S> {
fn accept(&self, conn: Connecting) -> BoxedFuture<Result<()>> {
let db = self.store.clone();
let events = self.events.clone();
let rt = self.rt.clone();
let db = self.store().clone();
let events = self.events().clone();
let rt = self.rt().clone();

Box::pin(async move {
crate::provider::handle_connection(conn.await?, db, events, rt).await;
Expand All @@ -404,7 +411,7 @@ impl<S: crate::store::Store> ProtocolHandler for Blobs<S> {
}

fn shutdown(&self) -> BoxedFuture<()> {
let store = self.store.clone();
let store = self.store().clone();
Box::pin(async move {
store.shutdown().await;
})
Expand Down
30 changes: 20 additions & 10 deletions src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

use std::{
io,
ops::Deref,
sync::{Arc, Mutex},
};

use anyhow::anyhow;
use client::{
blobs::{self, BlobInfo, BlobStatus, IncompleteBlobInfo, WrapOption},
blobs::{BlobInfo, BlobStatus, IncompleteBlobInfo, MemClient, WrapOption},
tags::TagInfo,
MemConnector,
};
Expand Down Expand Up @@ -62,13 +63,8 @@ const RPC_BLOB_GET_CHANNEL_CAP: usize = 2;

impl<D: crate::store::Store> Blobs<D> {
/// Get a client for the blobs protocol
pub fn client(&self) -> blobs::MemClient {
let client = self
.rpc_handler
.get_or_init(|| RpcHandler::new(self))
.client
.clone();
blobs::Client::new(client)
pub fn client(&self) -> RpcHandler {
RpcHandler::new(self)
}

/// Handle an RPC request
Expand Down Expand Up @@ -874,20 +870,34 @@ impl<D: crate::store::Store> Blobs<D> {
}
}

/// A rpc handler for the blobs rpc protocol
///
/// This struct contains both a task that handles rpc requests and a client
/// that can be used to send rpc requests. Dropping it will stop the handler task,
/// so you need to put it somewhere where it will be kept alive.
#[derive(Debug)]
pub(crate) struct RpcHandler {
pub struct RpcHandler {
/// Client to hand out
client: RpcClient<RpcService, MemConnector>,
client: MemClient,
/// Handler task
_handler: AbortOnDropHandle<()>,
}

impl Deref for RpcHandler {
type Target = MemClient;

fn deref(&self) -> &Self::Target {
&self.client
}
}

impl RpcHandler {
fn new<D: crate::store::Store>(blobs: &Blobs<D>) -> Self {
let blobs = blobs.clone();
let (listener, connector) = quic_rpc::transport::flume::channel(1);
let listener = RpcServer::new(listener);
let client = RpcClient::new(connector);
let client = MemClient::new(client);
let _handler = listener
.spawn_accept_loop(move |req, chan| blobs.clone().handle_rpc_request(req, chan));
Self { client, _handler }
Expand Down
7 changes: 1 addition & 6 deletions tests/blobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@ async fn blobs_gc_protected() -> TestResult<()> {
let pool = LocalPool::default();
let endpoint = Endpoint::builder().bind().await?;
let blobs = Blobs::memory().build(pool.handle(), &endpoint);
let client: iroh_blobs::rpc::client::blobs::Client<
quic_rpc::transport::flume::FlumeConnector<
iroh_blobs::rpc::proto::Response,
iroh_blobs::rpc::proto::Request,
>,
> = blobs.clone().client();
let client = blobs.clone().client();
let h1 = client.add_bytes(b"test".to_vec()).await?;
let protected = Arc::new(Mutex::new(Vec::new()));
blobs.add_protected(Box::new({
Expand Down
6 changes: 3 additions & 3 deletions tests/gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use iroh::{protocol::Router, Endpoint, NodeAddr, NodeId};
use iroh_blobs::{
hashseq::HashSeq,
net_protocol::Blobs,
rpc::client::{blobs, tags},
rpc::{client::tags, RpcHandler},
store::{
bao_tree, BaoBatchWriter, ConsistencyCheckProgress, EntryStatus, GcConfig, MapEntryMut,
MapMut, ReportLevel, Store,
Expand Down Expand Up @@ -66,8 +66,8 @@ impl<S: Store> Node<S> {
}

/// Returns an in-memory blobs client
pub fn blobs(&self) -> blobs::MemClient {
self.blobs.clone().client()
pub fn blobs(&self) -> RpcHandler {
self.blobs.client()
}

/// Returns an in-memory tags client
Expand Down
Loading