Rust Axum Middleware Authentication: 6 Production Patterns from JWT to RBAC

编程语言

Why Is Authentication in Rust Web Services Always a Minefield

You deployed an Axum endpoint with zero auth; you added JWT validation, only to find expired tokens give users raw 401s; you wanted RBAC, but the middleware can't access user roles; you added rate limiting, and unauthenticated requests are burning through your quota. In 2026, Axum 0.8 provides FromRequestParts extractors, Tower Layer middleware, and type-safe state management — but authentication is never something the framework can do for you.

This article starts from JWT validation and walks you through JWT auth → API Key auth → RBAC → Rate limiting → Session management → Production auth service — 6 production patterns to take Axum auth from "it runs" to "it withstands".


Core Concepts

Concept Description
FromRequestParts Axum extractor trait, extracts auth info from request Parts
JWT (JSON Web Token) Stateless token containing signed user identity and expiration data
API Key Key passed via Header, suitable for service-to-service calls
RBAC Role-Based Access Control, user → role → permission three-level model
Tower Layer Middleware abstraction layer for composing cross-cutting concerns
Rate Limiting Throttling to prevent API abuse
Session Stateful session, server-side login state storage (Redis etc.)
Claims Declaration data in JWT payload (sub/exp/role etc.)

Auth Request Flow

Auth Request Flow:
1. Client sends request with Authorization Header or API Key
2. Middleware/Extractor extracts authentication credentials
3. Validates credential validity (signature verification/expiry check/key match)
4. Builds user context (UserId/Role/Permissions)
5. RBAC middleware checks if user has access to current route
6. Rate limiting middleware checks request frequency
7. Handler executes business logic, can access user context via State
8. Response returned to client

Problem Analysis: 5 Major Challenges in Axum Auth Development

  1. JWT validation disconnected from user context: Middleware validates the token, but Handlers can't access user info — you end up parsing it again
  2. Multiple auth methods hard to coexist: JWT and API Key both need support, middleware becomes if-else spaghetti
  3. RBAC permission model design chaos: Roles and permissions hardcoded as strings, adding a permission means changing dozens of places
  4. Rate limiting vs auth ordering conflict: Rate limiting before auth wastes quota on unauthenticated requests; after auth, malicious requests hit the auth layer directly
  5. Session management lacks production solutions: JWT is stateless but can't be revoked; Session is stateful but Redis connection pools and expiry cleanup are pitfalls

Step-by-Step: 6 Production Auth Patterns

Pattern 1: JWT Authentication Middleware with FromRequestParts

use axum::extract::{FromRequestParts, Request};
use axum::http::request::Parts;
use axum::response::{IntoResponse, Response};
use axum::middleware::{Next, from_fn};
use axum::{Json, Router, middleware, routing::get};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use chrono::{Utc, Duration};
use std::sync::Arc;

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
    pub sub: String,
    pub role: String,
    pub exp: i64,
    pub iat: i64,
}

#[derive(Clone)]
pub struct AuthConfig {
    pub jwt_secret: String,
    pub jwt_expiration_hours: i64,
}

pub struct AppState {
    pub auth_config: AuthConfig,
    pub db: DbPool,
}

pub type SharedState = Arc<AppState>;

impl Claims {
    pub fn new(user_id: &str, role: &str, expiration_hours: i64) -> Self {
        let now = Utc::now();
        Self {
            sub: user_id.to_string(),
            role: role.to_string(),
            iat: now.timestamp(),
            exp: (now + Duration::hours(expiration_hours)).timestamp(),
        }
    }

    pub fn encode(&self, secret: &str) -> Result<String, jsonwebtoken::errors::Error> {
        encode(
            &Header::default(),
            self,
            &EncodingKey::from_secret(secret.as_bytes()),
        )
    }

    pub fn decode(token: &str, secret: &str) -> Result<Self, jsonwebtoken::errors::Error> {
        let token_data = decode::<Claims>(
            token,
            &DecodingKey::from_secret(secret.as_bytes()),
            &Validation::default(),
        )?;
        Ok(token_data.claims)
    }
}
use axum::extract::FromRequestParts;
use axum::http::StatusCode;

pub struct AuthUser {
    pub user_id: String,
    pub role: String,
}

#[axum::async_trait]
impl FromRequestParts<SharedState> for AuthUser {
    type Rejection = AuthError;

