Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor onchain OHLC ws endpoint #55

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pragma-common/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use chrono::{NaiveDateTime, Timelike};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

#[derive(Default, Debug, Deserialize, ToSchema, Clone, Copy)]
#[derive(Default, Debug, Serialize, Deserialize, ToSchema, Clone, Copy)]
pub enum AggregationMode {
#[serde(rename = "median")]
#[default]
Expand All @@ -13,7 +13,7 @@ pub enum AggregationMode {
Twap,
}

#[derive(Default, Debug, Deserialize, ToSchema, Clone, Copy)]
#[derive(Default, Debug, Serialize, Deserialize, ToSchema, Clone, Copy)]
pub enum Network {
#[serde(rename = "testnet")]
#[default]
Expand All @@ -32,7 +32,7 @@ pub enum DataType {
}

// Supported Aggregation Intervals
#[derive(Default, Debug, Deserialize, ToSchema, Clone, Copy)]
#[derive(Default, Debug, Serialize, Deserialize, ToSchema, Clone, Copy)]
pub enum Interval {
#[serde(rename = "1min")]
#[default]
Expand Down
223 changes: 162 additions & 61 deletions pragma-node/src/handlers/entries/get_onchain/ohlc.rs
Original file line number Diff line number Diff line change
@@ -1,99 +1,200 @@
use std::time::Duration;

use axum::extract::{Query, State};
use axum::extract::State;
use axum::response::IntoResponse;
use pragma_entities::InfraError;
use serde::{Deserialize, Serialize};
use serde_json::json;

use pragma_common::types::{Interval, Network};
use tokio::time::interval;

use crate::handlers::entries::utils::currency_pair_to_pair_id;
use crate::handlers::entries::GetOnchainOHLCParams;
use crate::handlers::entries::utils::is_onchain_existing_pair;
use crate::infra::repositories::entry_repository::OHLCEntry;
use crate::infra::repositories::onchain_repository::get_ohlc;
use crate::utils::PathExtractor;
use crate::AppState;

use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};

pub const WS_UPDATING_INTERVAL_IN_SECONDS: u64 = 10;
#[derive(Default, Debug, Serialize, Deserialize)]
enum SubscriptionType {
#[serde(rename = "subscribe")]
#[default]
Subscribe,
#[serde(rename = "unsubscribe")]
Unsubscribe,
}

#[derive(Debug, Serialize, Deserialize)]
struct SubscriptionRequest {
msg_type: SubscriptionType,
pair: String,
network: Network,
interval: Interval,
}

#[derive(Debug, Serialize, Deserialize)]
struct SubscriptionAck {
msg_type: SubscriptionType,
pair: String,
network: Network,
interval: Interval,
}

/// Interval in milliseconds that the channel will update the client with the latest prices.
const CHANNEL_UPDATE_INTERVAL_IN_MS: u64 = 500;

#[utoipa::path(
get,
path = "/node/v1/onchain/ws/ohlc/{base}/{quote}",
path = "/node/v1/onchain/ohlc",
responses(
(
status = 200,
description = "Get OHLC data for a pair continuously updated through a ws connection",
body = GetOnchainOHLCResponse
description = "Subscribe to a list of OHLC entries",
body = [SubscribeToEntryResponse]
)
),
params(
("base" = String, Path, description = "Base Asset"),
("quote" = String, Path, description = "Quote Asset"),
("network" = Network, Query, description = "Network"),
("interval" = Interval, Query, description = "Interval of the OHLC data"),
),
)
)]
pub async fn get_onchain_ohlc_ws(
pub async fn subscribe_to_onchain_ohlc(
ws: WebSocketUpgrade,
State(state): State<AppState>,
PathExtractor(pair): PathExtractor<(String, String)>,
Query(params): Query<GetOnchainOHLCParams>,
) -> impl IntoResponse {
let pair_id = currency_pair_to_pair_id(&pair.0, &pair.1);
ws.on_upgrade(move |socket| {
handle_ohlc_ws(socket, state, pair_id, params.network, params.interval)
})
ws.on_upgrade(move |socket| handle_channel(socket, state))
}

async fn handle_ohlc_ws(
mut socket: WebSocket,
state: AppState,
pair_id: String,
network: Network,
interval: Interval,
) {
// Initial OHLC to compute
let mut ohlc_to_compute = 10;
let mut update_interval =
tokio::time::interval(Duration::from_secs(WS_UPDATING_INTERVAL_IN_SECONDS));
/// Handle the WebSocket channel.
async fn handle_channel(mut socket: WebSocket, state: AppState) {
let waiting_duration = Duration::from_millis(CHANNEL_UPDATE_INTERVAL_IN_MS);
let mut update_interval = interval(waiting_duration);
let mut subscribed_pair: Option<String> = None;
let mut network = Network::default();
let mut interval = Interval::default();

let mut ohlc_to_compute = 10;
let mut ohlc_data: Vec<OHLCEntry> = Vec::new();

loop {
update_interval.tick().await;
match get_ohlc(
&mut ohlc_data,
&state.postgres_pool,
network,
pair_id.clone(),
interval,
ohlc_to_compute,
)
.await
{
Ok(()) => {
if socket
.send(Message::Text(serde_json::to_string(&ohlc_data).unwrap()))
.await
.is_err()
{
break;
tokio::select! {
Some(msg) = socket.recv() => {
if let Ok(Message::Text(text)) = msg {
handle_message_received(&mut socket, &state, &mut subscribed_pair, &mut network, &mut interval, text).await;
}
},
_ = update_interval.tick() => {
match send_ohlc_data(&mut socket, &state, &subscribed_pair, &mut ohlc_data, network, interval, ohlc_to_compute).await {
Ok(_) => {
// After the first request, we only get the latest interval
if !ohlc_data.is_empty() {
ohlc_to_compute = 1;
}
},
Err(_) => break
};
}
Err(e) => {
if socket
.send(Message::Text(json!({ "error": e.to_string() }).to_string()))
.await
.is_err()
{
break;
}
}
}

/// Handle the message received from the client.
/// Subscribe or unsubscribe to the pairs requested.
async fn handle_message_received(
socket: &mut WebSocket,
state: &AppState,
subscribed_pair: &mut Option<String>,
network: &mut Network,
interval: &mut Interval,
message: String,
) {
if let Ok(subscription_msg) = serde_json::from_str::<SubscriptionRequest>(&message) {
match subscription_msg.msg_type {
SubscriptionType::Subscribe => {
let pair_exists = is_onchain_existing_pair(
&state.postgres_pool,
&subscription_msg.pair,
subscription_msg.network,
)
.await;
if !pair_exists {
let error_msg = "Pair does not exist in the onchain database.";
send_error_message(socket, error_msg).await;
return;
}

*network = subscription_msg.network;
*subscribed_pair = Some(subscription_msg.pair.clone());
*interval = subscription_msg.interval;
}
SubscriptionType::Unsubscribe => {
*subscribed_pair = None;
}
};
// We send an ack message to the client with the subscribed pairs (so
// the client knows which pairs are successfully subscribed).
if let Ok(ack_message) = serde_json::to_string(&SubscriptionAck {
msg_type: subscription_msg.msg_type,
pair: subscription_msg.pair,
network: subscription_msg.network,
interval: subscription_msg.interval,
}) {
if socket.send(Message::Text(ack_message)).await.is_err() {
let error_msg = "Message received but could not send ack message.";
send_error_message(socket, error_msg).await;
}
} else {
let error_msg = "Could not serialize ack message.";
send_error_message(socket, error_msg).await;
}
// After the first request, we only get the latest interval
if !ohlc_data.is_empty() {
ohlc_to_compute = 1;
} else {
let error_msg = "Invalid message type. Please check the documentation for more info.";
send_error_message(socket, error_msg).await;
}
}

/// Send the current median entries to the client.
async fn send_ohlc_data(
socket: &mut WebSocket,
state: &AppState,
subscribed_pair: &Option<String>,
ohlc_data: &mut Vec<OHLCEntry>,
network: Network,
interval: Interval,
ohlc_to_compute: i64,
) -> Result<(), InfraError> {
if subscribed_pair.is_none() {
return Ok(());
}

let pair_id = subscribed_pair.as_ref().unwrap();

let entries = match get_ohlc(
ohlc_data,
&state.postgres_pool,
network,
pair_id.clone(),
interval,
ohlc_to_compute,
)
.await
{
Ok(()) => ohlc_data,
Err(e) => {
send_error_message(socket, &e.to_string()).await;
return Err(e);
}
};
if let Ok(json_response) = serde_json::to_string(&entries) {
if socket.send(Message::Text(json_response)).await.is_err() {
send_error_message(socket, "Could not send prices.").await;
}
} else {
send_error_message(socket, "Could not serialize prices.").await;
}
Ok(())
}

/// Send an error message to the client.
/// (Does not close the connection)
async fn send_error_message(socket: &mut WebSocket, error: &str) {
let error_msg = json!({ "error": error }).to_string();
socket.send(Message::Text(error_msg)).await.unwrap();
}
16 changes: 15 additions & 1 deletion pragma-node/src/handlers/entries/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use bigdecimal::{BigDecimal, ToPrimitive};
use chrono::NaiveDateTime;
use deadpool_diesel::postgres::Pool;
use pragma_common::types::Network;
use std::collections::HashMap;

use crate::infra::repositories::entry_repository::MedianEntry;
use crate::infra::repositories::{
entry_repository::MedianEntry, onchain_repository::get_existing_pairs,
};

const ONE_YEAR_IN_SECONDS: f64 = 3153600_f64;

Expand Down Expand Up @@ -68,6 +72,16 @@ pub(crate) fn compute_median_price_and_time(
Some((median_price, latest_time))
}

/// Given a pair and a network, returns if it exists in the
/// onchain database.
pub(crate) async fn is_onchain_existing_pair(pool: &Pool, pair: &String, network: Network) -> bool {
let existings_pairs = get_existing_pairs(pool, network)
.await
.expect("Couldn't get the existing pairs from the database.");

existings_pairs.into_iter().any(|p| p.pair_id == *pair)
}

/// Computes the volatility from a list of entries.
/// The volatility is computed as the annualized standard deviation of the log returns.
/// The log returns are computed as the natural logarithm of the ratio between two consecutive median prices.
Expand Down
31 changes: 31 additions & 0 deletions pragma-node/src/infra/repositories/onchain_repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,37 @@ pub async fn get_last_updated_timestamp(
Ok(most_recent_entry.timestamp.and_utc().timestamp() as u64)
}

#[derive(Queryable, QueryableByName)]
pub struct EntryPairId {
#[diesel(sql_type = VarChar)]
pub pair_id: String,
}

// TODO(0xevolve): Only works for Spot entries
pub async fn get_existing_pairs(
pool: &Pool,
network: Network,
) -> Result<Vec<EntryPairId>, InfraError> {
let raw_sql = format!(
r#"
SELECT DISTINCT
pair_id
FROM
{table_name};
"#,
table_name = get_table_name(network, DataType::SpotEntry)
);

let conn = pool.get().await.map_err(adapt_infra_error)?;
let raw_entries = conn
.interact(move |conn| diesel::sql_query(raw_sql).load::<EntryPairId>(conn))
.await
.map_err(adapt_infra_error)?
.map_err(adapt_infra_error)?;

Ok(raw_entries)
}

#[derive(Queryable, QueryableByName)]
struct RawCheckpoint {
#[diesel(sql_type = VarChar)]
Expand Down
2 changes: 1 addition & 1 deletion pragma-node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn main() {
handlers::entries::get_onchain::get_onchain,
handlers::entries::get_onchain::checkpoints::get_onchain_checkpoints,
handlers::entries::get_onchain::publishers::get_onchain_publishers,
handlers::entries::get_onchain::ohlc::get_onchain_ohlc_ws,
handlers::entries::get_onchain::ohlc::subscribe_to_onchain_ohlc,
),
components(
schemas(pragma_entities::dto::Entry, pragma_entities::EntryError),
Expand Down
4 changes: 2 additions & 2 deletions pragma-node/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use utoipa::OpenApi as OpenApiT;
use utoipa_swagger_ui::SwaggerUi;

use crate::handlers::entries::get_onchain::{
checkpoints::get_onchain_checkpoints, get_onchain, ohlc::get_onchain_ohlc_ws,
checkpoints::get_onchain_checkpoints, get_onchain, ohlc::subscribe_to_onchain_ohlc,
publishers::get_onchain_publishers,
};
use crate::handlers::entries::{create_entries, get_entry, get_ohlc, get_volatility};
Expand Down Expand Up @@ -47,7 +47,7 @@ fn onchain_routes(state: AppState) -> Router<AppState> {
.route("/:base/:quote", get(get_onchain))
.route("/checkpoints/:base/:quote", get(get_onchain_checkpoints))
.route("/publishers", get(get_onchain_publishers))
.route("/ws/ohlc/:base/:quote", get(get_onchain_ohlc_ws))
.route("/ws/ohlc", get(subscribe_to_onchain_ohlc))
.with_state(state)
}

Expand Down
Loading