migrate to axum (#97)

This commit is contained in:
QP Hou 2021-10-24 16:27:25 -07:00 committed by GitHub
parent 3e4bd2ea40
commit 246bc6d0c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 693 additions and 719 deletions

953
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,8 +9,6 @@ members = [
[patch.crates-io]
datafusion = { git = "https://github.com/houqp/arrow-datafusion.git", rev = "af34cec956c8d67b2520a266e74683d7bdcb3099" }
actix-cors = { git = "https://github.com/houqp/actix-extras.git", rev = "ab3bdb6a5924b6d881d204856199e7539e273d2f" }
deltalake = { git = "https://github.com/houqp/delta-rs.git", rev = "cbf1126e70f8578f192dced7286ea9d63d9f629e" }
[patch."https://github.com/apache/arrow-datafusion"]

View File

@ -160,9 +160,9 @@ async fn main() -> anyhow::Result<()> {
let app = clap::App::new("Columnq")
.version("0.0.1")
.author("QP Hou")
.about("OLAP the Unix way.")
.setting(clap::AppSettings::SubcommandRequiredElseHelp)
.setting(clap::AppSettings::DisableVersionForSubcommands)
.subcommand(
clap::App::new("sql")
.about("Query tables with SQL")

View File

@ -19,12 +19,15 @@ columnq = { path = "../columnq", version = "0", default-features = false }
# for datafusion optimization
snmalloc-rs = { version = "0.2", optional = true }
# all actix dependencies are patched to use git source until actix-web 4.x lands
actix-web = { version = "4.0.0-beta.4", default-features = false }
actix-http = { version = "3.0.0-beta.4", default-features = false }
actix-service = "2.0.0-beta.3"
# see https://github.com/actix/actix-extras/pull/144 for cors tokio 1 upgrade
actix-cors = "*"
# dependencies related to axum
tokio = { version = "1", features = ["rt-multi-thread"] }
hyper = { version = "0", features = ["http1", "server", "stream", "runtime"] }
# axum = "0.2.8"
axum = { git = "https://github.com/tokio-rs/axum.git", rev = "7692baf83728775afbb7ed7f1a178d741ca74c40" }
tower-http = { git = "https://github.com/tower-rs/tower-http.git", branch = "cors", features = ["cors"] }
tower-layer = "0"
tracing = "0"
pin-project = "1"
env_logger = "0"
log = "0"
@ -47,7 +50,6 @@ snmalloc = ["snmalloc-rs"]
[dev-dependencies]
actix-rt = "*"
reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls"]}
tokio = { version = "1" }
# TODO: uncomment this when we exclude roapi-http from root workspace
# [profile.release]

View File

@ -1,17 +1,22 @@
use actix_web::{web, HttpRequest, HttpResponse};
use std::sync::Arc;
use crate::api::{encode_record_batches, encode_type_from_req, HandlerContext};
use axum::body::Body;
use axum::body::Bytes;
use axum::extract;
use axum::http::header::HeaderMap;
use axum::http::Response;
use crate::api::{encode_record_batches, encode_type_from_hdr, HandlerContext};
use crate::error::ApiErrResp;
pub async fn post(
data: web::Data<HandlerContext>,
req: HttpRequest,
query: web::Bytes,
) -> Result<HttpResponse, ApiErrResp> {
let encode_type = encode_type_from_req(req)?;
let graphq = std::str::from_utf8(&query).map_err(ApiErrResp::read_query)?;
let batches = data.cq.query_graphql(graphq).await?;
state: extract::Extension<Arc<HandlerContext>>,
headers: HeaderMap,
body: Bytes,
) -> Result<Response<Body>, ApiErrResp> {
let ctx = state.0;
let encode_type = encode_type_from_hdr(headers)?;
let graphq = std::str::from_utf8(&body).map_err(ApiErrResp::read_query)?;
let batches = ctx.cq.query_graphql(graphq).await?;
encode_record_batches(encode_type, &batches)
}

View File

@ -1,6 +1,9 @@
use std::convert::TryFrom;
use actix_web::{http, HttpRequest, HttpResponse};
use axum::body::Body;
use axum::http;
use axum::http::header;
use axum::http::Response;
use columnq::datafusion::arrow;
use columnq::encoding;
use columnq::ColumnQ;
@ -32,8 +35,25 @@ impl HandlerContext {
}
}
pub fn encode_type_from_req(req: HttpRequest) -> Result<encoding::ContentType, ApiErrResp> {
match req.headers().get(http::header::ACCEPT) {
#[inline]
pub fn bytes_to_resp(bytes: Vec<u8>, content_type: &'static str) -> Response<Body> {
let mut res = Response::new(Body::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static(content_type),
);
res
}
#[inline]
pub fn bytes_to_json_resp(bytes: Vec<u8>) -> Response<Body> {
bytes_to_resp(bytes, "application/json")
}
pub fn encode_type_from_hdr(
headers: header::HeaderMap,
) -> Result<encoding::ContentType, ApiErrResp> {
match headers.get(header::ACCEPT) {
None => Ok(encoding::ContentType::Json),
Some(hdr_value) => {
encoding::ContentType::try_from(hdr_value.as_bytes()).map_err(|_| ApiErrResp {
@ -48,7 +68,7 @@ pub fn encode_type_from_req(req: HttpRequest) -> Result<encoding::ContentType, A
pub fn encode_record_batches(
content_type: encoding::ContentType,
batches: &[arrow::record_batch::RecordBatch],
) -> Result<HttpResponse, ApiErrResp> {
) -> Result<Response<Body>, ApiErrResp> {
let payload = match content_type {
encoding::ContentType::Json => encoding::json::record_batches_to_bytes(batches)
.map_err(ApiErrResp::json_serialization)?,
@ -64,9 +84,7 @@ pub fn encode_record_batches(
.map_err(ApiErrResp::parquet_serialization)?,
};
let mut resp = HttpResponse::Ok();
let builder = resp.content_type(content_type.to_str());
Ok(builder.body(payload))
Ok(bytes_to_resp(payload, content_type.to_str()))
}
pub mod graphql;

View File

@ -1,28 +1,23 @@
use std::collections::HashMap;
use std::sync::Arc;
use actix_web::{web, HttpRequest, HttpResponse};
use serde_derive::Deserialize;
use axum::body::Body;
use axum::extract;
use axum::http::header::HeaderMap;
use axum::http::Response;
use crate::api::{encode_record_batches, encode_type_from_req, HandlerContext};
use crate::api::HandlerContext;
use crate::api::{encode_record_batches, encode_type_from_hdr};
use crate::error::ApiErrResp;
#[derive(Deserialize)]
pub struct RestTablePath {
table_name: String,
}
pub async fn get_table(
data: web::Data<HandlerContext>,
path: web::Path<RestTablePath>,
req: HttpRequest,
query: web::Query<HashMap<String, String>>,
) -> Result<HttpResponse, ApiErrResp> {
let encode_type = encode_type_from_req(req)?;
let batches = data
.cq
.query_rest_table(&path.table_name, &query.into_inner())
.await?;
state: extract::Extension<Arc<HandlerContext>>,
headers: HeaderMap,
extract::Path(table_name): extract::Path<String>,
extract::Query(params): extract::Query<HashMap<String, String>>,
) -> Result<Response<Body>, ApiErrResp> {
let ctx = &state.0;
let encode_type = encode_type_from_hdr(headers)?;
let batches = ctx.cq.query_rest_table(&table_name, &params).await?;
encode_record_batches(encode_type, &batches)
}

View File

@ -1,30 +1,14 @@
use actix_http::body::MessageBody;
use actix_service::ServiceFactory;
use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::{web, App, Error};
use crate::api;
pub fn register_app_routes<T, B>(app: App<T, B>) -> App<T, B>
where
B: MessageBody,
T: ServiceFactory<
ServiceRequest,
Config = (),
Response = ServiceResponse<B>,
Error = Error,
InitError = (),
>,
{
app.route(
"/api/tables/{table_name}",
web::get().to(api::rest::get_table),
)
.route("/api/sql", web::post().to(api::sql::post))
.route("/api/graphql", web::post().to(api::graphql::post))
.route("/api/schema", web::get().to(api::schema::get))
.route(
"/api/schema/{table_name}",
web::get().to(api::schema::get_by_table_name),
)
use axum::{
routing::{get, post},
Router,
};
pub fn register_app_routes() -> Router {
Router::new()
.route("/api/tables/:table_name", get(api::rest::get_table))
.route("/api/sql", post(api::sql::post))
.route("/api/graphql", post(api::graphql::post))
.route("/api/schema", get(api::schema::schema))
}

View File

@ -1,39 +1,32 @@
use std::collections::HashMap;
use std::sync::Arc;
use actix_web::{web, HttpRequest, HttpResponse};
use serde_derive::Deserialize;
use axum::body::Body;
use axum::extract;
use axum::http::Response;
use crate::api::HandlerContext;
use crate::api::{bytes_to_json_resp, HandlerContext};
use crate::error::ApiErrResp;
pub async fn get(
data: web::Data<HandlerContext>,
_req: HttpRequest,
_query: web::Bytes,
) -> Result<HttpResponse, ApiErrResp> {
Ok(HttpResponse::Ok()
.content_type("application/json")
.body(serde_json::to_vec(data.cq.schema_map()).map_err(ApiErrResp::json_serialization)?))
}
#[derive(Deserialize)]
pub struct SchemaTablePath {
table_name: String,
pub async fn schema(
state: extract::Extension<Arc<HandlerContext>>,
) -> Result<Response<Body>, ApiErrResp> {
let ctx = state.0;
let payload =
serde_json::to_vec(ctx.cq.schema_map()).map_err(ApiErrResp::json_serialization)?;
Ok(bytes_to_json_resp(payload))
}
pub async fn get_by_table_name(
data: web::Data<HandlerContext>,
path: web::Path<SchemaTablePath>,
_req: HttpRequest,
_query: web::Query<HashMap<String, String>>,
) -> Result<HttpResponse, ApiErrResp> {
Ok(HttpResponse::Ok().content_type("application/json").body(
serde_json::to_vec(
data.cq
.schema_map()
.get(&path.table_name)
.ok_or_else(|| ApiErrResp::not_found("invalid table name"))?,
)
.map_err(ApiErrResp::json_serialization)?,
))
state: extract::Extension<Arc<HandlerContext>>,
extract::Path(table_name): extract::Path<String>,
) -> Result<Response<Body>, ApiErrResp> {
let ctx = state.0;
let payload = serde_json::to_vec(
ctx.cq
.schema_map()
.get(&table_name)
.ok_or_else(|| ApiErrResp::not_found("invalid table name"))?,
)
.map_err(ApiErrResp::json_serialization)?;
Ok(bytes_to_json_resp(payload))
}

View File

@ -1,17 +1,22 @@
use actix_web::{web, HttpRequest, HttpResponse};
use std::sync::Arc;
use crate::api::{encode_record_batches, encode_type_from_req, HandlerContext};
use axum::body::Body;
use axum::body::Bytes;
use axum::extract;
use axum::http::header::HeaderMap;
use axum::http::Response;
use crate::api::{encode_record_batches, encode_type_from_hdr, HandlerContext};
use crate::error::ApiErrResp;
pub async fn post(
data: web::Data<HandlerContext>,
req: HttpRequest,
query: web::Bytes,
) -> Result<HttpResponse, ApiErrResp> {
let encode_type = encode_type_from_req(req)?;
let sql = std::str::from_utf8(&query).map_err(ApiErrResp::read_query)?;
let batches = data.cq.query_sql(sql).await?;
state: extract::Extension<Arc<HandlerContext>>,
headers: HeaderMap,
body: Bytes,
) -> Result<Response<Body>, ApiErrResp> {
let ctx = state.0;
let encode_type = encode_type_from_hdr(headers)?;
let sql = std::str::from_utf8(&body).map_err(ApiErrResp::read_query)?;
let batches = ctx.cq.query_sql(sql).await?;
encode_record_batches(encode_type, &batches)
}

View File

@ -1,8 +1,7 @@
use std::fmt;
use actix_http::body::Body;
use actix_http::Response;
use actix_web::{http, HttpResponse};
use axum::http;
use axum::http::Response;
use columnq::datafusion::arrow;
use columnq::datafusion::parquet;
use columnq::error::QueryError;
@ -17,6 +16,13 @@ pub struct ApiErrResp {
pub message: String,
}
fn serialize_statuscode<S>(x: &http::StatusCode, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
s.serialize_u16(x.as_u16())
}
impl ApiErrResp {
pub fn not_found(message: &str) -> Self {
Self {
@ -30,7 +36,7 @@ impl ApiErrResp {
Self {
code: http::StatusCode::INTERNAL_SERVER_ERROR,
error: "json_serialization".to_string(),
message: "Failed to serialize record batches into JSON".to_string(),
message: "Failed to serialize payload into JSON".to_string(),
}
}
@ -85,11 +91,14 @@ impl From<QueryError> for ApiErrResp {
}
}
fn serialize_statuscode<S>(x: &http::StatusCode, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
s.serialize_u16(x.as_u16())
impl From<http::Error> for ApiErrResp {
fn from(e: http::Error) -> Self {
ApiErrResp {
error: "http_error".to_string(),
message: e.to_string(),
code: http::StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
impl fmt::Display for ApiErrResp {
@ -98,12 +107,14 @@ impl fmt::Display for ApiErrResp {
}
}
impl actix_web::error::ResponseError for ApiErrResp {
fn status_code(&self) -> http::StatusCode {
self.code
}
impl axum::response::IntoResponse for ApiErrResp {
type Body = axum::body::Body;
type BodyError = <Self::Body as axum::body::HttpBody>::Error;
fn error_response(&self) -> Response<Body> {
HttpResponse::build(self.code).json(self)
fn into_response(self) -> Response<axum::body::Body> {
let payload = serde_json::to_vec(&self).unwrap();
let mut res = Response::new(axum::body::Body::from(payload));
*res.status_mut() = self.code;
res
}
}

118
roapi-http/src/layers.rs Normal file
View File

@ -0,0 +1,118 @@
use axum::http::uri::Uri;
use axum::http::Method;
use axum::http::Request;
use axum::http::Response;
use hyper::service::Service;
use log::error;
use log::info;
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Instant;
use tower_layer::Layer;
pub struct HttpLoggerLayer {}
impl HttpLoggerLayer {
pub fn new() -> Self {
Self {}
}
}
impl Default for HttpLoggerLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for HttpLoggerLayer {
type Service = HttpLogger<S>;
fn layer(&self, service: S) -> Self::Service {
HttpLogger::new(service)
}
}
#[derive(Debug, Clone)]
pub struct HttpLogger<Inner> {
inner: Inner,
}
impl<Inner> HttpLogger<Inner> {
fn new(inner: Inner) -> Self {
Self { inner }
}
}
impl<Inner, ReqBody, ResBody> Service<Request<ReqBody>> for HttpLogger<Inner>
where
Inner: Service<Request<ReqBody>, Response = Response<ResBody>>,
Inner::Error: std::fmt::Debug,
{
type Response = Inner::Response;
type Error = Inner::Error;
type Future = LoggerResponseFuture<Inner::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
// TODO: user-agent can be extracted from request.headers()
let method = request.method().to_owned();
let uri = request.uri().to_owned();
let response_future = self.inner.call(request);
LoggerResponseFuture {
response_future,
method,
uri,
start: None,
}
}
}
#[pin_project]
pub struct LoggerResponseFuture<F> {
#[pin]
response_future: F,
method: Method,
uri: Uri,
start: Option<Instant>,
}
impl<F, Body, E> Future for LoggerResponseFuture<F>
where
F: Future<Output = Result<Response<Body>, E>>,
E: std::fmt::Debug,
{
type Output = Result<Response<Body>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let start = this.start.get_or_insert_with(Instant::now);
match this.response_future.poll(cx) {
Poll::Ready(result) => {
match &result {
Ok(resp) => {
let elapsed = start.elapsed();
info!(
"[{}] {:?} {} {:?}",
resp.status(),
this.method,
this.uri,
elapsed,
);
}
Err(err) => {
error!("{:?} {}: {:?}", this.method, this.uri, err);
}
}
Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
}
}
}

View File

@ -3,6 +3,7 @@
pub mod api;
pub mod config;
pub mod error;
pub mod layers;
pub mod startup;
#[cfg(test)]

View File

@ -3,12 +3,15 @@
use roapi_http::config::get_configuration;
use roapi_http::startup::Application;
#[actix_web::main]
#[cfg(snmalloc)]
#[global_allocator]
static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
let config = get_configuration()?;
let application = Application::build(config).await?;
application.run_until_stopped().await?;
Ok(())

View File

@ -1,45 +1,47 @@
use axum::http::Method;
use std::net::TcpListener;
use std::sync::Arc;
use crate::api;
use crate::api::HandlerContext;
use crate::config::Config;
use actix_cors::Cors;
use actix_web::dev::Server;
use actix_web::{middleware, web, App, HttpServer};
use std::net::TcpListener;
use crate::layers::HttpLoggerLayer;
pub struct Application {
port: u16,
server: Server,
server: axum::Server<
hyper::server::conn::AddrIncoming,
axum::routing::IntoMakeService<axum::Router>,
>,
}
impl Application {
pub async fn build(config: Config) -> Result<Self, std::io::Error> {
pub async fn build(config: Config) -> anyhow::Result<Self> {
let addr = (config.addr)
.clone()
.unwrap_or_else(|| "127.0.0.1:8080".to_string());
let listener = TcpListener::bind(addr)?;
let port = listener.local_addr().unwrap().port();
let ctx = web::Data::new(
HandlerContext::new(&config)
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
);
let handler_ctx = HandlerContext::new(&config)
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let server = HttpServer::new(move || {
let app = App::new()
.app_data(ctx.clone())
.wrap(middleware::Logger::default())
.wrap(
Cors::default()
.allowed_methods(vec!["POST", "GET"])
.supports_credentials()
.max_age(3600),
);
api::register_app_routes(app)
})
.listen(listener)?
.run();
let routes = api::routes::register_app_routes();
let cors = tower_http::cors::CorsLayer::new()
.allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
.allow_origin(tower_http::cors::Any)
.allow_credentials(false);
let mut app = routes
.layer(axum::AddExtensionLayer::new(Arc::new(handler_ctx)))
.layer(cors);
if log::log_enabled!(log::Level::Info) {
// only add logger layer if level >= INFO
app = app.layer(HttpLoggerLayer::new());
}
let server = axum::Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service());
Ok(Self { port, server })
}
@ -48,7 +50,7 @@ impl Application {
self.port
}
pub async fn run_until_stopped(self) -> Result<(), std::io::Error> {
self.server.await
pub async fn run_until_stopped(self) -> anyhow::Result<()> {
Ok(self.server.await?)
}
}