    async fn from_request_parts(
        parts: &mut Parts,
        state: &SharedState,
    ) -> Result<Self, Self::Rejection> {
        let auth_header = parts
            .headers
            .get("authorization")
            .and_then(|v| v.to_str().ok())
            .ok_or(AuthError::MissingToken)?;

        let token = auth_header
            .strip_prefix("Bearer ")
            .ok_or(AuthError::InvalidTokenFormat)?;

        let claims = Claims::decode(token, &state.auth_config.jwt_secret)
            .map_err(|_| AuthError::InvalidToken)?;

        Ok(AuthUser {
            user_id: claims.sub,
            role: claims.role,
        })
    }
}

async fn protected_handler(
    auth_user: AuthUser,
) -> Result<Json<serde_json::Value>, AuthError> {
    Ok(Json(serde_json::json!({
        "user_id": auth_user.user_id,
        "role": auth_user.role,
    })))
}

pub fn auth_router() -> Router<SharedState> {
    Router::new()
        .route("/me", get(protected_handler))
        .route("/refresh", get(refresh_token))
}

async fn refresh_token(
    auth_user: AuthUser,
    State(state): State<SharedState>,
) -> Result<Json<serde_json::Value>, AuthError> {
    let claims = Claims::new(
        &auth_user.user_id,
        &auth_user.role,
        state.auth_config.jwt_expiration_hours,
    );
    let token = claims.encode(&state.auth_config.jwt_secret)?;
    Ok(Json(serde_json::json!({ "token": token })))
}

Pattern 2: API Key Authentication with Header Extraction

use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use std::collections::HashMap;
use tokio::sync::RwLock;

#[derive(Debug, Clone)]
pub struct ApiKeyInfo {
    pub key_id: String,
    pub client_name: String,
    pub permissions: Vec<String>,
    pub rate_limit: u32,
}

pub struct ApiKeyState {
    pub keys: Arc<RwLock<HashMap<String, ApiKeyInfo>>>,
}

pub struct AuthenticatedApi {
    pub key_info: ApiKeyInfo,
}

#[axum::async_trait]
impl FromRequestParts<SharedState> for AuthenticatedApi {
    type Rejection = AuthError;

    async fn from_request_parts(
        parts: &mut Parts,
        state: &SharedState,
    ) -> Result<Self, Self::Rejection> {
        let api_key = parts
            .headers
            .get("x-api-key")
            .and_then(|v| v.to_str().ok())
            .ok_or(AuthError::MissingApiKey)?;

        let keys = state.api_key_state.keys.read().await;
        let key_info = keys
            .get(api_key)
            .ok_or(AuthError::InvalidApiKey)?
            .clone();

        Ok(AuthenticatedApi { key_info })
    }
}

async fn api_endpoint(
    auth: AuthenticatedApi,
) -> Result<Json<serde_json::Value>, AuthError> {
    Ok(Json(serde_json::json!({
        "client": auth.key_info.client_name,
        "permissions": auth.key_info.permissions,
    })))
}
use axum::extract::FromRequestParts;
use axum::http::request::Parts;

pub enum AuthMethod {
    Jwt(AuthUser),
    ApiKey(AuthenticatedApi),
}

pub struct MultiAuth;

#[axum::async_trait]
impl FromRequestParts<SharedState> for AuthMethod {
    type Rejection = AuthError;

    async fn from_request_parts(
        parts: &mut Parts,
        state: &SharedState,
    ) -> Result<Self, Self::Rejection> {
        if parts.headers.contains_key("x-api-key") {
            let api_auth = AuthenticatedApi::from_request_parts(parts, state).await?;
            Ok(AuthMethod::ApiKey(api_auth))
        } else if parts.headers.contains_key("authorization") {
            let jwt_auth = AuthUser::from_request_parts(parts, state).await?;
            Ok(AuthMethod::Jwt(jwt_auth))
        } else {
            Err(AuthError::NoAuthMethod)
        }
    }
}

async fn multi_auth_handler(
    auth: AuthMethod,
) -> Json<serde_json::Value> {
    match auth {
        AuthMethod::Jwt(user) => Json(serde_json::json!({
            "method": "jwt",
            "user_id": user.user_id,
        })),
        AuthMethod::ApiKey(api) => Json(serde_json::json!({
            "method": "api_key",
            "client": api.key_info.client_name,
        })),
    }
}

Pattern 3: RBAC Role-Based Access Control with Enum Permissions

