Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 173 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ use crate::Reference;

use crate::errors::{OciDistributionError, Result};
use crate::token_cache::{RegistryOperation, RegistryToken, RegistryTokenType, TokenCache};
use futures_util::future;
use futures_util::future::Future;
use futures_util::stream::{self, StreamExt, TryStreamExt};
use futures_util::{future, FutureExt};
use http::HeaderValue;
use http_auth::{parser::ChallengeParser, ChallengeRef};
use olpc_cjson::CanonicalFormatter;
use reqwest::header::HeaderMap;
use reqwest::{RequestBuilder, Url};
use serde::Serialize;
use sha2::Digest;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tracing::{debug, trace, warn};
Expand Down Expand Up @@ -179,6 +182,7 @@ impl Config {
///
/// For true anonymous access, you can skip `auth()`. This is not recommended
/// unless you are sure that the remote registry does not require Oauth2.
#[derive(Clone)]
pub struct Client {
config: ClientConfig,
tokens: TokenCache,
Expand Down Expand Up @@ -987,6 +991,61 @@ impl Client {
))
}

/// Pushes a single chunk of a blob to a registry,
/// as part of a chunked blob upload.
///
/// Returns the URL location for the next chunk
async fn push_chunk_request(
self,
location: String,
image: Reference,
blob_data: Vec<u8>,
start_byte: usize,
offset: usize,
) -> Result<(String, usize)> {
if blob_data.is_empty() {
return Err(OciDistributionError::PushNoDataError);
};
let end_byte = if (start_byte - offset + self.push_chunk_size) < blob_data.len() {
start_byte - offset + self.push_chunk_size - 1
} else {
blob_data.len() - 1
};
let body = blob_data[(start_byte - offset)..(end_byte - offset) + 1].to_vec();
let mut headers = HeaderMap::new();
headers.insert(
"Content-Range",
format!("{}-{}", start_byte, end_byte).parse().unwrap(),
);
headers.insert("Content-Length", format!("{}", body.len()).parse().unwrap());
headers.insert("Content-Type", "application/octet-stream".parse().unwrap());

debug!(
?start_byte,
?end_byte,
blob_data_len = blob_data.len(),
body_len = body.len(),
?location,
?headers,
"Pushing chunk"
);

let res = RequestBuilderWrapper::from_client(&self, |client| client.patch(&location))
.apply_auth(&image, RegistryOperation::Push)?
.into_request_builder()
.headers(headers)
.body(body)
.send()
.await?;

// Returns location for next chunk and the start byte for the next range
Ok((
self.extract_location_header(&image, res, &reqwest::StatusCode::ACCEPTED)
.await?,
end_byte + 1,
))
}

/// Pushes the manifest for a specified image
///
/// Returns pullable manifest URL
Expand Down Expand Up @@ -1138,6 +1197,104 @@ impl Client {
"uploads/",
)
}

/// Creates an async blob upload session.
pub async fn stream_blob_upload(&self, image: &Reference) -> Result<StreamBlobWriter> {
StreamBlobWriter::new(self.clone(), image.clone()).await
}
}

/// A client for pushing OCI artifacts to a registry from a stream.
pub struct StreamBlobWriter {
client: Client,
start_byte: usize,
location: String,
image: Reference,
hasher: Sha256,
pending: Option<Pin<Box<dyn Future<Output = Result<(String, usize)>> + Send>>>,
}

impl StreamBlobWriter {
async fn new(client: Client, image: Reference) -> Result<Self> {
let location = client.begin_push_chunked_session(&image).await?;

Ok(Self {
client,
start_byte: 0,
location,
image,
hasher: Sha256::new(),
pending: None,
})
}

/// Finalizes the upload and returns the URL of the uploaded blob.
pub async fn finalize(self) -> Result<String> {
self.client
.end_push_chunked_session(
&self.location,
&self.image,
&format!("{:x}", self.hasher.finalize()),
)
.await
}
}

impl AsyncWrite for StreamBlobWriter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
let this = self.get_mut();

// Update hasher
this.hasher.update(buf);

this.pending = match this.pending.take() {
Some(mut pending) => match pending.poll_unpin(cx) {
Poll::Ready(Ok((location, next_start))) => {
this.start_byte = next_start;
this.location = location;
return Poll::Ready(Ok(buf.len()));
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
}
Poll::Pending => Some(pending),
},
None => Some(
this.client
.clone()
.push_chunk_request(
this.location.clone(),
this.image.clone(),
buf.into(),
this.start_byte,
this.start_byte,
)
.boxed(),
),
};
Poll::Pending
}

fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
match self.pending {
Some(_) => Poll::Pending,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this poll_unpin too while pending?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're "flushing" inside the pending future, as long as there's something there, it's not flushed – that's my thinking

None => Poll::Ready(Ok(())),
}
}

fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
todo!()
}
}

/// The OCI spec technically does not allow any codes but 200, 500, 401, and 404.
Expand Down Expand Up @@ -1328,6 +1485,20 @@ pub struct ClientConfig {
pub max_concurrent_download: usize,
}

impl Clone for ClientConfig {
fn clone(&self) -> Self {
Self {
protocol: self.protocol.clone(),
accept_invalid_hostnames: self.accept_invalid_hostnames.clone(),
accept_invalid_certificates: self.accept_invalid_certificates.clone(),
extra_root_certificates: self.extra_root_certificates.clone(),
platform_resolver: Some(Box::new(current_platform_resolver)),
max_concurrent_upload: self.max_concurrent_upload.clone(),
max_concurrent_download: self.max_concurrent_download.clone(),
}
}
}

impl Default for ClientConfig {
fn default() -> Self {
Self {
Expand Down
4 changes: 2 additions & 2 deletions src/token_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl fmt::Debug for RegistryToken {
}
}

#[derive(Debug)]
#[derive(Clone, Debug)]
pub(crate) enum RegistryTokenType {
Bearer(RegistryToken),
Basic(String, String),
Expand Down Expand Up @@ -57,7 +57,7 @@ pub enum RegistryOperation {
Pull,
}

#[derive(Default)]
#[derive(Clone, Default)]
pub(crate) struct TokenCache {
// (registry, repository, scope) -> (token, expiration)
tokens: BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>,
Expand Down