Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace tokio lock with std lock in some sync scenarios #694

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions common_util/src/partitioned_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
num::NonZeroUsize,
sync::Arc,
sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard},
};

use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};

/// Simple partitioned `RwLock`
pub struct PartitionedRwLock<T> {
partitions: Vec<Arc<RwLock<T>>>,
Expand All @@ -28,16 +26,16 @@ impl<T> PartitionedRwLock<T> {
}
}

pub async fn read<K: Eq + Hash>(&self, key: &K) -> RwLockReadGuard<'_, T> {
pub fn read<K: Eq + Hash>(&self, key: &K) -> RwLockReadGuard<'_, T> {
let rwlock = self.get_partition(key);

rwlock.read().await
rwlock.read().unwrap()
}

pub async fn write<K: Eq + Hash>(&self, key: &K) -> RwLockWriteGuard<'_, T> {
pub fn write<K: Eq + Hash>(&self, key: &K) -> RwLockWriteGuard<'_, T> {
let rwlock = self.get_partition(key);

rwlock.write().await
rwlock.write().unwrap()
}

fn get_partition<K: Eq + Hash>(&self, key: &K) -> &RwLock<T> {
Expand Down Expand Up @@ -66,10 +64,10 @@ impl<T> PartitionedMutex<T> {
}
}

pub async fn lock<K: Eq + Hash>(&self, key: &K) -> MutexGuard<'_, T> {
pub fn lock<K: Eq + Hash>(&self, key: &K) -> MutexGuard<'_, T> {
let mutex = self.get_partition(key);

mutex.lock().await
mutex.lock().unwrap()
}