use strum::{Display, EnumString};
use std::collections::HashSet;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Display, EnumString)]
pub enum Permission {
    #[strum(to_string = "users:read")]
    UsersRead,
    #[strum(to_string = "users:write")]
    UsersWrite,
    #[strum(to_string = "users:delete")]
    UsersDelete,
    #[strum(to_string = "products:read")]
    ProductsRead,
    #[strum(to_string = "products:write")]
    ProductsWrite,
    #[strum(to_string = "orders:read")]
    OrdersRead,
    #[strum(to_string = "orders:write")]
    OrdersWrite,
    #[strum(to_string = "admin:full")]
    AdminFull,
}

#[derive(Debug, Clone, PartialEq, Eq, Display, EnumString)]
pub enum Role {
    #[strum(to_string = "viewer")]
    Viewer,
    #[strum(to_string = "editor")]
    Editor,
    #[strum(to_string = "admin")]
    Admin,
    #[strum(to_string = "superadmin")]
    SuperAdmin,
}

impl Role {
    pub fn permissions(&self) -> HashSet<Permission> {
        match self {
            Role::Viewer => HashSet::from([
                Permission::UsersRead,
                Permission::ProductsRead,
                Permission::OrdersRead,
            ]),
            Role::Editor => HashSet::from([
                Permission::UsersRead,
                Permission::ProductsRead,
                Permission::ProductsWrite,
                Permission::OrdersRead,
                Permission::OrdersWrite,
            ]),
            Role::Admin => HashSet::from([
                Permission::UsersRead,
                Permission::UsersWrite,
                Permission::ProductsRead,
                Permission::ProductsWrite,
                Permission::OrdersRead,
                Permission::OrdersWrite,
            ]),
            Role::SuperAdmin => HashSet::from([Permission::AdminFull]),
        }
    }

    pub fn has_permission(&self, permission: &Permission) -> bool {
        let perms = self.permissions();
        perms.contains(&Permission::AdminFull) || perms.contains(permission)
    }
}
use axum::extract::{FromRequestParts, Path};
use axum::http::request::Parts;

pub struct RequirePermission(pub Permission);

#[axum::async_trait]
impl FromRequestParts<SharedState> for RequirePermission {
    type Rejection = AuthError;

    async fn from_request_parts(
        parts: &mut Parts,
        state: &SharedState,
    ) -> Result<Self, Self::Rejection> {
        let auth_user = AuthUser::from_request_parts(parts, state).await?;
        let role: Role = auth_user.role.parse().map_err(|_| AuthError::InvalidRole)?;

        let required = Permission::UsersWrite;
        if !role.has_permission(&required) {
            return Err(AuthError::Forbidden(
                format!("Missing permission: {}", required),
            ));
        }

        Ok(RequirePermission(required))
    }
}

async fn delete_user_handler(
    auth_user: AuthUser,
    Path(user_id): Path<String>,
) -> Result<StatusCode, AuthError> {
    let role: Role = auth_user.role.parse().map_err(|_| AuthError::InvalidRole)?;
    if !role.has_permission(&Permission::UsersDelete) {
        return Err(AuthError::Forbidden("Missing users:delete permission".into()));
    }
    tracing::info!("Deleting user: {}", user_id);
    Ok(StatusCode::NO_CONTENT)
}

Pattern 4: Rate Limiting Middleware with Tower Layer

use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;

#[derive(Clone)]
pub struct RateLimitConfig {
    pub max_requests: u32,
    pub window_secs: u64,
}

#[derive(Debug)]
struct RateLimitEntry {
    count: u32,
    window_start: Instant,
}

#[derive(Clone)]
pub struct RateLimitState {
    pub entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
    pub config: RateLimitConfig,
}

impl RateLimitState {
    pub fn new(config: RateLimitConfig) -> Self {
        Self {
            entries: Arc::new(RwLock::new(HashMap::new())),
            config,
        }
    }

    pub async fn check_rate_limit(&self, key: &str) -> bool {
        let mut entries = self.entries.write().await;
        let now = Instant::now();

        let entry = entries.entry(key.to_string()).or_insert(RateLimitEntry {
            count: 0,
            window_start: now,
        });

        if now.duration_since(entry.window_start).as_secs() > self.config.window_secs {
            entry.count = 0;
            entry.window_start = now;
        }

        entry.count += 1;
        entry.count <= self.config.max_requests
    }
}

