--- url: https://docs.rs/axum/latest/axum/ description: Axum Web Framework 完整中文文档 - Rust高性能异步Web框架,基于Tokio和Tower生态,提供类型安全的路由、提取器、中间件和WebSocket支持。 --- # Axum Web Framework 完整文档 Axum 是由 Tokio 团队开发的现代 Rust 异步 Web 框架,专注于人体工程学和模块化设计。它建立在 Tower 和 Hyper 之上,提供强大的中间件支持和完全的异步处理能力。 --- # 介绍 Axum 是一个用 Rust 编写的高性能异步 HTTP Web 框架。它是 Tokio 生态系统的官方 Web 框架,充分利用了 Rust 的类型系统来实现编译时安全。 ## 核心特性 - **类型安全**: 利用 Rust 类型系统在编译时验证路由和处理器 - **模块化设计**: 通过组合而非继承来构建应用 - **异步优先**: 原生支持 async/await 和 Tokio 异步运行时 - **Tower 兼容**: 与 Tower 生态系统完全集成,支持丰富的中间件 - **提取器模式**: 使用提取器从请求中安全地提取数据 - **零成本抽象**: 高性能,编译时优化 --- # 快速入门 ## 要求 - Rust 1.75 及以上版本 - Tokio 异步运行时 ## 安装 ```bash # 创建项目目录 mkdir myproject && cd myproject # 初始化 Cargo 项目 cargo init # 添加依赖 cargo add axum tokio --features tokio/full ``` ## 第一个应用 创建 `src/main.rs`: ```rust use axum::{ routing::get, Router, }; #[tokio::main] async fn main() { // 构建路由 let app = Router::new() .route("/", get(|| async { "Hello, World!" })); // 监听地址 let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("Server running on http://127.0.0.1:3000"); // 启动服务器 axum::serve(listener, app).await.unwrap(); } ``` 运行: ```bash cargo run ``` 访问 `http://localhost:3000` 将返回 `Hello, World!` --- # 核心架构 ## Axum与Tower生态关系 ``` ┌─────────────────────────────────────┐ │ Axum Web Framework │ │ (高级Web服务层抽象) │ └──────────────┬──────────────────────┘ │ ┌──────────────▼──────────────────────┐ │ Tower Service Layer │ │ (通用中间件和服务组合) │ ├──────────────────────────────────────┤ │ - Tower Middleware │ │ - ServiceBuilder │ │ - Layer Trait │ └──────────────┬──────────────────────┘ │ ┌──────────────▼──────────────────────┐ │ Hyper HTTP Runtime │ │ (HTTP 1.1/2支持) │ └──────────────┬──────────────────────┘ │ ┌──────────────▼──────────────────────┐ │ Tokio Async Runtime │ │ (异步执行引擎) │ └──────────────────────────────────────┘ ``` ## Service Trait Service trait 是 Tower 的核心,也是 Axum 的基础: ```rust pub trait Service { type Response; type Error; type Future: Future>; fn poll_ready(&mut self, cx: &mut Context) -> Poll>; fn call(&mut self, req: Request) -> Self::Future; } ``` ## Layer Trait Layer 用于包装和装饰 Service: ```rust pub trait Layer { type Service: Service; fn layer(&self, inner: S) -> Self::Service; } ``` --- # 路由系统 ## 基本路由 ```rust use axum::{ routing::{get, post, put, delete}, Router, }; let app = Router::new() .route("/", get(root_handler)) .route("/users", get(list_users).post(create_user)) .route("/users/:id", get(get_user).put(update_user).delete(delete_user)); ``` ## 路由参数 ### 单个参数 ```rust use axum::extract::Path; async fn get_user(Path(id): Path) -> String { format!("User ID: {}", id) } // 路由: GET /users/:id ``` ### 多个参数 ```rust use axum::extract::Path; async fn get_user_post( Path((user_id, post_id)): Path<(u32, u32)> ) -> String { format!("User: {}, Post: {}", user_id, post_id) } // 路由: GET /users/:user_id/posts/:post_id ``` ### 命名参数(推荐) ```rust use axum::extract::Path; use serde::Deserialize; #[derive(Deserialize)] struct UserPostParams { user_id: u32, post_id: u32, } async fn get_user_post( Path(params): Path ) -> String { format!("User: {}, Post: {}", params.user_id, params.post_id) } ``` ### 通配符参数 ```rust async fn serve_file(Path(path): Path) -> String { format!("Serving file: {}", path) } // 路由: GET /files/*path // GET /files/docs/a/b/c.txt -> "Serving file: docs/a/b/c.txt" ``` ## 嵌套路由 ```rust use axum::{Router, routing::get}; // 用户子路由 let user_routes = Router::new() .route("/", get(list_users).post(create_user)) .route("/:id", get(get_user).put(update_user).delete(delete_user)); // 文章子路由 let post_routes = Router::new() .route("/", get(list_posts).post(create_post)) .route("/:id", get(get_post)); // 主应用 let app = Router::new() .nest("/api/users", user_routes) .nest("/api/posts", post_routes) .route("/health", get(|| async { "OK" })); // 生成的路由: // GET /api/users -> list_users // POST /api/users -> create_user // GET /api/users/:id -> get_user // GET /api/posts -> list_posts // POST /api/posts -> create_post ``` ## 路由合并 ```rust let api_v1 = Router::new() .route("/users", get(list_users_v1)) .route("/posts", get(list_posts_v1)); let api_v2 = Router::new() .route("/users", get(list_users_v2)) .route("/posts", get(list_posts_v2)); let app = Router::new() .nest("/api/v1", api_v1) .nest("/api/v2", api_v2); ``` ## Fallback处理 ```rust use axum::{ Router, Json, http::StatusCode, response::IntoResponse, }; use serde_json::json; async fn not_found_handler() -> impl IntoResponse { ( StatusCode::NOT_FOUND, Json(json!({ "error": "Not Found", "status": 404 })) ) } let app = Router::new() .route("/api/users", get(list_users)) .fallback(not_found_handler); ``` --- # 提取器 (Extractors) 提取器是从 HTTP 请求中提取数据的核心抽象。 ## Path - 路径参数 ```rust use axum::extract::Path; // 单个参数 async fn get_user(Path(id): Path) -> String { format!("User ID: {}", id) } // 多个参数 async fn get_item( Path((category, id)): Path<(String, u32)> ) -> String { format!("Category: {}, ID: {}", category, id) } // 结构体提取 #[derive(Deserialize)] struct PathParams { category: String, id: u32, } async fn get_item_struct(Path(params): Path) -> String { format!("Category: {}, ID: {}", params.category, params.id) } ``` ## Query - 查询参数 ```rust use axum::extract::Query; use serde::Deserialize; #[derive(Deserialize)] struct Pagination { page: Option, limit: Option, search: Option, } async fn list_items(Query(params): Query) -> String { let page = params.page.unwrap_or(1); let limit = params.limit.unwrap_or(10); format!("Page: {}, Limit: {}, Search: {:?}", page, limit, params.search) } // GET /items?page=2&limit=20&search=rust ``` ## Json - JSON请求体 ```rust use axum::{extract::Json, http::StatusCode}; use serde::{Deserialize, Serialize}; #[derive(Deserialize)] struct CreateUser { name: String, email: String, } #[derive(Serialize)] struct User { id: u32, name: String, email: String, } async fn create_user( Json(payload): Json ) -> (StatusCode, Json) { let user = User { id: 1, name: payload.name, email: payload.email, }; (StatusCode::CREATED, Json(user)) } ``` ## Form - 表单数据 ```rust use axum::extract::Form; use serde::Deserialize; #[derive(Deserialize)] struct LoginForm { username: String, password: String, } async fn login(Form(form): Form) -> String { format!("Login: {}", form.username) } ``` ## HeaderMap - 请求头 ```rust use axum::http::HeaderMap; async fn show_headers(headers: HeaderMap) -> String { let user_agent = headers .get("user-agent") .and_then(|v| v.to_str().ok()) .unwrap_or("Unknown"); format!("User-Agent: {}", user_agent) } ``` ## State - 共享状态 ```rust use axum::extract::State; use std::sync::Arc; #[derive(Clone)] struct AppState { db_url: String, } async fn handler(State(state): State) -> String { format!("DB: {}", state.db_url) } let state = AppState { db_url: "postgresql://localhost/mydb".to_string(), }; let app = Router::new() .route("/", get(handler)) .with_state(state); ``` ## Extension - 扩展数据 ```rust use axum::Extension; use std::sync::Arc; #[derive(Clone)] struct RequestContext { request_id: String, } async fn handler( Extension(ctx): Extension> ) -> String { format!("Request ID: {}", ctx.request_id) } ``` ## Multipart - 文件上传 ```rust use axum::extract::Multipart; async fn upload(mut multipart: Multipart) -> String { while let Some(field) = multipart.next_field().await.unwrap() { let name = field.name().unwrap_or("unknown").to_string(); let data = field.bytes().await.unwrap(); println!("Field: {}, Size: {} bytes", name, data.len()); } "Upload complete".to_string() } ``` ## 自定义提取器 ```rust use axum::{ async_trait, extract::FromRequestParts, http::request::Parts, response::{IntoResponse, Response}, }; struct AuthUser { user_id: u32, } #[async_trait] impl FromRequestParts for AuthUser where S: Send + Sync, { type Rejection = Response; async fn from_request_parts( parts: &mut Parts, _state: &S, ) -> Result { let auth_header = parts .headers .get("Authorization") .and_then(|v| v.to_str().ok()) .ok_or_else(|| { (axum::http::StatusCode::UNAUTHORIZED, "Missing auth").into_response() })?; // 解析token获取user_id let user_id = parse_token(auth_header)?; Ok(AuthUser { user_id }) } } async fn protected_handler(user: AuthUser) -> String { format!("User ID: {}", user.user_id) } ``` ## 提取器顺序 提取器按照参数顺序执行。**注意:只有一个提取器可以消费请求体。** ```rust // 正确:Path和Query在Json之前 async fn handler( Path(id): Path, // 1. 从URL提取 Query(params): Query, // 2. 从查询字符串提取 Json(body): Json, // 3. 消费请求体 ) -> String { // ... } // 错误:两个提取器都想消费请求体 // async fn bad_handler( // Json(body1): Json, // Json(body2): Json, // 编译错误! // ) {} ``` --- # 响应处理 ## IntoResponse Trait 所有 Handler 返回值必须实现 `IntoResponse`: ```rust pub trait IntoResponse { fn into_response(self) -> Response; } ``` ## 内置响应类型 ### Json 响应 ```rust use axum::Json; use serde::Serialize; #[derive(Serialize)] struct User { id: u32, name: String, } async fn get_user() -> Json { Json(User { id: 1, name: "Alice".to_string(), }) } ``` ### Html 响应 ```rust use axum::response::Html; async fn index() -> Html<&'static str> { Html("