fn get_partition<K: Eq + Hash>(&self, key: &K) -> &Mutex<T> {
Expand All @@ -87,37 +85,37 @@ mod tests {

use super::*;

#[tokio::test]
async fn test_partitioned_rwlock() {
#[test]
fn test_partitioned_rwlock() {
let test_locked_map =
PartitionedRwLock::new(HashMap::new(), NonZeroUsize::new(10).unwrap());
let test_key = "test_key".to_string();
let test_value = "test_value".to_string();

{
let mut map = test_locked_map.write(&test_key).await;
let mut map = test_locked_map.write(&test_key);
map.insert(test_key.clone(), test_value.clone());
}

{
let map = test_locked_map.read(&test_key).await;
let map = test_locked_map.read(&test_key);
assert_eq!(map.get(&test_key).unwrap(), &test_value);
}
}

#[tokio::test]
async fn test_partitioned_mutex() {
#[test]
fn test_partitioned_mutex() {
let test_locked_map = PartitionedMutex::new(HashMap::new(), NonZeroUsize::new(10).unwrap());
let test_key = "test_key".to_string();
let test_value = "test_value".to_string();

{
let mut map = test_locked_map.lock(&test_key).await;
let mut map = test_locked_map.lock(&test_key);
map.insert(test_key.clone(), test_value.clone());
}

{
let map = test_locked_map.lock(&test_key).await;
let map = test_locked_map.lock(&test_key);
assert_eq!(map.get(&test_key).unwrap(), &test_value);
}
}
Expand Down
76 changes: 33 additions & 43 deletions components/object_store/src/mem_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ use std::{
hash::{Hash, Hasher},
num::NonZeroUsize,
ops::Range,
sync::Arc,
sync::{Arc, Mutex},
};

use async_trait::async_trait;
use bytes::Bytes;
use clru::{CLruCache, CLruCacheConfig, WeightScale};
use futures::stream::BoxStream;
use snafu::{OptionExt, Snafu};
use tokio::{io::AsyncWrite, sync::Mutex};
use tokio::io::AsyncWrite;
use upstream::{path::Path, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result};

use crate::ObjectStoreRef;
Expand Down Expand Up @@ -52,26 +52,26 @@ impl Partition {
}

impl Partition {
async fn get(&self, key: &str) -> Option<Bytes> {
let mut guard = self.inner.lock().await;
fn get(&self, key: &str) -> Option<Bytes> {
let mut guard = self.inner.lock().unwrap();
guard.get(key).cloned()
}

async fn peek(&self, key: &str) -> Option<Bytes> {
fn peek(&self, key: &str) -> Option<Bytes> {
// FIXME: actually, here write lock is not necessary.
let guard = self.inner.lock().await;
let guard = self.inner.lock().unwrap();
guard.peek(key).cloned()
}

async fn insert(&self, key: String, value: Bytes) {
let mut guard = self.inner.lock().await;
fn insert(&self, key: String, value: Bytes) {
let mut guard = self.inner.lock().unwrap();
// don't care error now.
_ = guard.put_with_weight(key, value);
}

#[cfg(test)]
async fn keys(&self) -> Vec<String> {
let guard = self.inner.lock().await;
fn keys(&self) -> Vec<String> {
let guard = self.inner.lock().unwrap();
guard
.iter()
.map(|(key, _)| key)
Expand Down Expand Up @@ -115,34 +115,32 @@ impl MemCache {
self.partitions[hasher.finish() as usize & self.partition_mask].clone()
}

async fn get(&self, key: &str) -> Option<Bytes> {
fn get(&self, key: &str) -> Option<Bytes> {
let partition = self.locate_partition(key);
partition.get(key).await
partition.get(key)
}

async fn peek(&self, key: &str) -> Option<Bytes> {
fn peek(&self, key: &str) -> Option<Bytes> {
let partition = self.locate_partition(key);
partition.peek(key).await
partition.peek(key)
}

async fn insert(&self, key: String, value: Bytes) {
fn insert(&self, key: String, value: Bytes) {
let partition = self.locate_partition(&key);
partition.insert(key, value).await;
partition.insert(key, value);
}

/// Give a description of the cache state.
#[cfg(test)]
async fn to_string(&self) -> String {
futures::future::join_all(
self.partitions
.iter()
.map(|part| async { part.keys().await.join(",") }),
)
.await
.into_iter()
.enumerate()
.map(|(part_no, keys)| format!("{part_no}: [{keys}]"))
.collect::<Vec<_>>()
.join("\n")
fn state_desc(&self) -> String {
self.partitions
.iter()
.map(|part| part.keys().join(","))
.into_iter()
.enumerate()
.map(|(part_no, keys)| format!("{part_no}: [{keys}]"))
.collect::<Vec<_>>()
.join("\n")
}
}

Expand Down Expand Up @@ -195,21 +193,21 @@ impl MemCacheStore {
// TODO(chenxiang): What if there are some overlapping range in cache?
// A request with range [5, 10) can also use [0, 20) cache
let cache_key = Self::cache_key(location, &range);
if let Some(bytes) = self.cache.get(&cache_key).await {
if let Some(bytes) = self.cache.get(&cache_key) {
return Ok(bytes);
}

// TODO(chenxiang): What if two threads reach here? It's better to
// pend one thread, and only let one to fetch data from underlying store.
let bytes = self.underlying_store.get_range(location, range).await?;
self.cache.insert(cache_key, bytes.clone()).await;
self.cache.insert(cache_key, bytes.clone());

Ok(bytes)
}

async fn get_range_with_ro_cache(&self, location: &Path, range: Range<usize>) -> Result<Bytes> {
let cache_key = Self::cache_key(location, &range);
if let Some(bytes) = self.cache.peek(&cache_key).await {
if let Some(bytes) = self.cache.peek(&cache_key) {
return Ok(bytes);
}

Expand Down Expand Up @@ -297,7 +295,7 @@ mod test {

use super::*;

async fn prepare_store(bits: usize, mem_cap: usize) -> MemCacheStore {
fn prepare_store(bits: usize, mem_cap: usize) -> MemCacheStore {
let local_path = tempdir().unwrap();
let local_store = Arc::new(LocalFileSystem::new_with_prefix(local_path.path()).unwrap());

Expand All @@ -309,7 +307,7 @@ mod test {
#[tokio::test]
async fn test_mem_cache_evict() {
// single partition
let store = prepare_store(0, 13).await;
let store = prepare_store(0, 13);

// write date
let location = Path::from("1.sst");
Expand All @@ -324,7 +322,6 @@ mod test {
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_some());

// get bytes from [5, 10), insert to cache
Expand All @@ -333,12 +330,10 @@ mod test {
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_some());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range5_10))
.await
.is_some());

// get bytes from [10, 15), insert to cache
Expand All @@ -351,24 +346,21 @@ mod test {
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_none());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range5_10))
.await
.is_some());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range10_15))
.await
.is_some());
}

#[tokio::test]
async fn test_mem_cache_partition() {
// 4 partitions
let store = prepare_store(2, 100).await;
let store = prepare_store(2, 100);
let location = Path::from("partition.sst");
store
.put(&location, Bytes::from_static(&[1; 1024]))
Expand All @@ -388,18 +380,16 @@ mod test {
1: [partition.sst-100-105]
2: []
3: [partition.sst-0-5]"#,
store.cache.as_ref().to_string().await
store.cache.as_ref().state_desc()
);

assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range0_5))
.await
.is_some());
assert!(store
.cache
.get(&MemCacheStore::cache_key(&location, &range100_105))
.await
.is_some());
}
}
9 changes: 4 additions & 5 deletions remote_engine_client/src/cached_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

//! Cached router

use std::collections::HashMap;
use std::{collections::HashMap, sync::RwLock};

use ceresdbproto::storage::{self, RequestContext};
use log::debug;
use router::RouterRef;
use snafu::{OptionExt, ResultExt};
use table_engine::remote::model::TableIdentifier;
use tokio::sync::RwLock;
use tonic::transport::Channel;

use crate::{channel::ChannelPool, config::Config, error::*};
Expand Down Expand Up @@ -40,7 +39,7 @@ impl CachedRouter {
pub async fn route(&self, table_ident: &TableIdentifier) -> Result<Channel> {
// Find in cache first.
let channel_opt = {
let cache = self.cache.read().await;
let cache = self.cache.read().unwrap();
cache.get(table_ident).cloned()
};

Expand All @@ -62,7 +61,7 @@ impl CachedRouter {
let channel = self.do_route(table_ident).await?;

{
let mut cache = self.cache.write().await;
let mut cache = self.cache.write().unwrap();
// Double check here, if still not found, we put it.
let channel_opt = cache.get(table_ident).cloned();
if channel_opt.is_none() {
Expand All @@ -81,7 +80,7 @@ impl CachedRouter {
}

pub async fn evict(&self, table_ident: &TableIdentifier) {
let mut cache = self.cache.write().await;
let mut cache = self.cache.write().unwrap();
let _ = cache.remove(table_ident);
}

Expand Down
7 changes: 3 additions & 4 deletions remote_engine_client/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

//! Channel pool

use std::collections::HashMap;
use std::{collections::HashMap, sync::RwLock};

use router::endpoint::Endpoint;
use snafu::ResultExt;
use tokio::sync::RwLock;
use tonic::transport::{Channel, Endpoint as TonicEndpoint};

use crate::{config::Config, error::*};
Expand All @@ -30,7 +29,7 @@ impl ChannelPool {

pub async fn get(&self, endpoint: &Endpoint) -> Result<Channel> {
{
let inner = self.channels.read().await;
let inner = self.channels.read().unwrap();
if let Some(channel) = inner.get(endpoint) {
return Ok(channel.clone());
}
Expand All @@ -40,7 +39,7 @@ impl ChannelPool {
.builder
.build(endpoint.clone().to_string().as_str())
.await?;
let mut inner = self.channels.write().await;
let mut inner = self.channels.write().unwrap();
// Double check here.
if let Some(channel) = inner.get(endpoint) {
return Ok(channel.clone());
Expand Down