pub async fn rate_limit_middleware(
    request: Request,
    next: Next,
) -> Result<Response, AuthError> {
    let state = request
        .extensions()
        .get::<RateLimitState>()
        .cloned()
        .ok_or(AuthError::InternalServerError("Rate limit state not found".into()))?;

    let key = request
        .headers()
        .get("x-api-key")
        .and_then(|v| v.to_str().ok())
        .or_else(|| {
            request.headers()
                .get("x-forwarded-for")
                .and_then(|v| v.to_str().ok())
        })
        .unwrap_or("anonymous")
        .to_string();

    if !state.check_rate_limit(&key).await {
        return Err(AuthError::RateLimited);
    }

    Ok(next.run(request).await)
}
use axum::Router;
use axum::middleware;
use tower::ServiceBuilder;
use tower_http::limit::RateLimitLayer;
use std::time::Duration;

pub fn create_rate_limited_router(state: SharedState) -> Router<SharedState> {
    let rate_limit_state = RateLimitState::new(RateLimitConfig {
        max_requests: 100,
        window_secs: 60,
    });

    Router::new()
        .route("/api/data", get(data_handler))
        .layer(middleware::from_fn(rate_limit_middleware))
        .layer(middleware::from_fn(auth_middleware))
        .with_state(state)
}

pub fn create_tower_rate_limited_router(state: SharedState) -> Router<SharedState> {
    Router::new()
        .route("/api/data", get(data_handler))
        .layer(
            ServiceBuilder::new()
                .layer(RateLimitLayer::new(100, Duration::from_secs(60)))
                .into_inner(),
        )
        .with_state(state)
}

async fn data_handler() -> &'static str {
    "OK"
}

Pattern 5: Session Management with Redis Backend

use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
    pub session_id: String,
    pub user_id: String,
    pub role: String,
    pub created_at: i64,
    pub expires_at: i64,
    pub ip_address: Option<String>,
    pub user_agent: Option<String>,
}

impl Session {
    pub fn new(user_id: &str, role: &str, ttl_secs: i64, ip: Option<&str>, ua: Option<&str>) -> Self {
        let now = Utc::now().timestamp();
        Self {
            session_id: Uuid::new_v4().to_string(),
            user_id: user_id.to_string(),
            role: role.to_string(),
            created_at: now,
            expires_at: now + ttl_secs,
            ip_address: ip.map(|s| s.to_string()),
            user_agent: ua.map(|s| s.to_string()),
        }
    }

    pub fn is_expired(&self) -> bool {
        Utc::now().timestamp() > self.expires_at
    }
}

#[derive(Clone)]
pub struct SessionStore {
    pub client: redis::Client,
    pub ttl_secs: i64,
    pub key_prefix: String,
}

impl SessionStore {
    pub fn new(redis_url: &str, ttl_secs: i64) -> Result<Self, redis::RedisError> {
        Ok(Self {
            client: redis::Client::open(redis_url)?,
            ttl_secs,
            key_prefix: "session:".to_string(),
        })
    }

    pub async fn create_session(&self, session: &Session) -> Result<(), redis::RedisError> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        let key = format!("{}{}", self.key_prefix, session.session_id);
        let value = serde_json::to_string(session).unwrap();
        conn.set_ex(&key, value, self.ttl_secs as u64).await?;
        Ok(())
    }

    pub async fn get_session(&self, session_id: &str) -> Result<Option<Session>, redis::RedisError> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        let key = format!("{}{}", self.key_prefix, session_id);
        let value: Option<String> = conn.get(&key).await?;
        match value {
            Some(v) => Ok(Some(serde_json::from_str(&v).unwrap())),
            None => Ok(None),
        }
    }

    pub async fn delete_session(&self, session_id: &str) -> Result<(), redis::RedisError> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        let key = format!("{}{}", self.key_prefix, session_id);
        conn.del(&key).await?;
        Ok(())
    }

    pub async fn refresh_session(&self, session_id: &str) -> Result<(), redis::RedisError> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        let key = format!("{}{}", self.key_prefix, session_id);
        conn.expire(&key, self.ttl_secs as i64).await?;
        Ok(())
    }
}
use axum::extract::{FromRequestParts, Request};
use axum::http::request::Parts;
use axum::middleware::Next;
use axum::response::Response;
use axum::Extension;

pub struct SessionUser {
    pub session: Session,
}

#[axum::async_trait]
impl FromRequestParts<SharedState> for SessionUser {
    type Rejection = AuthError;