Welcome

") } ``` ### Redirect 响应 ```rust use axum::response::Redirect; // 临时重定向 (307) async fn redirect_temp() -> Redirect { Redirect::temporary("/new-location") } // 永久重定向 (308) async fn redirect_perm() -> Redirect { Redirect::permanent("/new-location") } ``` ### StatusCode 响应 ```rust use axum::http::StatusCode; async fn health_check() -> StatusCode { StatusCode::OK } async fn not_implemented() -> StatusCode { StatusCode::NOT_IMPLEMENTED } ``` ### 元组响应 ```rust use axum::{http::StatusCode, Json}; // (StatusCode, Body) async fn created() -> (StatusCode, Json) { (StatusCode::CREATED, Json(user)) } // (StatusCode, Headers, Body) async fn with_headers() -> (StatusCode, [(&'static str, &'static str); 1], String) { ( StatusCode::OK, [("X-Custom-Header", "value")], "Response body".to_string() ) } ``` ### Result 响应 ```rust use axum::{http::StatusCode, Json}; async fn maybe_user(id: u32) -> Result, StatusCode> { if id == 0 { Err(StatusCode::NOT_FOUND) } else { Ok(Json(User { id, name: "Alice".to_string() })) } } ``` ## 自定义响应类型 ```rust use axum::{ http::StatusCode, response::{IntoResponse, Response}, Json, }; use serde::Serialize; #[derive(Debug)] enum AppError { NotFound, BadRequest(String), Internal, } #[derive(Serialize)] struct ErrorResponse { error: String, message: String, } impl IntoResponse for AppError { fn into_response(self) -> Response { let (status, error_message) = match self { AppError::NotFound => (StatusCode::NOT_FOUND, "Not found"), AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, &msg), AppError::Internal => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"), }; (status, Json(ErrorResponse { error: status.to_string(), message: error_message.to_string(), })).into_response() } } // 使用 async fn handler() -> Result, AppError> { Err(AppError::NotFound) } ``` ## 流式响应 ```rust use axum::{ body::StreamBody, response::IntoResponse, }; use futures::stream::{self, Stream}; use tokio_util::io::ReaderStream; async fn stream_response() -> impl IntoResponse { let stream = stream::iter(vec![ Ok::<_, std::io::Error>("Hello ".into()), Ok("World".into()), ]); StreamBody::new(stream) } // 文件下载 async fn download_file() -> impl IntoResponse { let file = tokio::fs::File::open("file.txt").await.unwrap(); let stream = ReaderStream::new(file); let body = StreamBody::new(stream); ( [ ("content-type", "application/octet-stream"), ("content-disposition", "attachment; filename=\"file.txt\""), ], body, ) } ``` ## SSE (Server-Sent Events) ```rust use axum::response::sse::{Event, Sse, KeepAlive}; use futures::stream::{self, Stream}; use std::convert::Infallible; use std::time::Duration; async fn sse_handler() -> Sse>> { let stream = stream::repeat_with(|| { Event::default().data("heartbeat") }) .map(Ok) .throttle(Duration::from_secs(1)); Sse::new(stream).keep_alive(KeepAlive::default()) } // 带事件类型 async fn sse_with_events() -> Sse>> { let events = vec![ Event::default().event("user_joined").data("Alice"), Event::default().event("message").data("Hello!"), ]; let stream = stream::iter(events).map(Ok); Sse::new(stream) } ``` --- # 中间件 ## 使用内置中间件 ### TraceLayer - 请求追踪 ```rust use tower_http::trace::TraceLayer; let app = Router::new() .route("/", get(handler)) .layer(TraceLayer::new_for_http()); ``` ### CorsLayer - 跨域处理 ```rust use tower_http::cors::{CorsLayer, Any}; use axum::http::Method; let cors = CorsLayer::new() .allow_origin(Any) .allow_methods(vec![Method::GET, Method::POST]) .allow_headers(Any); let app = Router::new() .route("/api", get(handler)) .layer(cors); ``` ### CompressionLayer - 压缩 ```rust use tower_http::compression::CompressionLayer; let app = Router::new() .route("/", get(handler)) .layer(CompressionLayer::new()); ``` ### TimeoutLayer - 超时 ```rust use tower_http::timeout::TimeoutLayer; use std::time::Duration; let app = Router::new() .route("/", get(handler)) .layer(TimeoutLayer::new(Duration::from_secs(30))); ``` ### RequestBodyLimitLayer - 请求体限制 ```rust use tower_http::limit::RequestBodyLimitLayer; let app = Router::new() .route("/upload", post(upload)) .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)); // 10MB ``` ## 使用 from_fn 创建中间件 ```rust use axum::{ middleware::{self, Next}, http::Request, response::Response, }; async fn logging_middleware( request: Request, next: Next, ) -> Response { let method = request.method().clone(); let uri = request.uri().clone(); println!("Request: {} {}", method, uri); let response = next.run(request).await; println!("Response: {}", response.status()); response } let app = Router::new() .route("/", get(handler)) .layer(middleware::from_fn(logging_middleware)); ``` ## 带状态的中间件 ```rust use axum::{ middleware::{self, Next}, extract::State, http::Request, response::Response, }; use std::sync::Arc; use tokio::sync::Mutex; #[derive(Clone)] struct RequestCounter { count: Arc>, } async fn counting_middleware( State(counter): State, request: Request, next: Next, ) -> Response { let mut count = counter.count.lock().await; *count += 1; println!("Request count: {}", *count); drop(count); next.run(request).await } let counter = RequestCounter { count: Arc::new(Mutex::new(0)), }; let app = Router::new() .route("/", get(handler)) .layer(middleware::from_fn_with_state( counter.clone(), counting_middleware )) .with_state(counter); ``` ## ServiceBuilder 组合多个中间件 ```rust use tower::ServiceBuilder; use tower_http::{ trace::TraceLayer, cors::CorsLayer, compression::CompressionLayer, timeout::TimeoutLayer, }; use std::time::Duration; let middleware_stack = ServiceBuilder::new() .layer(TraceLayer::new_for_http()) .layer(CorsLayer::permissive()) .layer(CompressionLayer::new()) .layer(TimeoutLayer::new(Duration::from_secs(30))); let app = Router::new() .route("/", get(handler)) .layer(middleware_stack); ``` ## 中间件执行顺序 中间件按照**从外到内**的顺序应用,但**从上到下**执行: ```rust // 执行顺序: // 请求: CatchPanic -> Trace -> Cors -> Compression -> Timeout -> Handler // 响应: Handler -> Timeout -> Compression -> Cors -> Trace -> CatchPanic let app = Router::new() .route("/", get(handler)) .layer(TimeoutLayer::new(Duration::from_secs(30))) // 最后处理请求 .layer(CompressionLayer::new()) .layer(CorsLayer::permissive()) .layer(TraceLayer::new_for_http()) .layer(CatchPanicLayer::new()); // 最先处理请求 ``` --- # 状态管理 ## 基础状态 ```rust use axum::extract::State; use std::sync::Arc; #[derive(Clone)] struct AppState { db_url: String, } async fn handler(State(state): State) -> String { format!("DB: {}", state.db_url) } let state = AppState { db_url: "postgresql://localhost/mydb".to_string(), }; let app = Router::new() .route("/", get(handler)) .with_state(state); ``` ## 使用 Arc 包装状态 对于大型数据结构或需要内部可变性的状态: ```rust use std::sync::Arc; use tokio::sync::RwLock; use std::collections::HashMap; #[derive(Clone)] struct AppState { // 数据库连接池 db: Arc>, // 缓存 cache: Arc>>, // 配置 config: Arc, } async fn handler(State(state): State) -> String { // 读操作 let cache = state.cache.read().await; if let Some(value) = cache.get("key") { return value.clone(); } drop(cache); // 写操作 let mut cache = state.cache.write().await; cache.insert("key".to_string(), "value".to_string()); "cached".to_string() } ``` ## State vs Extension | 特性 | State | Extension | |------|-------|-----------| | 推荐 | ✅ 推荐 | ❌ 已弃用 | | 类型安全 | 编译时检查 | 运行时检查 | | 性能 | 更高效 | 相对较低 | | 使用方式 | `with_state()` | `Extension()` layer | --- # 错误处理 ## 定义应用错误类型 ```rust use axum::{ http::StatusCode, response::{IntoResponse, Response}, Json, }; use serde::Serialize; #[derive(Debug)] pub enum AppError { NotFound(String), BadRequest(String), Unauthorized, Internal(String), } #[derive(Serialize)] struct ErrorResponse { error: String, message: String, } impl IntoResponse for AppError { fn into_response(self) -> Response { let (status, message) = match self { AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg), AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()), AppError::Internal(msg) => { eprintln!("Internal error: {}", msg); (StatusCode::INTERNAL_SERVER_ERROR, "Internal error".to_string()) } }; (status, Json(ErrorResponse { error: status.to_string(), message, })).into_response() } } ``` ## 使用 thiserror ```rust use thiserror::Error; #[derive(Error, Debug)] pub enum AppError { #[error("Not found: {0}")] NotFound(String), #[error("Bad request: {0}")] BadRequest(String), #[error("Database error: {0}")] Database(#[from] sqlx::Error), #[error("Internal error")] Internal, } impl IntoResponse for AppError { fn into_response(self) -> Response { let status = match &self { AppError::NotFound(_) => StatusCode::NOT_FOUND, AppError::BadRequest(_) => StatusCode::BAD_REQUEST, AppError::Database(_) => StatusCode::INTERNAL_SERVER_ERROR, AppError::Internal => StatusCode::INTERNAL_SERVER_ERROR, }; (status, self.to_string()).into_response() } } ``` ## 在 Handler 中使用 ```rust pub type Result = std::result::Result; async fn get_user(Path(id): Path) -> Result> { let user = find_user(id) .await .ok_or_else(|| AppError::NotFound(format!("User {} not found", id)))?; Ok(Json(user)) } async fn create_user(Json(req): Json) -> Result<(StatusCode, Json)> { if req.email.is_empty() { return Err(AppError::BadRequest("Email is required".to_string())); } let user = insert_user(req).await?; // 自动转换数据库错误 Ok((StatusCode::CREATED, Json(user))) } ``` --- # WebSocket ## 基本WebSocket处理 ```rust use axum::{ extract::ws::{WebSocket, WebSocketUpgrade, Message}, response::IntoResponse, routing::get, Router, }; async fn ws_handler(ws: WebSocketUpgrade) -> impl IntoResponse { ws.on_upgrade(handle_socket) } async fn handle_socket(mut socket: WebSocket) { while let Some(msg) = socket.recv().await { match msg { Ok(Message::Text(text)) => { println!("Received: {}", text); let _ = socket.send(Message::Text(format!("Echo: {}", text))).await; } Ok(Message::Binary(data)) => { let _ = socket.send(Message::Binary(data)).await; } Ok(Message::Ping(data)) => { let _ = socket.send(Message::Pong(data)).await; } Ok(Message::Close(_)) => { println!("Client disconnected"); break; } Err(e) => { eprintln!("Error: {}", e); break; } _ => {} } } } let app = Router::new() .route("/ws", get(ws_handler)); ``` ## 分离发送和接收 ```rust use futures::{SinkExt, StreamExt}; async fn handle_socket(socket: WebSocket) { let (mut tx, mut rx) = socket.split(); // 接收任务 let recv_task = tokio::spawn(async move { while let Some(msg) = rx.next().await { match msg { Ok(Message::Text(text)) => { println!("Received: {}", text); } Ok(Message::Close(_)) => break, _ => {} } } }); // 发送任务 let send_task = tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(1)); for i in 0..10 { interval.tick().await; if tx.send(Message::Text(format!("Tick {}", i))).await.is_err() { break; } } }); let _ = tokio::join!(recv_task, send_task); } ``` ## 广播模式 ```rust use tokio::sync::broadcast; use std::sync::Arc; #[derive(Clone)] struct ChatState { tx: broadcast::Sender, } async fn ws_handler( ws: WebSocketUpgrade, State(state): State>, ) -> impl IntoResponse { ws.on_upgrade(move |socket| handle_chat(socket, state)) } async fn handle_chat(socket: WebSocket, state: Arc) { let (mut tx, mut rx) = socket.split(); let mut broadcast_rx = state.tx.subscribe(); loop { tokio::select! { msg = rx.next() => { match msg { Some(Ok(Message::Text(text))) => { let _ = state.tx.send(text); } Some(Ok(Message::Close(_))) => break, _ => {} } } result = broadcast_rx.recv() => { if let Ok(msg) = result { if tx.send(Message::Text(msg)).await.is_err() { break; } } } } } } ``` ## 心跳处理 ```rust use std::time::{Duration, Instant}; async fn handle_socket_with_heartbeat(mut socket: WebSocket) { let mut last_pong = Instant::now(); let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(30)); loop { tokio::select! { _ = heartbeat_interval.tick() => { if last_pong.elapsed() > Duration::from_secs(90) { println!("Client timeout"); break; } if socket.send(Message::Ping(vec![])).await.is_err() { break; } } msg = socket.recv() => { match msg { Some(Ok(Message::Pong(_))) => { last_pong = Instant::now(); } Some(Ok(Message::Text(text))) => { // 处理消息 } Some(Ok(Message::Close(_))) => break, _ => {} } } } } } ``` --- # 测试 ## 使用 tower::ServiceExt ```rust use axum::{ Router, routing::get, http::StatusCode, body::Body, }; use tower::ServiceExt; use http::Request; #[tokio::test] async fn test_handler() { let app = Router::new() .route("/", get(|| async { "Hello" })); let request = Request::builder() .uri("/") .body(Body::empty()) .unwrap(); let response = app .oneshot(request) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); } ``` ## 使用 axum-test ```rust use axum_test::TestServer; #[tokio::test] async fn test_with_axum_test() { let app = create_app(); let server = TestServer::new(app).unwrap(); // GET 请求 let response = server.get("/users").await; response.assert_status_ok(); // POST 请求 let response = server .post("/users") .json(&serde_json::json!({ "name": "Alice", "email": "alice@example.com" })) .await; response.assert_status_created(); // 带Header let response = server .get("/protected") .add_header("Authorization", "Bearer token") .await; response.assert_status_ok(); } ``` ## 测试带状态的路由 ```rust #[tokio::test] async fn test_with_state() { let state = AppState { db: Arc::new(MockDatabase::new()), }; let app = Router::new() .route("/users", get(list_users)) .with_state(state); let server = TestServer::new(app).unwrap(); let response = server.get("/users").await; response.assert_status_ok(); } ``` --- # 项目结构最佳实践 ## 推荐目录结构 ``` project/ ├── src/ │ ├── main.rs # 入口点 │ ├── config.rs # 配置管理 │ ├── error.rs # 错误类型定义 │ ├── state.rs # 应用状态 │ │ │ ├── routes/ │ │ ├── mod.rs # 路由导出 │ │ ├── users.rs # 用户相关路由 │ │ └── posts.rs # 文章相关路由 │ │ │ ├── handlers/ │ │ ├── mod.rs │ │ └── users.rs │ │ │ ├── models/ │ │ ├── mod.rs │ │ └── user.rs │ │ │ ├── services/ │ │ ├── mod.rs │ │ └── user_service.rs │ │ │ └── middleware/ │ ├── mod.rs │ └── auth.rs │ ├── tests/ │ └── integration_test.rs │ └── Cargo.toml ``` ## 模块化路由 ```rust // src/routes/mod.rs use axum::Router; use crate::state::AppState; mod users; mod posts; pub fn create_routes() -> Router { Router::new() .nest("/api/users", users::routes()) .nest("/api/posts", posts::routes()) } // src/routes/users.rs use axum::{Router, routing::{get, post}}; use crate::handlers::users; use crate::state::AppState; pub fn routes() -> Router { Router::new() .route("/", get(users::list).post(users::create)) .route("/:id", get(users::get).put(users::update).delete(users::delete)) } ``` --- # 安全最佳实践 ## JWT 认证 ```rust use axum::{ async_trait, extract::FromRequestParts, http::request::Parts, }; use jsonwebtoken::{decode, DecodingKey, Validation}; #[derive(Debug, Serialize, Deserialize)] struct Claims { sub: String, exp: i64, } struct AuthUser(Claims); #[async_trait] impl FromRequestParts for AuthUser where S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts( parts: &mut Parts, _state: &S, ) -> Result { let auth_header = parts .headers .get("Authorization") .and_then(|h| h.to_str().ok()) .ok_or((StatusCode::UNAUTHORIZED, "Missing auth".to_string()))?; let token = auth_header .strip_prefix("Bearer ") .ok_or((StatusCode::UNAUTHORIZED, "Invalid format".to_string()))?; let claims = decode::( token, &DecodingKey::from_secret(b"secret"), &Validation::default(), ) .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".to_string()))? .claims; Ok(AuthUser(claims)) } } async fn protected_handler(AuthUser(claims): AuthUser) -> String { format!("Hello, {}", claims.sub) } ``` ## 输入验证 ```rust use validator::Validate; #[derive(Deserialize, Validate)] struct CreateUserRequest { #[validate(length(min = 1, max = 100))] name: String, #[validate(email)] email: String, #[validate(length(min = 8))] password: String, } async fn create_user( Json(req): Json ) -> Result { req.validate() .map_err(|e| AppError::BadRequest(e.to_string()))?; // 处理... Ok(StatusCode::CREATED) } ``` ## Rate Limiting ```rust use governor::{Quota, RateLimiter}; use std::num::NonZeroU32; async fn rate_limit_middleware( req: Request, next: Next, ) -> Result { // 检查速率限制 if limiter.check().is_err() { return Err(StatusCode::TOO_MANY_REQUESTS); } Ok(next.run(req).await) } ``` ## 安全 Headers ```rust use tower_http::set_header::SetResponseHeaderLayer; let app = Router::new() .route("/", get(handler)) .layer(SetResponseHeaderLayer::if_not_present( "X-Content-Type-Options", HeaderValue::from_static("nosniff"), )) .layer(SetResponseHeaderLayer::if_not_present( "X-Frame-Options", HeaderValue::from_static("DENY"), )) .layer(SetResponseHeaderLayer::if_not_present( "Strict-Transport-Security", HeaderValue::from_static("max-age=31536000; includeSubDomains"), )); ``` --- # 性能优化 ## 异步最佳实践 ```rust // 不好:阻塞操作 async fn bad_handler() -> String { std::thread::sleep(Duration::from_secs(1)); // 阻塞! "OK".to_string() } // 好:异步操作 async fn good_handler() -> String { tokio::time::sleep(Duration::from_secs(1)).await; "OK".to_string() } // 并发执行 async fn concurrent_handler() -> String { let (user, posts, comments) = tokio::join!( get_user(), get_posts(), get_comments(), ); format!("{:?}, {:?}, {:?}", user, posts, comments) } ``` ## 连接池配置 ```rust use sqlx::postgres::PgPoolOptions; let pool = PgPoolOptions::new() .min_connections(5) .max_connections(20) .connect_timeout(Duration::from_secs(30)) .idle_timeout(Duration::from_secs(600)) .connect(&database_url) .await?; ``` ## 优雅关闭 ```rust use tokio::signal; #[tokio::main] async fn main() { let app = create_app(); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); axum::serve(listener, app) .with_graceful_shutdown(async { signal::ctrl_c().await.expect("Failed to listen for Ctrl+C"); println!("Shutting down..."); }) .await .unwrap(); } ``` --- # 完整示例 ## RESTful API ```rust use axum::{ extract::{Path, State, Json}, routing::{get, post, put, delete}, http::StatusCode, Router, }; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::RwLock; use std::collections::HashMap; // 数据模型 #[derive(Clone, Serialize, Deserialize)] struct User { id: u32, name: String, email: String, } #[derive(Deserialize)] struct CreateUser { name: String, email: String, } // 应用状态 #[derive(Clone)] struct AppState { users: Arc>>, next_id: Arc>, } // 处理器 async fn list_users(State(state): State) -> Json> { let users = state.users.read().await; Json(users.values().cloned().collect()) } async fn create_user( State(state): State, Json(req): Json, ) -> (StatusCode, Json) { let mut users = state.users.write().await; let mut next_id = state.next_id.write().await; let id = *next_id; *next_id += 1; let user = User { id, name: req.name, email: req.email, }; users.insert(id, user.clone()); (StatusCode::CREATED, Json(user)) } async fn get_user( State(state): State, Path(id): Path, ) -> Result, StatusCode> { let users = state.users.read().await; users .get(&id) .cloned() .map(Json) .ok_or(StatusCode::NOT_FOUND) } async fn update_user( State(state): State, Path(id): Path, Json(req): Json, ) -> Result, StatusCode> { let mut users = state.users.write().await; if let Some(user) = users.get_mut(&id) { user.name = req.name; user.email = req.email; Ok(Json(user.clone())) } else { Err(StatusCode::NOT_FOUND) } } async fn delete_user( State(state): State, Path(id): Path, ) -> StatusCode { let mut users = state.users.write().await; if users.remove(&id).is_some() { StatusCode::NO_CONTENT } else { StatusCode::NOT_FOUND } } #[tokio::main] async fn main() { let state = AppState { users: Arc::new(RwLock::new(HashMap::new())), next_id: Arc::new(RwLock::new(1)), }; let app = Router::new() .route("/users", get(list_users).post(create_user)) .route("/users/:id", get(get_user).put(update_user).delete(delete_user)) .with_state(state); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("Server running on http://127.0.0.1:3000"); axum::serve(listener, app).await.unwrap(); } ``` --- # axum-extra 扩展功能 ## TypedHeader - 类型安全的请求头 ```rust use axum_extra::TypedHeader; use headers::{Authorization, authorization::Bearer, UserAgent}; async fn handler( TypedHeader(auth): TypedHeader>, TypedHeader(user_agent): TypedHeader, ) -> String { let token = auth.token(); format!("Token: {}, UA: {}", token, user_agent) } // Cargo.toml: // axum-extra = { version = "0.9", features = ["typed-header"] } // headers = "0.4" ``` ## CookieJar - Cookie 管理 ```rust use axum_extra::extract::cookie::{CookieJar, Cookie}; // 读取 Cookie async fn read_cookies(jar: CookieJar) -> String { let session = jar.get("session_id") .map(|c| c.value().to_string()) .unwrap_or_default(); format!("Session: {}", session) } // 设置 Cookie async fn set_cookie(jar: CookieJar) -> (CookieJar, &'static str) { let updated_jar = jar.add(Cookie::new("session_id", "abc123")); (updated_jar, "Cookie set") } // 删除 Cookie async fn remove_cookie(jar: CookieJar) -> (CookieJar, &'static str) { let updated_jar = jar.remove(Cookie::from("session_id")); (updated_jar, "Cookie removed") } // Cargo.toml: // axum-extra = { version = "0.9", features = ["cookie"] } ``` ## PrivateCookieJar - 加密 Cookie ```rust use axum_extra::extract::cookie::{PrivateCookieJar, Cookie, Key}; #[derive(Clone)] struct AppState { key: Key, } async fn handler(jar: PrivateCookieJar) -> (PrivateCookieJar, String) { // Cookie 值会自动加密/解密 let user_id = jar.get("user_id") .map(|c| c.value().to_string()) .unwrap_or_default(); let jar = jar.add(Cookie::new("user_id", "12345")); (jar, format!("User: {}", user_id)) } // Cargo.toml: // axum-extra = { version = "0.9", features = ["cookie-private"] } ``` --- # 静态文件服务 ## ServeFile - 单文件服务 ```rust use axum::{Router, routing::get_service}; use tower_http::services::ServeFile; let app = Router::new() // 访问 /favicon.ico 返回 static/favicon.ico .route_service("/favicon.ico", ServeFile::new("static/favicon.ico")) // 访问 /robots.txt 返回 robots.txt .route_service("/robots.txt", ServeFile::new("robots.txt")); ``` ## ServeDir - 目录服务 ```rust use axum::Router; use tower_http::services::ServeDir; let app = Router::new() // 访问 /static/* 映射到 ./static 目录 .nest_service("/static", ServeDir::new("static")) // 其他路由 .route("/api/users", get(list_users)); // 带 fallback 的静态文件服务 let serve_dir = ServeDir::new("static") .not_found_service(ServeFile::new("static/404.html")); let app = Router::new() .nest_service("/", serve_dir); ``` ## SPA (单页应用) 服务 ```rust use axum::Router; use tower_http::services::{ServeDir, ServeFile}; // SPA 模式:所有未匹配路由返回 index.html let app = Router::new() .route("/api/users", get(list_users)) .nest_service( "/", ServeDir::new("dist") .not_found_service(ServeFile::new("dist/index.html")) ); // Cargo.toml: // tower-http = { version = "0.5", features = ["fs"] } ``` --- # 高级提取器 ## ConnectInfo - 连接信息 ```rust use axum::extract::ConnectInfo; use std::net::SocketAddr; async fn handler( ConnectInfo(addr): ConnectInfo ) -> String { format!("Client IP: {}", addr) } // 需要在 serve 时启用 let app = Router::new().route("/", get(handler)); let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); axum::serve( listener, app.into_make_service_with_connect_info::() ).await.unwrap(); ``` ## OriginalUri - 原始 URI ```rust use axum::extract::OriginalUri; async fn handler(OriginalUri(uri): OriginalUri) -> String { // 在嵌套路由中获取完整原始 URI format!("Original URI: {}", uri) } // 嵌套路由示例 let api = Router::new() .route("/users", get(handler)); let app = Router::new() .nest("/api/v1", api); // 访问 /api/v1/users 时: // uri (普通) = /users // OriginalUri = /api/v1/users ``` ## MatchedPath - 匹配的路由模式 ```rust use axum::extract::MatchedPath; async fn handler(MatchedPath(path): MatchedPath) -> String { // 返回路由模式,而非实际路径 format!("Matched pattern: {}", path) } // GET /users/123 -> "Matched pattern: /users/:id" let app = Router::new() .route("/users/:id", get(handler)); ``` ## Host - 主机名提取 ```rust use axum::extract::Host; async fn handler(Host(host): Host) -> String { format!("Host: {}", host) } // 根据域名路由 async fn virtual_host(Host(host): Host) -> String { match host.as_str() { "api.example.com" => "API endpoint".to_string(), "www.example.com" => "Web endpoint".to_string(), _ => "Unknown host".to_string(), } } ``` ## RawQuery - 原始查询字符串 ```rust use axum::extract::RawQuery; async fn handler(RawQuery(query): RawQuery) -> String { match query { Some(q) => format!("Query string: {}", q), None => "No query string".to_string(), } } // GET /search?foo=bar&baz=qux -> "Query string: foo=bar&baz=qux" ``` --- # debug_handler 宏 `#[debug_handler]` 提供更友好的编译错误信息。 ## 基本用法 ```rust use axum::debug_handler; #[debug_handler] async fn handler() -> &'static str { "Hello" } // 如果 handler 签名有问题,会给出更清晰的错误提示 ``` ## 带状态的 debug_handler ```rust use axum::{debug_handler, extract::State}; #[derive(Clone)] struct AppState { inner: InnerState, } #[derive(Clone)] struct InnerState {} impl axum::extract::FromRef for InnerState { fn from_ref(state: &AppState) -> Self { state.inner.clone() } } // 指定状态类型 #[debug_handler(state = AppState)] async fn handler( State(app_state): State, State(inner_state): State, ) -> &'static str { "Hello" } // Cargo.toml: // axum-macros = "0.4" ``` --- # FromRef 嵌套状态 当应用状态包含多个子状态时,使用 `FromRef` 提取子状态。 ## 手动实现 FromRef ```rust use axum::extract::{State, FromRef}; use axum::{Router, routing::get}; #[derive(Clone)] struct AppState { db: DatabasePool, cache: CacheClient, } #[derive(Clone)] struct DatabasePool { /* ... */ } #[derive(Clone)] struct CacheClient { /* ... */ } // 为子状态实现 FromRef impl FromRef for DatabasePool { fn from_ref(state: &AppState) -> Self { state.db.clone() } } impl FromRef for CacheClient { fn from_ref(state: &AppState) -> Self { state.cache.clone() } } // Handler 可以只提取需要的子状态 async fn db_handler(State(db): State) -> String { "Using database".to_string() } async fn cache_handler(State(cache): State) -> String { "Using cache".to_string() } async fn full_handler(State(state): State) -> String { "Using full state".to_string() } let state = AppState { db: DatabasePool {}, cache: CacheClient {}, }; let app = Router::new() .route("/db", get(db_handler)) .route("/cache", get(cache_handler)) .route("/full", get(full_handler)) .with_state(state); ``` ## 使用 derive 宏自动实现 ```rust use axum::extract::FromRef; #[derive(Clone, FromRef)] struct AppState { db: DatabasePool, cache: CacheClient, #[from_ref(skip)] // 跳过此字段 secret_key: String, } // 现在 DatabasePool 和 CacheClient 自动可以从 AppState 提取 // Cargo.toml: // axum = { version = "0.7", features = ["macros"] } ``` ## 嵌套路由的不同状态 ```rust use axum::{Router, routing::get, extract::State}; #[derive(Clone)] struct InnerState { value: String, } #[derive(Clone)] struct OuterState { name: String, } async fn inner_handler(State(state): State) -> String { state.value.clone() } async fn outer_handler(State(state): State) -> String { state.name.clone() } // 内部路由有自己的状态 let inner_router = Router::new() .route("/", get(inner_handler)) .with_state(InnerState { value: "inner".to_string() }); // 外部路由有不同的状态 let app = Router::new() .route("/", get(outer_handler)) .nest("/inner", inner_router) .with_state(OuterState { name: "outer".to_string() }); ``` --- # 参考资源 - 官方文档: https://docs.rs/axum/latest/axum/ - GitHub: https://github.com/tokio-rs/axum - Tower: https://docs.rs/tower/latest/tower/ - Tower HTTP: https://docs.rs/tower-http/latest/tower_http/ - Tokio: https://tokio.rs/ - Axum Examples: https://github.com/tokio-rs/axum/tree/main/examples