From 01004e237b644ab51ff6183a868a971a74e7d806 Mon Sep 17 00:00:00 2001 From: "D. Scott Boggs" Date: Tue, 27 Jun 2023 14:20:43 -0400 Subject: [PATCH] Fix authenticated track insertion --- server/src/api/auth.rs | 32 ++++++++++-------- server/src/api/tracks.rs | 64 ++++++++++++++++++++---------------- server/src/entities/users.rs | 3 ++ 3 files changed, 58 insertions(+), 41 deletions(-) diff --git a/server/src/api/auth.rs b/server/src/api/auth.rs index edf0d28..e42fe66 100644 --- a/server/src/api/auth.rs +++ b/server/src/api/auth.rs @@ -1,5 +1,8 @@ + + use derive_deref::Deref; -use log::warn; +use either::Either::{self, Right}; + use rocket::{ http::{Cookie, CookieJar, Status}, outcome::IntoOutcome, @@ -16,6 +19,8 @@ use crate::{ error::Error, }; +use super::ErrorResponder; + #[derive(Clone, Deserialize)] pub(super) struct LoginData { name: String, @@ -27,22 +32,22 @@ pub(super) async fn login( db: &State, user_data: Json, cookies: &CookieJar<'_>, -) -> ApiResult { - let users = User::find() - .filter(user::Column::Name.eq(&user_data.name)) - .all(db as &DatabaseConnection) +) -> Result> { + let user = Users::find() + .filter(users::Column::Name.eq(&user_data.name)) + .one(db as &DatabaseConnection) .await - .map_err(Error::from)?; - if users.len() > 1 { - warn!(count = users.len(), name = &user_data.name; "multiple entries found in database for user"); - } - let Some(user) = users.get(0) else { + .map_err(|err| Right(Error::from(err).into()))?; + let Some(user) = user else { + info!(name = user_data.name; "no user found with the given name"); return Ok(Status::Unauthorized); }; + let user = user.check_password(&user_data.password)?; cookies.add_private(Cookie::new( "user", - serde_json::to_string(&user).map_err(Error::from)?, + serde_json::to_string(&user).map_err(|err| Right(Error::from(err).into()))?, )); + cookies.add(Cookie::new("name", user.name)); Ok(Status::Ok) } @@ -60,12 +65,13 @@ pub(super) async fn sign_up( "user", serde_json::to_string(&user_data).map_err(Error::from)?, )); + cookies.add(Cookie::new("name", user_data.name)); Ok(()) } /// Authentication guard #[derive(Deref)] -pub(super) struct Auth(user::Model); +pub(super) struct Auth(users::Model); #[rocket::async_trait] impl<'r> FromRequest<'r> for Auth { @@ -77,7 +83,7 @@ impl<'r> FromRequest<'r> for Auth { }; serde_json::from_str(user.value()) .ok() - .map(|user| Auth(user)) + .map(Auth) .into_outcome(unauthorized) } } diff --git a/server/src/api/tracks.rs b/server/src/api/tracks.rs index e78904f..5060686 100644 --- a/server/src/api/tracks.rs +++ b/server/src/api/tracks.rs @@ -1,5 +1,5 @@ -use std::convert::Infallible; -use std::default::default; + + use crate::api::auth::Auth; use crate::api::{self, error::ApiResult}; @@ -8,7 +8,7 @@ use crate::error::Error; use either::Either::{self, Left, Right}; use rocket::http::Status; use rocket::{serde::json::Json, State}; -use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement, TryIntoModel}; +use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement}; use tokio::sync::broadcast::Sender; use super::update::Update; @@ -53,7 +53,7 @@ pub(super) async fn track( id: i32, auth: Auth, ) -> Result, Either> { - get_track_check_user(db, id, &*auth).await + get_track_check_user(db, id, &auth).await } #[get("//ticks")] @@ -63,7 +63,7 @@ pub(super) async fn ticks_for_track( auth: Auth, ) -> Result>, Either> { let db = db as &DatabaseConnection; - let track = get_track_check_user(db, id, &*auth).await?; + let track = get_track_check_user(db, id, &auth).await?; let result = track.find_related(Ticks).all(db).await; match result { Ok(ticks) => Ok(Json(ticks)), @@ -81,40 +81,48 @@ pub(super) async fn insert_track( fn bad() -> Either { Left(Status::BadRequest) } - let track = track.0.as_object().ok_or_else(bad)?; + fn bad_value_for(key: &'static str) -> impl Fn() -> Either { + move || { + warn!(key = key; "bad value"); + bad() + } + } + let track = track.0.as_object().ok_or_else(|| { + warn!("received value was not an object"); + bad() + })?; let Some(track_id) = db .query_one(Statement::from_sql_and_values( sea_orm::DatabaseBackend::Postgres, - "insert into $1 (user_id, track_id) values ( - $2, ( - insert into $3 ( - name, description, icon, enabled, multiple_entries_per_day, - color, order + r#"with track_insertion as ( + insert into tracks (name, description, icon, enabled, + multiple_entries_per_day, color, "order" ) values ( - $4, $5, $6, $7, $8, $9, $10, + $2, $3, $4, $5, $6, $7, $8 ) returning id ) - ) returning track_id;", + insert into user_tracks ( + user_id, track_id + ) select $1, ti.id + from track_insertion ti + join track_insertion using (id);"#, [ - user_tracks::Entity::default().table_name().into(), auth.id.into(), - tracks::Entity::default().table_name().into(), - track.get("name").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(), + track.get("name").ok_or_else(bad_value_for("name"))?.as_str().ok_or_else(bad_value_for("name"))?.into(), track .get("description") - .ok_or_else(bad)? + .ok_or_else(bad_value_for("description"))? .as_str() - .ok_or_else(bad)? + .ok_or_else(bad_value_for("description"))? .into(), - track.get("icon").ok_or_else(bad)?.as_str().ok_or_else(bad)?.into(), - track.get("enabled").ok_or_else(bad)?.as_i64().into(), + track.get("icon").ok_or_else(bad_value_for("icon"))?.as_str().ok_or_else(bad_value_for("icon"))?.into(), + track.get("enabled").and_then(|it| it.as_i64()).into(), track .get("multiple_entries_per_day") - .ok_or_else(bad)? - .as_i64() + .and_then(|it| it.as_i64()) .into(), - track.get("color").ok_or_else(bad)?.as_i64().into(), - track.get("order").ok_or_else(bad)?.as_i64().into(), + track.get("color").and_then(|it| it.as_i64()).into(), + track.get("order").and_then(|it| it.as_i64()).into(), ], )) .await @@ -186,7 +194,7 @@ pub(super) async fn ticked( .insert(db as &DatabaseConnection) .await .map_err(|err| Right(Error::from(err).into()))? - .to_owned(); + ; tx.send(Update::tick_added(tick.clone())) .map_err(|err| Right(Error::from(err).into()))?; Ok(Json(tick)) @@ -214,7 +222,7 @@ pub(super) async fn ticked_on_date( .insert(db as &DatabaseConnection) .await .map_err(Error::from)? - .to_owned(); + ; tx.send(Update::tick_added(tick.clone())) .map_err(Error::from)?; Ok(Left(Json(tick))) @@ -242,7 +250,7 @@ pub(super) async fn clear_all_ticks( .map_err(Error::from)?; for tick in ticks.clone() { tick.clone().delete(db).await.map_err(Error::from)?; - Update::tick_cancelled(tick).send(&tx)?; + Update::tick_cancelled(tick).send(tx)?; } Ok(Right(Json(ticks))) } @@ -271,7 +279,7 @@ pub(super) async fn clear_all_ticks_on_day( .map_err(Error::from)?; for tick in ticks.clone() { tick.clone().delete(db).await.map_err(Error::from)?; - Update::tick_cancelled(tick).send(&tx)?; + Update::tick_cancelled(tick).send(tx)?; } Ok(Right(Json(ticks))) } diff --git a/server/src/entities/users.rs b/server/src/entities/users.rs index c79c5de..645f862 100644 --- a/server/src/entities/users.rs +++ b/server/src/entities/users.rs @@ -65,8 +65,11 @@ impl ActiveModel { impl Model { pub fn check_password( self, + password: impl AsRef<[u8]>, + ) -> std::result::Result> { match verify(password, &self.password_hash) { Ok(true) => Ok(self), + Ok(false) => Err(Left(Status::Unauthorized)), Err(err) => Err(Right(Error::from(err).into())), } }