    async fn from_request_parts(
        parts: &mut Parts,
        state: &SharedState,
    ) -> Result<Self, Self::Rejection> {
        let cookie_header = parts
            .headers
            .get("cookie")
            .and_then(|v| v.to_str().ok())
            .ok_or(AuthError::MissingSession)?;

        let session_id = extract_session_id(cookie_header)
            .ok_or(AuthError::MissingSession)?;

        let session = state
            .session_store
            .get_session(&session_id)
            .await
            .map_err(|_| AuthError::InternalServerError("Redis error".into()))?
            .ok_or(AuthError::SessionExpired)?;

        if session.is_expired() {
            state.session_store.delete_session(&session_id).await.ok();
            return Err(AuthError::SessionExpired);
        }

        state.session_store.refresh_session(&session_id).await.ok();

        Ok(SessionUser { session })
    }
}

fn extract_session_id(cookie_header: &str) -> Option<String> {
    cookie_header
        .split(';')
        .find_map(|cookie| {
            let mut parts = cookie.trim().splitn(2, '=');
            let name = parts.next()?.trim();
            if name == "session_id" {
                Some(parts.next()?.trim().to_string())
            } else {
                None
            }
        })
}

async fn login_handler(
    State(state): State<SharedState>,
    Json(payload): Json<LoginRequest>,
) -> Result<Json<serde_json::Value>, AuthError> {
    let user = state.db.verify_user(&payload.username, &payload.password).await
        .ok_or(AuthError::InvalidCredentials)?;

    let session = Session::new(
        &user.id,
        &user.role,
        state.session_store.ttl_secs,
        None,
        None,
    );

    state.session_store.create_session(&session).await
        .map_err(|_| AuthError::InternalServerError("Failed to create session".into()))?;

    Ok(Json(serde_json::json!({
        "session_id": session.session_id,
        "expires_at": session.expires_at,
    })))
}

async fn logout_handler(
    session_user: SessionUser,
    State(state): State<SharedState>,
) -> Result<StatusCode, AuthError> {
    state.session_store
        .delete_session(&session_user.session.session_id)
        .await
        .map_err(|_| AuthError::InternalServerError("Failed to delete session".into()))?;
    Ok(StatusCode::NO_CONTENT)
}

Pattern 6: Production Auth Service Combining All Patterns

use axum::{Router, routing::{get, post}, middleware, extract::State};
use tower::ServiceBuilder;
use tower_http::cors::{CorsLayer, Any};
use tower_http::trace::TraceLayer;
use std::time::Duration;

pub struct AppState {
    pub auth_config: AuthConfig,
    pub api_key_state: ApiKeyState,
    pub rate_limit_state: RateLimitState,
    pub session_store: SessionStore,
    pub db: DbPool,
}

pub type SharedState = Arc<AppState>;

pub fn create_production_router(state: SharedState) -> Router {
    let public_routes = Router::new()
        .route("/auth/login", post(login_handler))
        .route("/auth/register", post(register_handler))
        .route("/health", get(health_check));

    let jwt_protected = Router::new()
        .route("/users/me", get(get_profile))
        .route("/users/me", post(update_profile))
        .route("/auth/refresh", get(refresh_token))
        .layer(middleware::from_fn_with_state(
            state.clone(),
            jwt_auth_middleware,
        ));

    let api_key_routes = Router::new()
        .route("/api/v1/data", get(data_handler))
        .route("/api/v1/reports", get(reports_handler))
        .layer(middleware::from_fn_with_state(
            state.clone(),
            api_key_auth_middleware,
        ));

    let admin_routes = Router::new()
        .route("/admin/users", get(list_users_handler))
        .route("/admin/users/{id}", post(delete_user_handler))
        .layer(middleware::from_fn_with_state(
            state.clone(),
            admin_auth_middleware,
        ));

    let session_routes = Router::new()
        .route("/dashboard", get(dashboard_handler))
        .route("/auth/logout", post(logout_handler))
        .layer(middleware::from_fn_with_state(
            state.clone(),
            session_auth_middleware,
        ));

    Router::new()
        .merge(public_routes)
        .nest("/api", jwt_protected)
        .nest("/external", api_key_routes)
        .nest("/manage", admin_routes)
        .nest("/web", session_routes)
        .layer(
            ServiceBuilder::new()
                .layer(TraceLayer::new_for_http())
                .layer(CorsLayer::new().allow_origin(Any).allow_methods(Any).allow_headers(Any))
                .into_inner(),
        )
        .with_state(state)
}

async fn jwt_auth_middleware(
    State(state): State<SharedState>,
    mut request: Request,
    next: Next,
) -> Result<Response, AuthError> {
    let auth_header = request.headers()
        .get("authorization")
        .and_then(|v| v.to_str().ok())
        .ok_or(AuthError::MissingToken)?;

    let token = auth_header
        .strip_prefix("Bearer ")
        .ok_or(AuthError::InvalidTokenFormat)?;

    let claims = Claims::decode(token, &state.auth_config.jwt_secret)
        .map_err(|_| AuthError::InvalidToken)?;

    request.extensions_mut().insert(AuthUser {
        user_id: claims.sub.clone(),
        role: claims.role.clone(),
    });

    Ok(next.run(request).await)
}

