#[cfg(feature = "axum")] use { axum::{extract::Request, handler::Handler, Router, ServiceExt}, std::net::Ipv4Addr, tower::layer::Layer, tower_http::{ cors::CorsLayer, normalize_path::NormalizePathLayer, trace, trace::{HttpMakeClassifier, TraceLayer}, }, tracing::{info, Level}, }; #[cfg(all(feature = "axum", feature = "tokio"))] use {std::io, std::net::SocketAddr, tokio::net::TcpListener}; // TODO trim trailing slash into macro > let _app = NormalizePathLayer::trim_trailing_slash().layer(create_app!(routes)); #[macro_export] #[cfg(feature = "axum")] macro_rules! create_app { ($router:expr) => { $router }; ($router:expr, $($layer:expr),* $(,)?) => { $router$(.layer($layer))* }; } #[derive(Default)] #[cfg(feature = "axum")] pub struct AppBuilder { router: Router, socket: Option<(Ipv4Addr, u16)>, cors: Option<CorsLayer>, normalize_path: Option<bool>, tracing: Option<TraceLayer<HttpMakeClassifier>>, } #[cfg(all(feature = "axum", feature = "tokio"))] impl AppBuilder { pub fn new() -> Self { Self::default() } pub fn route(mut self, route: Router) -> Self { self.router = self.router.merge(route); self } pub fn routes(mut self, routes: impl IntoIterator<Item = Router>) -> Self { self.router = routes.into_iter().fold(self.router, Router::merge); self } pub fn socket(mut self, socket: impl Into<(Ipv4Addr, u16)>) -> Self { self.socket = Some(socket.into()); self } pub fn fallback<H, T>(mut self, fallback: H) -> Self where H: Handler<T, ()>, T: 'static, { self.router = self.router.fallback(fallback); self } pub fn cors(mut self, cors: CorsLayer) -> Self { self.cors = Some(cors); self } pub fn normalize_path(mut self, normalize_path: bool) -> Self { self.normalize_path = Some(normalize_path); self } pub fn tracing(mut self, tracing: TraceLayer<HttpMakeClassifier>) -> Self { self.tracing = Some(tracing); self } pub async fn serve(self) -> io::Result<()> { let _ = fmt_trace(); let listener = self.listener().await?; if self.normalize_path.unwrap_or(true) { let app = NormalizePathLayer::trim_trailing_slash().layer(self.create_app()); axum::serve(listener, ServiceExt::<Request>::into_make_service(app)).await?; } else { let app = self.create_app(); axum::serve(listener, app.into_make_service()).await?; }; Ok(()) } async fn listener(&self) -> io::Result<TcpListener> { let addr = SocketAddr::from(self.socket.unwrap_or((Ipv4Addr::UNSPECIFIED, 8000))); info!("Initializing server on: {addr}"); TcpListener::bind(&addr).await } fn create_app(self) -> Router { let mut app = self.router; if let Some(cors) = self.cors { app = app.layer(cors); } app.layer( self.tracing.unwrap_or( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) .on_response(trace::DefaultOnResponse::new().level(Level::INFO)), ), ) } } fn fmt_trace() -> Result<(), String> { tracing_subscriber::fmt() .with_target(false) .compact() .try_init() .map_err(|error| error.to_string()) } #[cfg(all(test, feature = "axum"))] mod tests { use axum::Router; use super::*; #[cfg(feature = "tokio")] mod tokio_tests { use std::time::Duration; use tokio::time::sleep; use super::*; #[tokio::test] async fn test_app_builder_serve() { let handler = tokio::spawn(async { AppBuilder::new().serve().await.unwrap(); }); sleep(Duration::from_secs(1)).await; handler.abort(); } #[tokio::test] async fn test_app_builder_all() { let handler = tokio::spawn(async { AppBuilder::new() .socket((Ipv4Addr::LOCALHOST, 8080)) .routes([Router::new()]) .fallback(|| async { "Fallback" }) .cors(CorsLayer::new()) .normalize_path(true) .tracing(TraceLayer::new_for_http()) .serve() .await .unwrap(); }); sleep(Duration::from_secs(1)).await; handler.abort(); } } #[test] fn test_create_app_router_only() { let _app: Router<()> = create_app!(Router::new()); } }