diff --git a/kernel/socket.h b/kernel/socket.h index 5683d7e1..45f88a2b 100644 --- a/kernel/socket.h +++ b/kernel/socket.h @@ -1,6 +1,7 @@ #pragma once #include "fs/fs.h" +#include "lock.h" #include "ring_buf.h" typedef struct unix_socket { @@ -22,8 +23,10 @@ typedef struct unix_socket { atomic_bool is_connected; file_description* connector_fd; - ring_buf to_connector_buf; - ring_buf to_acceptor_buf; + struct mutex_protected_buf { + ring_buf ring; + mutex lock; + } to_connector_buf, to_acceptor_buf; atomic_bool is_open_for_writing_to_connector; atomic_bool is_open_for_writing_to_acceptor; diff --git a/kernel/unix_socket.c b/kernel/unix_socket.c index 9d994df0..9866aefd 100644 --- a/kernel/unix_socket.c +++ b/kernel/unix_socket.c @@ -8,8 +8,8 @@ static void unix_socket_destroy_inode(struct inode* inode) { unix_socket* socket = (unix_socket*)inode; - ring_buf_destroy(&socket->to_connector_buf); - ring_buf_destroy(&socket->to_acceptor_buf); + ring_buf_destroy(&socket->to_connector_buf.ring); + ring_buf_destroy(&socket->to_acceptor_buf.ring); kfree(socket); } @@ -31,13 +31,13 @@ static bool is_open_for_reading(file_description* desc) { : socket->is_open_for_writing_to_acceptor; } -static ring_buf* buf_to_read(file_description* desc) { +static struct mutex_protected_buf* buf_to_read(file_description* desc) { unix_socket* socket = (unix_socket*)desc->inode; return is_connector(desc) ? &socket->to_connector_buf : &socket->to_acceptor_buf; } -static ring_buf* buf_to_write(file_description* desc) { +static struct mutex_protected_buf* buf_to_write(file_description* desc) { unix_socket* socket = (unix_socket*)desc->inode; return is_connector(desc) ? &socket->to_acceptor_buf : &socket->to_connector_buf; @@ -46,8 +46,8 @@ static ring_buf* buf_to_write(file_description* desc) { static bool read_should_unblock(file_description* desc) { if (!is_open_for_reading(desc)) return true; - ring_buf* buf = buf_to_read(desc); - return !ring_buf_is_empty(buf); + struct mutex_protected_buf* buf = buf_to_read(desc); + return !ring_buf_is_empty(&buf->ring); } static ssize_t unix_socket_read(file_description* desc, void* buffer, @@ -56,19 +56,19 @@ static ssize_t unix_socket_read(file_description* desc, void* buffer, if (!socket->is_connected) return -EINVAL; - ring_buf* buf = buf_to_read(desc); + struct mutex_protected_buf* buf = buf_to_read(desc); for (;;) { int rc = file_description_block(desc, read_should_unblock); if (IS_ERR(rc)) return rc; - mutex_lock(&socket->lock); - if (!ring_buf_is_empty(buf)) { - ssize_t nread = ring_buf_read(buf, buffer, count); - mutex_unlock(&socket->lock); + mutex_lock(&buf->lock); + if (!ring_buf_is_empty(&buf->ring)) { + ssize_t nread = ring_buf_read(&buf->ring, buffer, count); + mutex_unlock(&buf->lock); return nread; } - mutex_unlock(&socket->lock); + mutex_unlock(&buf->lock); if (!is_open_for_reading(desc)) return 0; @@ -83,8 +83,8 @@ static bool write_should_unblock(file_description* desc) { } else if (!socket->is_open_for_writing_to_connector) { return true; } - ring_buf* buf = buf_to_write(desc); - return !ring_buf_is_full(buf); + struct mutex_protected_buf* buf = buf_to_write(desc); + return !ring_buf_is_full(&buf->ring); } static ssize_t unix_socket_write(file_description* desc, const void* buffer, @@ -93,7 +93,7 @@ static ssize_t unix_socket_write(file_description* desc, const void* buffer, if (!socket->is_connected) return -ENOTCONN; - ring_buf* buf = buf_to_write(desc); + struct mutex_protected_buf* buf = buf_to_write(desc); for (;;) { int rc = file_description_block(desc, write_should_unblock); if (IS_ERR(rc)) @@ -106,13 +106,13 @@ static ssize_t unix_socket_write(file_description* desc, const void* buffer, return -EPIPE; } - mutex_lock(&socket->lock); - if (!ring_buf_is_full(buf)) { - ssize_t nwritten = ring_buf_write(buf, buffer, count); - mutex_unlock(&socket->lock); + mutex_lock(&buf->lock); + if (!ring_buf_is_full(&buf->ring)) { + ssize_t nwritten = ring_buf_write(&buf->ring, buffer, count); + mutex_unlock(&buf->lock); return nwritten; } - mutex_unlock(&socket->lock); + mutex_unlock(&buf->lock); } } @@ -156,14 +156,14 @@ unix_socket* unix_socket_create(void) { socket->is_open_for_writing_to_connector = true; socket->is_open_for_writing_to_acceptor = true; - int rc = ring_buf_init(&socket->to_acceptor_buf, PAGE_SIZE); + int rc = ring_buf_init(&socket->to_acceptor_buf.ring, PAGE_SIZE); if (IS_ERR(rc)) { kfree(socket); return ERR_PTR(rc); } - rc = ring_buf_init(&socket->to_connector_buf, PAGE_SIZE); + rc = ring_buf_init(&socket->to_connector_buf.ring, PAGE_SIZE); if (IS_ERR(rc)) { - ring_buf_destroy(&socket->to_acceptor_buf); + ring_buf_destroy(&socket->to_acceptor_buf.ring); kfree(socket); return ERR_PTR(rc); }