async fn api_key_auth_middleware(
    State(state): State<SharedState>,
    mut request: Request,
    next: Next,
) -> Result<Response, AuthError> {
    let api_key = request.headers()
        .get("x-api-key")
        .and_then(|v| v.to_str().ok())
        .ok_or(AuthError::MissingApiKey)?;

    let keys = state.api_key_state.keys.read().await;
    let key_info = keys.get(api_key).ok_or(AuthError::InvalidApiKey)?.clone();

    if !state.rate_limit_state.check_rate_limit(&key_info.key_id).await {
        return Err(AuthError::RateLimited);
    }

    request.extensions_mut().insert(key_info);
    Ok(next.run(request).await)
}

async fn admin_auth_middleware(
    State(state): State<SharedState>,
    mut request: Request,
    next: Next,
) -> Result<Response, AuthError> {
    let auth_header = request.headers()
        .get("authorization")
        .and_then(|v| v.to_str().ok())
        .ok_or(AuthError::MissingToken)?;

    let token = auth_header.strip_prefix("Bearer ").ok_or(AuthError::InvalidTokenFormat)?;
    let claims = Claims::decode(token, &state.auth_config.jwt_secret)
        .map_err(|_| AuthError::InvalidToken)?;

    let role: Role = claims.role.parse().map_err(|_| AuthError::InvalidRole)?;
    if !role.has_permission(&Permission::UsersWrite) {
        return Err(AuthError::Forbidden("Admin access required".into()));
    }

    request.extensions_mut().insert(AuthUser {
        user_id: claims.sub,
        role: claims.role,
    });

    Ok(next.run(request).await)
}

async fn session_auth_middleware(
    State(state): State<SharedState>,
    mut request: Request,
    next: Next,
) -> Result<Response, AuthError> {
    let cookie_header = request.headers()
        .get("cookie")
        .and_then(|v| v.to_str().ok())
        .ok_or(AuthError::MissingSession)?;

    let session_id = extract_session_id(cookie_header).ok_or(AuthError::MissingSession)?;
    let session = state.session_store.get_session(&session_id).await
        .map_err(|_| AuthError::InternalServerError("Redis error".into()))?
        .ok_or(AuthError::SessionExpired)?;

    if session.is_expired() {
        state.session_store.delete_session(&session_id).await.ok();
        return Err(AuthError::SessionExpired);
    }

    state.session_store.refresh_session(&session_id).await.ok();
    request.extensions_mut().insert(session);
    Ok(next.run(request).await)
}

Pitfall Guide

Pitfall 1: JWT Secret Hardcoded

// ❌ Wrong: Secret hardcoded in source code
let secret = "my-super-secret-key-123";

// ✅ Correct: Read from environment variable, validate at startup
let secret = std::env::var("JWT_SECRET")
    .expect("JWT_SECRET must be set");
if secret.len() < 32 {
    panic!("JWT_SECRET must be at least 32 characters");
}

Pitfall 2: State Type Mismatch in Middleware

// ❌ Wrong: from_fn cannot access State, compile error
.layer(middleware::from_fn(my_auth_middleware))

async fn my_auth_middleware(
    State(state): State<SharedState>, // from_fn doesn't support State parameter!
    request: Request,
    next: Next,
) -> Result<Response, AuthError> { ... }

// ✅ Correct: Use from_fn_with_state
.layer(middleware::from_fn_with_state(state.clone(), my_auth_middleware))

async fn my_auth_middleware(
    State(state): State<SharedState>,
    request: Request,
    next: Next,
) -> Result<Response, AuthError> { ... }

Pitfall 3: JWT Claims Missing Expiration Validation

// ❌ Wrong: Manually creating Validation without exp check
let validation = Validation::new(jsonwebtoken::Algorithm::HS256);
// Default checks exp, but manual construction may miss it

// ✅ Correct: Explicitly configure Validation
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
validation.leeway = 60; // Allow 60s clock skew
validation.validate_exp = true;
validation.validate_nbf = true;
let token_data = decode::<Claims>(token, &key, &validation)?;

Pitfall 4: RBAC with String Matching

// ❌ Wrong: String hardcoding, typos not caught at compile time
if user.role == "admin" || user.permissions.contains(&"users:write".to_string()) {
    // Typos won't be caught by the compiler
}

