diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 12e9e0b..f1f695a 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -604,6 +604,7 @@ dependencies = [ "serde_json", "sql-adaptor", "tokio", + "tower-http", "tracing", "tracing-subscriber", ] @@ -1080,6 +1081,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -2642,6 +2649,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.2" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 86fce04..d123fac 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -21,3 +21,4 @@ tracing = "0.1.37" tracing-subscriber = "0.3.17" chrono = "0.4.24" bcrypt = "0.14.0" +tower-http = { version = "0.4.0", features = ["cors"] } diff --git a/backend/src/main.rs b/backend/src/main.rs index 3ec7bf7..644a879 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,24 +1,20 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{env, net::SocketAddr, sync::Arc}; use axum::{ extract, + http::{HeaderValue, Method}, routing::{get, patch, post}, Router, Server, }; use routes::*; use sql_adaptor::SqlAdaptor; use tokio::sync::Mutex; +use tower_http::cors::CorsLayer; mod errors; mod payloads; mod routes; -#[cfg(debug_assertions)] -const MODE: &str = "debug"; - -#[cfg(not(debug_assertions))] -const MODE: &str = "release"; - pub struct ApiState { adaptor: A, } @@ -36,6 +32,18 @@ async fn main() { adaptor: SqlAdaptor::new().await, })); + let cors = CorsLayer::new() + .allow_methods([Method::GET, Method::POST, Method::PATCH]) + .allow_origin( + if cfg!(debug_assertions) { + "http://localhost:1234".to_owned() + } else { + env::var("FRONTEND_URL").expect("Missing FRONTEND_URL environment variable") + } + .parse::() + .unwrap(), + ); + let app = Router::new() .route("/", get(get_root)) .route("/stats", get(get_stats)) @@ -44,13 +52,19 @@ async fn main() { .route("/event/:event_id/people", get(get_people)) .route("/event/:event_id/people/:person_name", get(get_person)) .route("/event/:event_id/people/:person_name", patch(update_person)) - .with_state(shared_state); + .with_state(shared_state) + .layer(cors); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); println!( "🦀 Crab Fit API listening at http://{} in {} mode", - addr, MODE + addr, + if cfg!(debug_assertions) { + "debug" + } else { + "release" + } ); Server::bind(&addr) .serve(app.into_make_service())