// ✅ Correct: Use enum + strum, compile-time permission name safety
#[derive(EnumString)]
pub enum Permission {
    #[strum(to_string = "users:write")]
    UsersWrite,
}

let perm = Permission::UsersWrite;
if role.has_permission(&perm) {
    // Compile-time safe
}

Pitfall 5: Redis Connection Not Reused

// ❌ Wrong: Creating new connection per request
async fn get_session(session_id: &str) -> Option<Session> {
    let client = redis::Client::open("redis://localhost").unwrap();
    let mut conn = client.get_async_connection().await.unwrap(); // New connection each time!
    conn.get(session_id).await.ok()
}

// ✅ Correct: Use multiplexed connection reuse
pub struct SessionStore {
    client: redis::Client,
}

impl SessionStore {
    pub async fn get_session(&self, session_id: &str) -> Result<Option<Session>, redis::RedisError> {
        let mut conn = self.client.get_multiplexed_async_connection().await?;
        let value: Option<String> = conn.get(session_id).await?;
        // ...
    }
}

Error Troubleshooting

# Error Cause Solution
1 the trait FromRequestParts is not implemented for AuthUser FromRequestParts trait not implemented Use #[axum::async_trait] to implement the trait, ensure Rejection type implements IntoResponse
2 mismatched types: expected State<X>, found State<Y> Middleware and Router State type mismatch Use type SharedState = Arc<AppState> alias consistently
3 JWT decode error: InvalidToken Token format error or secret mismatch Check Bearer prefix, secret consistency, algorithm match
4 JWT decode error: ExpiredSignature Token has expired Implement refresh token mechanism, auto-renewal on frontend
5 future cannot be sent between threads safely Non-Send type across await point Use get_multiplexed_async_connection() instead of get_async_connection()
6 missing field 'exp' in Claims Claims struct missing exp field Ensure Claims includes exp and iat fields, Validation checks by default
7 Rate limit state not found in extensions Extension not injected Inject via .layer(Extension(state)) on middleware or Router
8 Redis: Connection refused Redis service not running Check Redis service status and connection URL config
9 handler has too many arguments Handler exceeds 4 Extractor parameters Merge with structs or use Extension to pass auth info
10 Cannot drop a runtime in a context that is already inside a runtime Synchronous Redis connection in async function Use get_multiplexed_async_connection().await for async connection

Advanced Optimization

1. JWT Blacklist and Active Revocation

use std::collections::HashSet;
use tokio::sync::RwLock;

#[derive(Clone)]
pub struct JwtBlacklist {
    pub revoked_tokens: Arc<RwLock<HashSet<String>>>,
    pub redis_client: Option<redis::Client>,
}

impl JwtBlacklist {
    pub fn new(redis_url: Option<&str>) -> Result<Self, redis::RedisError> {
        let client = redis_url
            .map(redis::Client::open)
            .transpose()?;
        Ok(Self {
            revoked_tokens: Arc::new(RwLock::new(HashSet::new())),
            redis_client: client,
        })
    }

    pub async fn revoke_token(&self, jti: &str, exp_secs: i64) -> Result<(), AuthError> {
        if let Some(client) = &self.redis_client {
            let mut conn = client.get_multiplexed_async_connection().await
                .map_err(|_| AuthError::InternalServerError("Redis connection failed".into()))?;
            let key = format!("jwt:blacklist:{}", jti);
            redis::cmd("SET")
                .arg(&key)
                .arg("1")
                .arg("EX")
                .arg(exp_secs)
                .exec_async(&mut conn)
                .await
                .map_err(|_| AuthError::InternalServerError("Redis SET failed".into()))?;
        } else {
            let mut tokens = self.revoked_tokens.write().await;
            tokens.insert(jti.to_string());
        }
        Ok(())
    }

    pub async fn is_revoked(&self, jti: &str) -> bool {
        if let Some(client) = &self.redis_client {
            if let Ok(mut conn) = client.get_multiplexed_async_connection().await {
                let key = format!("jwt:blacklist:{}", jti);
                if let Ok(exists) = redis::cmd("EXISTS")
                    .arg(&key)
                    .query_async::<i32>(&mut conn)
                    .await
                {
                    return exists > 0;
                }
            }
        }
        let tokens = self.revoked_tokens.read().await;
        tokens.contains(jti)
    }
}

2. Auth Middleware Performance Optimization

use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use moka::sync::Cache;
use std::sync::Arc;
use std::time::Duration;

#[derive(Clone)]
pub struct AuthCache {
    pub token_cache: Cache<String, AuthUser>,
    pub api_key_cache: Cache<String, ApiKeyInfo>,
}

impl AuthCache {
    pub fn new(max_entries: usize, ttl_secs: u64) -> Self {
        Self {
            token_cache: Cache::builder()
                .max_capacity(max_entries as u64)
                .time_to_live(Duration::from_secs(ttl_secs))
                .build(),
            api_key_cache: Cache::builder()
                .max_capacity(max_entries as u64)
                .time_to_live(Duration::from_secs(ttl_secs))
                .build(),
        }
    }
}

pub async fn cached_jwt_auth_middleware(
    State(state): State<SharedState>,
    mut request: Request,
    next: Next,
) -> Result<Response, AuthError> {
    let auth_header = request.headers()
        .get("authorization")
        .and_then(|v| v.to_str().ok())
        .ok_or(AuthError::MissingToken)?;

    let token = auth_header.strip_prefix("Bearer ").ok_or(AuthError::InvalidTokenFormat)?;

    if let Some(cached_user) = state.auth_cache.token_cache.get(token) {
        request.extensions_mut().insert(cached_user);
        return Ok(next.run(request).await);
    }

    let claims = Claims::decode(token, &state.auth_config.jwt_secret)
        .map_err(|_| AuthError::InvalidToken)?;

    let auth_user = AuthUser {
        user_id: claims.sub.clone(),
        role: claims.role.clone(),
    };

    state.auth_cache.token_cache.insert(token.to_string(), auth_user.clone());
    request.extensions_mut().insert(auth_user);
    Ok(next.run(request).await)
}

3. OpenTelemetry Observability Integration

use axum::extract::Request;
use axum::middleware::Next;
use axum::response::Response;
use opentelemetry::trace::{Span, Tracer};
use opentelemetry::KeyValue;

pub async fn observability_middleware(
    request: Request,
    next: Next,
) -> Response {
    let method = request.method().clone();
    let path = request.uri().path().to_string();
    let auth_method = if request.headers().contains_key("x-api-key") {
        "api_key"
    } else if request.headers().contains_key("authorization") {
        "jwt"
    } else if request.headers().contains_key("cookie") {
        "session"
    } else {
        "none"
    };

    let tracer = opentelemetry::global::tracer("auth-service");
    let mut span = tracer.start(format!("{} {}", method, path));
    span.set_attribute(KeyValue::new("auth.method", auth_method.to_string()));
    span.set_attribute(KeyValue::new("http.method", method.to_string()));
    span.set_attribute(KeyValue::new("http.path", path.clone()));

    let response = next.run(request).await;

    span.set_attribute(KeyValue::new("http.status_code", response.status().as_u16() as i64));
    span.end();

    response
}

Comparison

Dimension Axum+Tower Actix-web Guard Go Gin Java Spring Security
Auth model FromRequestParts+Layer Guard trait+Extractor Middleware function Filter chain+SecurityContext
Type safety ✅ Compile-time ⚠️ Partial runtime ❌ Runtime ❌ Runtime
Middleware composition ✅ Tower ServiceBuilder ⚠️ Manual nesting ✅ Middleware chain ✅ Filter chain
RBAC support Must implement Must implement casbin etc. ✅ Built-in @PreAuthorize
JWT ecosystem jsonwebtoken jsonwebtoken golang-jwt jjwt/nimbus
Session management Must implement Must implement gorilla/sessions ✅ Built-in Session
Performance ⭐ Very high ⭐ Very high ⭐ High ⭐ Medium
Learning curve ⭐ Steep ⭐ Steep ⭐ Gentle ⭐ Very steep
Rate limiting tower-http Self-implement tollbooth etc. bucket4j etc.
Production readiness ⭐ High ⭐ High ⭐ Very high ⭐ Very high

Summary: The core advantage of Axum auth lies in type-safe extractorsFromRequestParts injects auth info into Handlers like regular parameters, catching type mismatches at compile time. 2026 production practice: Use FromRequestParts for JWT/API Key extraction → enum + strum for RBAC permission model → from_fn_with_state for stateful middleware → Redis for sessions and JWT blacklist → moka cache for token validation acceleration → Tower ServiceBuilder for middleware pipeline composition. The key insight: auth in Axum is not an "interceptor" — it's "type extraction".


Try these browser-local tools — no sign-up required →

#Rust#Axum#中间件#JWT#认证鉴权#2026#Tower