diff --git a/Makefile b/Makefile index 274c3d3..f97671d 100644 --- a/Makefile +++ b/Makefile @@ -7,8 +7,10 @@ client/dist/index.html: build-client: client/dist/index.html start-server: build-client + -mkdir db.mount docker compose up --build -d clean: docker compose down - rm -r server/public/ client/dist/ + -rm -r server/public/ client/dist/ + diff --git a/client/src/components/NavBar.vue b/client/src/components/NavBar.vue index e11af91..145a817 100644 --- a/client/src/components/NavBar.vue +++ b/client/src/components/NavBar.vue @@ -1,5 +1,7 @@ @@ -9,20 +11,27 @@ import { RouterLink } from 'vue-router'; - - - - Go Back - - + + + Go Back + + - - - Add Track - - - + + + Add Track + + + + + + Log Out + + + + + - \ No newline at end of file + diff --git a/client/src/components/TickComponent.vue b/client/src/components/TickComponent.vue index 8b1d863..47d57e0 100644 --- a/client/src/components/TickComponent.vue +++ b/client/src/components/TickComponent.vue @@ -15,8 +15,8 @@ const className = computed(() => isSet.value ? "button is-rounded is-info" : "bu async function toggle() { if (isSet.value) { - await state.taskMarkedIncomplete(props.track, props.date) + await props.track.markIncomplete(props.date) } else - await state.taskCompleted(props.track, props.date) + await props.track.markComplete(props.date) } \ No newline at end of file diff --git a/client/src/components/TrackIcon.vue b/client/src/components/TrackIcon.vue index 273987e..e5aeabf 100644 --- a/client/src/components/TrackIcon.vue +++ b/client/src/components/TrackIcon.vue @@ -1,12 +1,12 @@ diff --git a/client/src/router.ts b/client/src/router.ts index 806e60d..8d94d4c 100644 --- a/client/src/router.ts +++ b/client/src/router.ts @@ -6,7 +6,8 @@ const router = createRouter({ history: createWebHistory(), routes: [ { path: '/', component: TableView }, - { path: '/new-track', component: NewTrackView } + { path: '/new-track', component: NewTrackView }, + { path: '/login', component: import('./views/Login.vue') } // for other pages: // {path: '/', component: import('./views/TableView.vue')} ] diff --git a/client/src/state.ts b/client/src/state.ts index d449b60..028ae4b 100644 --- a/client/src/state.ts +++ b/client/src/state.ts @@ -2,6 +2,8 @@ import { reactive } from "vue" import { Track } from "./track" import { Tick } from './ticks' import { error } from "./error" +import { getCookie } from "./util"; +import router from './router' enum State { Unfetched, @@ -9,18 +11,22 @@ enum State { Fetched, } -function dateQuery(date: Date): URLSearchParams { - let query = new URLSearchParams() - query.set("year", date.getUTCFullYear().toString()) - query.set("month", (date.getUTCMonth() + 1).toString()) - // good thing I still had this ^^^^^^^^^^^^^^ in mind when I wrote this 😬 - query.set("day", date.getUTCDate().toString()) - return query +interface LoggedInUser { + name: string } -export const state = reactive({ - tracks: new Array, - state: State.Unfetched, +class AppState { + tracks: Array + state: State + user?: LoggedInUser + source?: EventSource + + constructor() { + this.tracks = new Array + this.state = State.Unfetched + const name = getCookie("name") + if (name) this.user = { name } + } streamUpdatesFromServer() { const source = new EventSource("/api/v1/updates") source.addEventListener("open", () => console.debug("opened event source")) @@ -69,45 +75,30 @@ export const state = reactive({ window.location = window.location }) window.addEventListener('beforeunload', () => source.close()) - }, + this.source = source + } async repopulate() { + if (!this.user) { + this.tracks = [] + return + } this.state = State.Fetching this.tracks = await Track.fetchAll() - }, + this.source?.close() + this.streamUpdatesFromServer() + this.state = State.Fetched + } async populate() { if (this.state != State.Unfetched) return await this.repopulate() - this.streamUpdatesFromServer() - this.state = State.Fetched - }, - async taskCompleted(track: Track, date: Date): Promise { - const query = dateQuery(date) - const response: Response = await fetch(`/api/v1/tracks/${track.id}/ticked?${query.toString()}`, { method: "PATCH" }) - const body = await response.text() - if (!response.ok) { - error(body) - throw new Error(`error setting tick for track ${track.id} ("${track.name}"): ${response.status} ${response.statusText}`) - } - return JSON.parse(body) - }, - async taskMarkedIncomplete(track: Track, date: Date) { - const query = dateQuery(date) - const { ok, status, statusText } = await fetch(`/api/v1/tracks/${track.id}/all-ticks?${query.toString()}`, { method: 'DELETE' }) - if (!ok) - error(`error deleting ticks for ${track.id}: ${statusText} (${status})`) - }, - async addTrack(track: Track): Promise { - const response = await fetch('/api/v1/tracks', { - method: "POST", - body: JSON.stringify(track), - headers: { "Content-Type": "application/json" } - }) - if (!response.ok) - error(`error submitting track: ${track}: ${response.statusText} (${response.status})`) - return response.ok - }, - async removeTrack(trackID: number) { - const response = await fetch(`/api/v1/tracks/${trackID}`, { method: "DELETE" }) - if (!response.ok) error(`error deleting track with ID ${trackID}: ${response.statusText} (${response.status})`) } -}) + async logOut() { + const result = await fetch('/api/v1/auth', {method: 'DELETE'}) + if(!result.ok) return error('failed to log out') + this.user = undefined + router.push('/login') + } +} + + +export const state = reactive(new AppState) diff --git a/client/src/track.ts b/client/src/track.ts index 83947a3..f7525ad 100644 --- a/client/src/track.ts +++ b/client/src/track.ts @@ -1,4 +1,6 @@ import { error } from "./error" +import { Tick, ITick } from './ticks' +import { dateQuery } from "./util" export interface ITrack { id?: number @@ -47,6 +49,34 @@ export class Track implements ITrack { this.fetchTicks = this.fetchTicks.bind(this) } + /** + * Add this track to the database. A `TrackAdded` event should have been + * received from the server on the event stream by the time this returns. + * + * @returns whether or not the query succeeded + */ + async create(): Promise { + // note that this.id is expected to be `undefined` here. + const response = await fetch('/api/v1/tracks', { + method: "POST", + body: JSON.stringify(this), + headers: { "Content-Type": "application/json" } + }) + if (!response.ok) + error(`error submitting track ${this.name}: ${response.statusText} (${response.status})`) + return response.ok + } + + async delete() { + const id = this.id + if (id) await Track.deleteById(id) + } + + static async deleteById(id: number) { + const response = await fetch(`/api/v1/tracks/${id}`, { method: "DELETE" }) + if (!response.ok) error(`error deleting track with ID ${id}: ${response.statusText} (${response.status})`) + } + static fromJSON(track: ITrack): Track { return new Track(track.id, track.name, track.description, track.icon, track.enabled, track.multiple_entries_per_day, track.color, track.order) } @@ -97,4 +127,36 @@ export class Track implements ITrack { } return [] } -} \ No newline at end of file + /** + * Mark this track as being completed on the given date. A `TickAdded` event + * should have been received from the server on the event stream by the time + * this returns. + * + * @param date the date the task was completed + * @returns the decoded server API response + */ + async markComplete(date: Date) { + const query = dateQuery(date) + const response: Response = await fetch(`/api/v1/tracks/${this.id}/ticked?${query.toString()}`, { method: "PATCH" }) + const body = await response.text() + if (!response.ok) { + error(body) + throw new Error(`error setting tick for track ${this.id} ("${this.name}"): ${response.status} ${response.statusText}`) + } + return JSON.parse(body) + } + /** + * Mark this track as being incomplete on the given date. A `TickAdded` event + * should have been received from the server on the event stream by the time + * this returns. + * + * @param date the date the task was completed + * @returns the decoded server API response + */ + async markIncomplete(date: Date) { + const query = dateQuery(date) + const { ok, status, statusText } = await fetch(`/api/v1/tracks/${this.id}/all-ticks?${query.toString()}`, { method: 'DELETE' }) + if (!ok) + error(`error deleting ticks for ${this.id}: ${statusText} (${status})`) + } +} diff --git a/client/src/util.ts b/client/src/util.ts new file mode 100644 index 0000000..023079c --- /dev/null +++ b/client/src/util.ts @@ -0,0 +1,17 @@ +export function getCookie(key: string): string | null { + const start = document.cookie.indexOf(key + '=') + if(start === -1) return null + let end: number | undefined = document.cookie.indexOf(';', start) + if(end === -1) + end = undefined + return document.cookie.substring(start + key.length + 1, end) +} + +export function dateQuery(date: Date): URLSearchParams { + let query = new URLSearchParams() + query.set("year", date.getUTCFullYear().toString()) + query.set("month", (date.getUTCMonth() + 1).toString()) + // good thing I still had this ^^^^^^^^^^^^^^ in mind when I wrote this 😬 + query.set("day", date.getUTCDate().toString()) + return query +} diff --git a/client/src/views/Login.vue b/client/src/views/Login.vue new file mode 100644 index 0000000..84d09b7 --- /dev/null +++ b/client/src/views/Login.vue @@ -0,0 +1,84 @@ + + + + + + + Log in or Sign up + + + + Name + + + + + + + Password + + + + + + + + + + + diff --git a/client/src/views/NewTrackView.vue b/client/src/views/NewTrackView.vue index 8b2ed62..a15e26a 100644 --- a/client/src/views/NewTrackView.vue +++ b/client/src/views/NewTrackView.vue @@ -2,7 +2,6 @@ import { RouterLink, useRouter } from 'vue-router'; import { Track } from '../track'; import { computed, ref } from 'vue'; -import { state } from '../state'; const props = defineProps<{ initialState?: Track }>() const router = useRouter() @@ -26,7 +25,7 @@ const submit = async () => { const track = new Track(undefined, name.value, description.value, icon.value, Number(enabled.value), Number(multipleEntriesPerDay.value), color.value, order.value) - if (await state.addTrack(track)) + if (await track.create()) router.push('/') } diff --git a/client/src/views/TableView.vue b/client/src/views/TableView.vue index 053df35..f5afd39 100644 --- a/client/src/views/TableView.vue +++ b/client/src/views/TableView.vue @@ -1,5 +1,10 @@ diff --git a/docker-compose_dev.yml b/docker-compose_dev.yml index 93d4bf1..ac3c228 100644 --- a/docker-compose_dev.yml +++ b/docker-compose_dev.yml @@ -14,7 +14,7 @@ services: POSTGRES_USER: kalkutago POSTGRES_DB: kalkutago POSTGRES_HOST: database - secrets: [ postgres-password ] + secrets: [ postgres-password, cookie-secret ] depends_on: [ database ] expose: [ 8000 ] # ports: @@ -25,6 +25,7 @@ services: labels: traefik.enable: true traefik.http.routers.kalkutago_server.rule: 'Host(`kalkutago`) && PathPrefix(`/api`)' + database: image: postgres environment: @@ -65,6 +66,8 @@ services: secrets: postgres-password: file: ./server/postgres.pw + cookie-secret: + file: ./server/cookie-secret.pw networks: internal: diff --git a/docker-compose_prod.yml b/docker-compose_prod.yml index cb57656..09483f7 100644 --- a/docker-compose_prod.yml +++ b/docker-compose_prod.yml @@ -14,7 +14,7 @@ services: POSTGRES_USER: kalkutago POSTGRES_DB: kalkutago POSTGRES_HOST: database - secrets: [ postgres-password ] + secrets: [ postgres-password, cookie-secret ] depends_on: [ database ] volumes: - ./client/dist:/src/public:ro @@ -32,6 +32,8 @@ services: secrets: postgres-password: file: ./server/postgres.pw + cookie-secret: + file: ./server/cookie-secret.pw networks: internal: diff --git a/docker-compose_test.yml b/docker-compose_test.yml new file mode 100644 index 0000000..2fce6d8 --- /dev/null +++ b/docker-compose_test.yml @@ -0,0 +1,45 @@ +version: "3.5" + +services: + server: + build: + context: ./server + dockerfile: Dockerfile.test + networks: + - web + - internal + environment: + POSTGRES_PASSWORD_FILE: /run/secrets/postgres-password + POSTGRES_USER: kalkutago + POSTGRES_DB: kalkutago_TEST + POSTGRES_HOST: database + secrets: [ postgres-password, cookie-secret ] + depends_on: [ database ] + expose: [ 8000 ] + volumes: + - ./client/dist:/src/public:ro + labels: + traefik.enable: false + + database: + image: postgres + environment: + POSTGRES_PASSWORD_FILE: /run/secrets/postgres-password + POSTGRES_USER: kalkutago + POSTGRES_DB: kalkutago_TEST + secrets: [ postgres-password ] + networks: [ internal ] + labels: + traefik.enable: false + +secrets: + postgres-password: + file: ${PG_PW_FILE} + cookie-secret: + file: ${COOKIE_SECRET_FILE} + +networks: + internal: + internal: true + web: + external: true diff --git a/server/Cargo.lock b/server/Cargo.lock index 59054fd..0622b98 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -8,6 +8,41 @@ version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "209b47e8954a928e1d72e86eca7000ebb6655fe1436d33eefc2201cad027e237" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.7.6" @@ -142,6 +177,19 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +[[package]] +name = "bcrypt" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9df288bec72232f78c1ec5fe4e8f1d108aa0265476e93097593c803c8c02062a" +dependencies = [ + "base64 0.21.2", + "blowfish", + "getrandom", + "subtle", + "zeroize", +] + [[package]] name = "bigdecimal" version = "0.3.1" @@ -192,6 +240,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blowfish" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e412e2cd0f2b2d93e02543ceae7917b3c70331573df19ee046bcbc35e45e87d7" +dependencies = [ + "byteorder", + "cipher", +] + [[package]] name = "borsh" version = "0.10.3" @@ -305,6 +363,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "3.2.25" @@ -347,7 +415,13 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7efb37c3e1ccb1ff97164ad95ac1606e8ccd35b3fa0a7d99a304c7f4a428cc24" dependencies = [ + "aes-gcm", + "base64 0.21.2", + "hkdf", "percent-encoding", + "rand", + "sha2", + "subtle", "time 0.3.22", "version_check", ] @@ -393,9 +467,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core", "typenum", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.14.4" @@ -462,6 +546,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_deref" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcdbcee2d9941369faba772587a565f4f534e42cb8d17e5295871de730163b2b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "devise" version = "0.4.1" @@ -756,6 +851,16 @@ dependencies = [ "wasi 0.11.0+wasi-snapshot-preview1", ] +[[package]] +name = "ghash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d930750de5717d2dd0b8c0d42c076c0e884c81a73e6cab859bbd2339c71e3e40" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "glob" version = "0.3.1" @@ -989,6 +1094,15 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.12" @@ -1049,8 +1163,10 @@ dependencies = [ name = "kalkutago-server" version = "0.1.0" dependencies = [ + "bcrypt", "chrono", "derive_builder", + "derive_deref", "either", "femme", "log", @@ -1061,6 +1177,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-test", ] [[package]] @@ -1249,6 +1366,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "os_str_bytes" version = "6.5.1" @@ -1379,6 +1502,18 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "polyval" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef234e08c11dfcb2e56f79fd70f6f2eb7f025c0ce2333e82f4f0518ecad30c6" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -2503,6 +2638,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.8" @@ -2699,6 +2847,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.7.1" @@ -3010,3 +3168,9 @@ name = "yansi" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" diff --git a/server/Cargo.toml b/server/Cargo.toml index eddfb63..d2d627a 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -13,12 +13,15 @@ path = "src/main.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bcrypt = "0.14.0" chrono = "0.4.26" +derive_deref = "1.1.1" femme = "2.2.1" log = { version = "0.4.19", features = ["kv_unstable", "kv_unstable_serde"] } sea-orm-migration = "0.11.3" serde_json = "1.0.96" thiserror = "1.0.40" +tokio-test = "0.4.2" [dependencies.derive_builder] version = "0.12.0" @@ -40,7 +43,7 @@ features = [ [dependencies.rocket] git = "https://github.com/SergioBenitez/Rocket" rev = "v0.5.0-rc.3" -features = ["json"] +features = ["json", "secrets"] [dependencies.serde] version = "1.0.163" diff --git a/server/src/api/auth.rs b/server/src/api/auth.rs new file mode 100644 index 0000000..415a122 --- /dev/null +++ b/server/src/api/auth.rs @@ -0,0 +1,116 @@ +use derive_deref::Deref; +use either::Either::{self, Right}; + +use log::{as_debug, as_serde, debug}; +use rocket::{ + http::{Cookie, CookieJar, Status}, + outcome::IntoOutcome, + request::{self, FromRequest}, + serde::json::Json, + Request, State, +}; +use sea_orm::{prelude::*, DatabaseConnection}; +use serde::{Deserialize, Serialize}; + +use crate::{ + api::error::ApiResult, + entities::{prelude::*, *}, + error::Error, +}; + +use super::ErrorResponder; + +#[derive(Clone, Deserialize, Serialize)] +pub struct LoginData { + pub name: String, + pub password: String, +} + +#[put("/", data = "", format = "application/json")] +pub(super) async fn login( + db: &State, + user_data: Json, + cookies: &CookieJar<'_>, +) -> Result> { + let user = Users::find() + .filter(users::Column::Name.eq(&user_data.name)) + .one(db as &DatabaseConnection) + .await + .map_err(|err| Right(Error::from(err).into()))?; + let Some(user) = user else { + info!(name = user_data.name; "no user found with the given name"); + return Ok(Status::Unauthorized); + }; + let user = user.check_password(&user_data.password)?; + cookies.add_private(Cookie::new( + "user", + serde_json::to_string(&user).map_err(|err| Right(Error::from(err).into()))?, + )); + cookies.add(Cookie::new("name", user.name)); + Ok(Status::Ok) +} + +#[post("/", data = "", format = "application/json")] +pub(super) async fn sign_up( + db: &State, + user_data: Json, + cookies: &CookieJar<'_>, +) -> ApiResult<()> { + let user_data = users::ActiveModel::new(&user_data.name, &user_data.password)? + .insert(db as &DatabaseConnection) + .await + .map_err(Error::from)?; + debug!(user = as_serde!(user_data); "user added"); + cookies.add_private(Cookie::new( + "user", + serde_json::to_string(&user_data).map_err(Error::from)?, + )); + cookies.add(Cookie::new("name", user_data.name)); + Ok(()) +} + +#[delete("/")] +pub(super) async fn sign_out(cookies: &CookieJar<'_>) { + cookies.remove_private(Cookie::named("user")); + cookies.remove(Cookie::named("name")); +} + +/// Authentication guard +#[derive(Deref)] +pub(super) struct Auth(users::Model); + +#[derive(Deserialize)] +struct AuthData { + id: i32, + name: String, + password_hash: String, +} + +impl From for Auth { + fn from(value: AuthData) -> Self { + Auth(users::Model { + id: value.id, + name: value.name, + password_hash: value.password_hash, + }) + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Auth { + type Error = (); + async fn from_request(request: &'r Request<'_>) -> request::Outcome { + let unauthorized = (Status::Unauthorized, ()); + let Some(user) = request.cookies().get_private("user") else { + return request::Outcome::Failure(unauthorized); + }; + let user = user.value(); + debug!(user = user; "user retreived from private cookie"); + let result = serde_json::from_str(user) + .ok() + .map(|model: AuthData| model.into()) + .into_outcome(unauthorized); + debug!(result = as_debug!(result); "auth FromRequest return value"); + result + } +} diff --git a/server/src/api/error.rs b/server/src/api/error.rs index 602ca1d..5689ed4 100644 --- a/server/src/api/error.rs +++ b/server/src/api/error.rs @@ -2,11 +2,11 @@ use crate::error::Error; #[derive(Responder)] #[response(status = 500, content_type = "json")] -pub(crate) struct ErrorResponder { +pub struct ErrorResponder { message: String, } -pub(crate) type ApiResult = Result; +pub type ApiResult = Result; // The following impl's are for easy conversion of error types. diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index bec41f1..eaae998 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -1,3 +1,4 @@ +mod auth; mod error; mod groups; #[cfg(feature = "unsafe_import")] @@ -7,12 +8,14 @@ mod tracks; pub(crate) mod update; use std::{ - default::default, + default::Default, + env, fs, net::{IpAddr, Ipv4Addr}, }; use crate::error::Error; use rocket::{ + config::SecretKey, fs::{FileServer, NamedFile}, response::stream::EventStream, routes, Build, Config, Rocket, State, @@ -25,6 +28,8 @@ use tokio::sync::broadcast::{self, error::RecvError, Sender}; use self::{error::ApiResult, update::Update}; use log::{as_debug, as_serde, debug, trace}; +pub use auth::LoginData; + #[get("/status")] fn status() -> &'static str { "Ok" @@ -61,7 +66,17 @@ async fn spa_index_redirect() -> ApiResult { .map_err(Error::from)?) } -pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { +fn get_secret() -> [u8; 32] { + let path = + env::var("COOKIE_SECRET_FILE").unwrap_or_else(|_| "/run/secrets/cookie-secret".into()); + let file_contents = + fs::read(&path).unwrap_or_else(|err| panic!("failed to read from {path:?}: {err:?}")); + let mut data = [0u8; 32]; + data.copy_from_slice(&file_contents); + data +} + +pub fn start_server(db: DatabaseConnection) -> Rocket { use groups::*; use ticks::*; use tracks::*; @@ -69,7 +84,8 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { let it = rocket::build() .configure(Config { address: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), - ..default() + secret_key: SecretKey::derive_from(&get_secret()), + ..Config::default() }) .register("/", catchers![spa_index_redirect]) .manage(db) @@ -98,6 +114,7 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { "/api/v1/groups", routes![all_groups, group, insert_group, update_group, delete_group], ) + .mount("/api/v1/auth", routes![auth::login, auth::sign_up, auth::sign_out]) .mount("/", FileServer::from("/src/public")); #[cfg(feature = "unsafe_import")] diff --git a/server/src/api/tracks.rs b/server/src/api/tracks.rs index edfa616..5d01ea5 100644 --- a/server/src/api/tracks.rs +++ b/server/src/api/tracks.rs @@ -1,51 +1,71 @@ +use crate::api::auth::Auth; use crate::api::{self, error::ApiResult}; use crate::entities::{prelude::*, *}; use crate::error::Error; use either::Either::{self, Left, Right}; +use log::as_debug; +use log::{as_serde, debug, warn}; use rocket::http::Status; use rocket::{serde::json::Json, State}; -use sea_orm::{prelude::*, DatabaseConnection}; +use sea_orm::{prelude::*, DatabaseConnection, IntoActiveModel, Statement}; use tokio::sync::broadcast::Sender; use super::update::Update; +use super::ErrorResponder; #[get("/")] pub(super) async fn all_tracks( db: &State, + authorized_user: Auth, ) -> ApiResult>> { let db = db as &DatabaseConnection; - let tracks = Tracks::find().all(db).await.unwrap(); + let tracks = authorized_user + .find_related(Tracks) + .all(db) + .await + .map_err(Error::from)?; Ok(Json(tracks)) } +async fn get_track_check_user( + db: &DatabaseConnection, + track_id: i32, + user: &users::Model, +) -> Result, Either> { + if let Some(Some(track)) = user + .find_related(Tracks) + .filter(tracks::Column::Id.eq(track_id)) + .one(db) + .await + .transpose() + .map(|it| it.ok()) + { + Ok(Json(track)) + } else { + Err(Left(Status::NotFound)) + } +} + #[get("/")] pub(super) async fn track( db: &State, id: i32, + auth: Auth, ) -> Result, Either> { - let db = db as &DatabaseConnection; - match Tracks::find_by_id(id).one(db).await { - Ok(Some(track)) => Ok(Json(track)), - Ok(None) => Err(Left(Status::NotFound)), - Err(err) => Err(Right(Error::from(err).into())), - } + get_track_check_user(db, id, &auth).await } #[get("//ticks")] pub(super) async fn ticks_for_track( db: &State, id: i32, + auth: Auth, ) -> Result>, Either> { let db = db as &DatabaseConnection; - match Tracks::find_by_id(id).one(db).await { - Ok(Some(track)) => { - let result = track.find_related(Ticks).all(db).await; - match result { - Ok(ticks) => Ok(Json(ticks)), - Err(err) => Err(Right(Error::from(err).into())), - } - } - Ok(None) => Err(Left(Status::NotFound)), + let track = get_track_check_user(db, id, &auth).await?; + let result = track.find_related(Ticks).all(db).await; + match result { + Ok(ticks) => Ok(Json(ticks)), Err(err) => Err(Right(Error::from(err).into())), } } @@ -55,13 +75,87 @@ pub(super) async fn insert_track( db: &State, tx: &State>, track: Json, -) -> ApiResult> { - let track = track.0; - let db = db as &DatabaseConnection; - let model = tracks::ActiveModel::from_json(track).map_err(Error::from)?; - let track = model.insert(db).await.map_err(Error::from)?; - tx.send(Update::track_added(track.clone())) - .map_err(Error::from)?; + auth: Auth, +) -> Result, Either> { + debug!( + user=as_serde!(*auth), + track=as_serde!(track.0); + "authenticated user making track insertion request" + ); + fn bad() -> Either { + Left(Status::BadRequest) + } + fn bad_value_for(key: &'static str) -> impl Fn() -> Either { + move || { + warn!(key = key; "bad value"); + bad() + } + } + let track = track.0.as_object().ok_or_else(|| { + warn!("received value was not an object"); + bad() + })?; + let Some(track_id) = db + .query_one(Statement::from_sql_and_values( + sea_orm::DatabaseBackend::Postgres, + r#"with track_insertion as ( + insert into tracks (name, description, icon, enabled, + multiple_entries_per_day, color, "order" + ) values ( + $2, $3, $4, $5, $6, $7, $8 + ) returning id + ) + insert into user_tracks ( + user_id, track_id + ) select $1, ti.id + from track_insertion ti + join track_insertion using (id) + returning id;"#, + [ + auth.id.into(), + track + .get("name") + .ok_or_else(bad_value_for("name"))? + .as_str() + .ok_or_else(bad_value_for("name"))? + .into(), + track + .get("description") + .ok_or_else(bad_value_for("description"))? + .as_str() + .ok_or_else(bad_value_for("description"))? + .into(), + track + .get("icon") + .ok_or_else(bad_value_for("icon"))? + .as_str() + .ok_or_else(bad_value_for("icon"))? + .into(), + track.get("enabled").and_then(|it| it.as_i64()).into(), + track + .get("multiple_entries_per_day") + .and_then(|it| it.as_i64()) + .into(), + track.get("color").and_then(|it| it.as_i64()).into(), + track.get("order").and_then(|it| it.as_i64()).into(), + ], + )) + .await + .map_err(|err| Right(Error::from(err).into()))? + else { + return Err(Right("no value returned from track insertion query".into())); + }; + trace!("query completed"); + let track_id = track_id + .try_get_by_index(0) + .map_err(|err| Right(Error::from(err).into()))?; + trace!(track_id = track_id; "freshly inserted track ID"); + let track = auth.authorized_track(track_id, db).await.ok_or_else(|| { + Right(format!("failed to fetch freshly inserted track with id {track_id}").into()) + })?; + if let Err(err) = tx.send(Update::track_added(track.clone())) { + warn!(err = as_debug!(err); "error sending updates to subscribed channels"); + } Ok(Json(track)) } @@ -69,16 +163,22 @@ pub(super) async fn insert_track( pub(super) async fn update_track( db: &State, tx: &State>, - track: Json, -) -> ApiResult> { + track: Json, + authorized_user: Auth, +) -> Result, Either> { let db = db as &DatabaseConnection; - let track = tracks::ActiveModel::from_json(track.0) - .map_err(Error::from)? + let track = track.0; + if !authorized_user.is_authorized_for(track.id, db).await { + return Err(Left(Status::Forbidden)); + } + let track = track + .into_active_model() .update(db) .await - .map_err(Error::from)?; - tx.send(Update::track_changed(track.clone())) - .map_err(Error::from)?; + .map_err(|err| Right(Error::from(err).into()))?; + if let Err(err) = tx.send(Update::track_changed(track.clone())) { + warn!(err = as_debug!(err); "error sending updates to subscribed channels"); + } Ok(Json(track)) } @@ -87,11 +187,13 @@ pub(super) async fn delete_track( db: &State, tx: &State>, id: i32, + authorized_user: Auth, ) -> ApiResult { let db = db as &DatabaseConnection; - let Some(track) = Tracks::find_by_id(id).one(db).await.map_err(Error::from)? else { + let Some(track) = authorized_user.authorized_track(id, db).await else { return Ok(Status::NotFound); }; + track.clone().delete(db).await.map_err(Error::from)?; tx.send(Update::track_removed(track)).map_err(Error::from)?; Ok(Status::Ok) } @@ -101,15 +203,20 @@ pub(super) async fn ticked( db: &State, tx: &State>, id: i32, -) -> ApiResult> { + authorized_user: Auth, +) -> Result, Either> { + if !authorized_user.is_authorized_for(id, db).await { + return Err(Left(Status::Forbidden)); + } + let tick = ticks::ActiveModel::now(id); let tick = tick .insert(db as &DatabaseConnection) .await - .map_err(Error::from)? - .to_owned(); - tx.send(Update::tick_added(tick.clone())) - .map_err(Error::from)?; + .map_err(|err| Right(Error::from(err).into()))?; + if let Err(err) = tx.send(Update::tick_added(tick.clone())) { + warn!(err = as_debug!(err); "error sending updates to subscribed channels"); + } Ok(Json(tick)) } @@ -121,7 +228,12 @@ pub(super) async fn ticked_on_date( year: i32, month: u32, day: u32, + authorized_user: Auth, ) -> ApiResult, Status>> { + if !authorized_user.is_authorized_for(id, db).await { + return Ok(Right(Status::Forbidden)); + } + let Some(date) = Date::from_ymd_opt(year, month, day) else { return Ok(Right(Status::BadRequest)); }; @@ -129,10 +241,10 @@ pub(super) async fn ticked_on_date( let tick = tick .insert(db as &DatabaseConnection) .await - .map_err(Error::from)? - .to_owned(); - tx.send(Update::tick_added(tick.clone())) .map_err(Error::from)?; + if let Err(err) = tx.send(Update::tick_added(tick.clone())) { + warn!(err = as_debug!(err); "error sending updates to subscribed channels"); + } Ok(Left(Json(tick))) } @@ -141,10 +253,14 @@ pub(super) async fn clear_all_ticks( db: &State, tx: &State>, id: i32, + authorized_user: Auth, ) -> ApiResult>>> { let db = db as &DatabaseConnection; - let Some(track) = Tracks::find_by_id(id).one(db).await.map_err(Error::from)? else { - info!(track_id = id; "couldn't drop all ticks for track; track not found"); + let Some(track) = authorized_user.authorized_track(id, db).await else { + info!( + track_id = id, user_id = authorized_user.id; + "couldn't drop all ticks for track; track not found or user not authorized" + ); return Ok(Left(Status::NotFound)); }; let ticks = track @@ -154,7 +270,9 @@ pub(super) async fn clear_all_ticks( .map_err(Error::from)?; for tick in ticks.clone() { tick.clone().delete(db).await.map_err(Error::from)?; - Update::tick_cancelled(tick).send(&tx)?; + if let Err(err) = Update::tick_cancelled(tick).send(tx) { + warn!(err = as_debug!(err); "error sending updates to subscribed channels"); + } } Ok(Right(Json(ticks))) } @@ -167,8 +285,12 @@ pub(super) async fn clear_all_ticks_on_day( year: i32, month: u32, day: u32, -) -> ApiResult>> { + authorized_user: Auth, +) -> ApiResult>>> { let db = db as &DatabaseConnection; + if !authorized_user.is_authorized_for(id, db).await { + return Ok(Left(Status::Forbidden)); + } let ticks = Ticks::find() .filter(ticks::Column::TrackId.eq(id)) .filter(ticks::Column::Year.eq(year)) @@ -179,7 +301,9 @@ pub(super) async fn clear_all_ticks_on_day( .map_err(Error::from)?; for tick in ticks.clone() { tick.clone().delete(db).await.map_err(Error::from)?; - Update::tick_cancelled(tick).send(&tx)?; + if let Err(err) = Update::tick_cancelled(tick).send(tx) { + warn!(err = as_debug!(err); "error sending updates to subscribed channels"); + } } - Ok(Json(ticks)) + Ok(Right(Json(ticks))) } diff --git a/server/src/api/update.rs b/server/src/api/update.rs index 1135c18..7cacb7b 100644 --- a/server/src/api/update.rs +++ b/server/src/api/update.rs @@ -91,8 +91,14 @@ impl Update { } pub fn send(self, tx: &Sender) -> Result<()> { - let count = tx.send(self.clone())?; - trace!(sent_to = count, update = as_serde!(self); "sent update to SSE channel"); + let receiver_count = tx.receiver_count(); + if receiver_count > 0 { + trace!(receiver_count = receiver_count, update = as_serde!(self); "sending update"); + let count = tx.send(self.clone())?; + trace!(count = count; "update sent"); + } else { + trace!("no update receivers, skipping message"); + } Ok(()) } } diff --git a/server/src/db/mod.rs b/server/src/db/mod.rs index b141c4f..fda857d 100644 --- a/server/src/db/mod.rs +++ b/server/src/db/mod.rs @@ -1,11 +1,15 @@ +use crate::migrator::Migrator; +use sea_orm_migration::MigratorTrait; +use sea_orm_migration::SchemaManager; use std::{ - default::default, env, ffi::{OsStr, OsString}, fs::File, io::Read, }; +use sea_orm::{Database, DatabaseConnection}; + // from https://doc.rust-lang.org/std/ffi/struct.OsString.html fn concat_os_strings(a: &OsStr, b: &OsStr) -> OsString { let mut ret = OsString::with_capacity(a.len() + b.len()); // This will allocate @@ -30,7 +34,7 @@ fn get_env_var_or_file>(key: A) -> Option { if let Some(path) = env::var_os(file_key) { // open the file and read it let mut file = File::open(&path).unwrap_or_else(|_| panic!("no such file at {path:?}")); - let mut val: String = default(); + let mut val = String::new(); file.read_to_string(&mut val) .unwrap_or_else(|_| panic!("reading file at {path:?}")); Some(val) @@ -58,3 +62,31 @@ pub fn connection_url() -> String { .unwrap_or(5432_u16); format!("postgres://{user}:{password}@{host}:{port}/{db}") } + +pub async fn connection() -> DatabaseConnection { + Database::connect(connection_url()) + .await + .expect("db connection") +} +pub async fn migrated() -> DatabaseConnection { + let db = connection().await; + let schema_manager = SchemaManager::new(&db); + Migrator::refresh(&db).await.expect("migration"); + assert!(schema_manager + .has_table("tracks") + .await + .expect("fetch tracks table")); + assert!(schema_manager + .has_table("ticks") + .await + .expect("fetch ticks table")); + assert!(schema_manager + .has_table("groups") + .await + .expect("fetch groups table")); + assert!(schema_manager + .has_table("track2_groups") + .await + .expect("fetch track2groups table")); + db +} diff --git a/server/src/entities/mod.rs b/server/src/entities/mod.rs index 7a309af..c9107ae 100644 --- a/server/src/entities/mod.rs +++ b/server/src/entities/mod.rs @@ -6,3 +6,5 @@ pub mod groups; pub mod ticks; pub mod track2_groups; pub mod tracks; +pub mod user_tracks; +pub mod users; diff --git a/server/src/entities/prelude.rs b/server/src/entities/prelude.rs index 796df22..5b17141 100644 --- a/server/src/entities/prelude.rs +++ b/server/src/entities/prelude.rs @@ -4,3 +4,5 @@ pub use super::groups::Entity as Groups; pub use super::ticks::Entity as Ticks; pub use super::track2_groups::Entity as Track2Groups; pub use super::tracks::Entity as Tracks; +pub use super::user_tracks::Entity as UserTracks; +pub use super::users::Entity as Users; diff --git a/server/src/entities/ticks.rs b/server/src/entities/ticks.rs index 41d4107..7734cf8 100644 --- a/server/src/entities/ticks.rs +++ b/server/src/entities/ticks.rs @@ -1,6 +1,6 @@ //! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.3 -use std::default::default; +use std::default::Default; use chrono::{Datelike, Timelike, Utc}; use sea_orm::entity::prelude::*; @@ -60,7 +60,7 @@ impl ActiveModel { minute: Set(now.minute().try_into().ok()), second: Set(now.second().try_into().ok()), has_time_info: Set(Some(1)), - ..default() + ..Default::default() } } pub(crate) fn on(date: Date, track_id: i32) -> Self { @@ -80,7 +80,7 @@ impl ActiveModel { minute: Set(now.minute().try_into().ok()), second: Set(now.second().try_into().ok()), has_time_info: Set(Some(1)), - ..default() + ..Default::default() } } } diff --git a/server/src/entities/tracks.rs b/server/src/entities/tracks.rs index de4cef4..b3dfde0 100644 --- a/server/src/entities/tracks.rs +++ b/server/src/entities/tracks.rs @@ -24,6 +24,8 @@ pub enum Relation { Ticks, #[sea_orm(has_many = "super::track2_groups::Entity")] Track2Groups, + #[sea_orm(has_many = "super::user_tracks::Entity")] + UserTracks, } impl Related for Entity { @@ -38,4 +40,19 @@ impl Related for Entity { } } +impl Related for Entity { + fn to() -> RelationDef { + Relation::UserTracks.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + super::user_tracks::Relation::Users.def() + } + fn via() -> Option { + Some(super::user_tracks::Relation::Tracks.def().rev()) + } +} + impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/entities/user_tracks.rs b/server/src/entities/user_tracks.rs new file mode 100644 index 0000000..a244d53 --- /dev/null +++ b/server/src/entities/user_tracks.rs @@ -0,0 +1,46 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.3 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "user_tracks")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub user_id: i32, + pub track_id: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::tracks::Entity", + from = "Column::TrackId", + to = "super::tracks::Column::Id", + on_update = "NoAction", + on_delete = "NoAction" + )] + Tracks, + #[sea_orm( + belongs_to = "super::users::Entity", + from = "Column::UserId", + to = "super::users::Column::Id", + on_update = "NoAction", + on_delete = "NoAction" + )] + Users, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Tracks.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Users.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/server/src/entities/users.rs b/server/src/entities/users.rs new file mode 100644 index 0000000..57e0d6a --- /dev/null +++ b/server/src/entities/users.rs @@ -0,0 +1,99 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.11.3 + +use std::default::Default; + +use bcrypt::*; +// TODO Add option for argon2 https://docs.rs/argon2/latest/argon2/ +use either::Either::{self, Left, Right}; +use rocket::http::Status; +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::{ + api::ErrorResponder, + error::{self, Error}, +}; + +use super::tracks; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key)] + #[serde(skip_deserializing)] + pub id: i32, + pub name: String, + pub password_hash: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::user_tracks::Entity")] + UserTracks, +} +impl Related for Entity { + fn to() -> RelationDef { + Relation::UserTracks.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + super::user_tracks::Relation::Tracks.def() + } + + fn via() -> Option { + Some(super::user_tracks::Relation::Users.def().rev()) + } +} + +impl ActiveModelBehavior for ActiveModel {} + +impl ActiveModel { + pub fn new(name: impl AsRef, password: impl AsRef) -> error::Result { + use sea_orm::ActiveValue::Set; + let name = Set(name.as_ref().to_string()); + let password_hash = Set(hash(password.as_ref(), DEFAULT_COST)?); + Ok(Self { + name, + password_hash, + ..Default::default() + }) + } +} + +impl Model { + pub fn check_password( + self, + password: impl AsRef<[u8]>, + ) -> std::result::Result> { + match verify(password, &self.password_hash) { + Ok(true) => Ok(self), + Ok(false) => Err(Left(Status::Unauthorized)), + Err(err) => Err(Right(Error::from(err).into())), + } + } + + pub async fn authorized_track( + &self, + track_id: i32, + db: &DatabaseConnection, + ) -> Option { + self.find_related(super::prelude::Tracks) + .filter(tracks::Column::Id.eq(track_id)) + .one(db) + .await + .ok() + .flatten() + } + pub async fn is_authorized_for(&self, track_id: i32, db: &DatabaseConnection) -> bool { + self.authorized_track(track_id, db).await.is_some() + } + + pub async fn authorized_tracks(&self, db: &DatabaseConnection) -> Vec { + self.find_related(super::prelude::Tracks) + .all(db) + .await + .unwrap_or_default() + } +} diff --git a/server/src/error.rs b/server/src/error.rs index fbf5c10..a908650 100644 --- a/server/src/error.rs +++ b/server/src/error.rs @@ -1,5 +1,6 @@ use std::string; +use bcrypt::BcryptError; use derive_builder::UninitializedFieldError; #[derive(Debug, thiserror::Error)] @@ -18,6 +19,10 @@ pub enum Error { Utf8(#[from] string::FromUtf8Error), #[error(transparent)] ChannelSendError(#[from] tokio::sync::broadcast::error::SendError), + #[error(transparent)] + Bcrypt(#[from] BcryptError), + #[error(transparent)] + SerdeJson(#[from] serde_json::Error), } pub type Result = std::result::Result; diff --git a/server/src/lib.rs b/server/src/lib.rs new file mode 100644 index 0000000..f18ba5a --- /dev/null +++ b/server/src/lib.rs @@ -0,0 +1,8 @@ +#![feature(proc_macro_hygiene, decl_macro, never_type)] +#[macro_use] +extern crate rocket; +pub mod api; +pub mod db; +pub mod entities; +pub mod error; +mod migrator; diff --git a/server/src/main.rs b/server/src/main.rs index 8959895..eaf236f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,4 @@ -#![feature(default_free_fn, proc_macro_hygiene, decl_macro)] +#![feature(proc_macro_hygiene, decl_macro, never_type)] #[macro_use] extern crate rocket; mod api; @@ -6,32 +6,9 @@ mod db; mod entities; mod error; mod migrator; -use crate::migrator::Migrator; -use sea_orm::Database; -use sea_orm_migration::prelude::*; #[launch] async fn rocket_defines_the_main_fn() -> _ { - femme::with_level(femme::LevelFilter::Debug); - let url = db::connection_url(); - let db = Database::connect(url).await.expect("db connection"); - let schema_manager = SchemaManager::new(&db); - Migrator::refresh(&db).await.expect("migration"); - assert!(schema_manager - .has_table("tracks") - .await - .expect("fetch tracks table")); - assert!(schema_manager - .has_table("ticks") - .await - .expect("fetch ticks table")); - assert!(schema_manager - .has_table("groups") - .await - .expect("fetch groups table")); - assert!(schema_manager - .has_table("track2_groups") - .await - .expect("fetch track2groups table")); - api::start_server(db) + femme::with_level(femme::LevelFilter::Trace); + api::start_server(db::migrated().await) } diff --git a/server/src/migrator/m20230626_083036_create_users_table.rs b/server/src/migrator/m20230626_083036_create_users_table.rs new file mode 100644 index 0000000..7e77a64 --- /dev/null +++ b/server/src/migrator/m20230626_083036_create_users_table.rs @@ -0,0 +1,42 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(Users::Table) + .if_not_exists() + .col( + ColumnDef::new(Users::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col(ColumnDef::new(Users::Name).string().unique_key().not_null()) + .col(ColumnDef::new(Users::PasswordHash).string().not_null()) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(Users::Table).to_owned()) + .await + } +} + +/// Learn more at https://docs.rs/sea-query#iden +#[derive(Iden)] +pub(crate) enum Users { + Table, + Id, + Name, + PasswordHash, +} diff --git a/server/src/migrator/m20230626_150551_associate_users_and_tracks.rs b/server/src/migrator/m20230626_150551_associate_users_and_tracks.rs new file mode 100644 index 0000000..b2bfcfd --- /dev/null +++ b/server/src/migrator/m20230626_150551_associate_users_and_tracks.rs @@ -0,0 +1,57 @@ +use super::{ + m20230606_000001_create_tracks_table::Tracks, m20230626_083036_create_users_table::Users, +}; +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(UserTracks::Table) + .if_not_exists() + .col( + ColumnDef::new(UserTracks::Id) + .integer() + .not_null() + .primary_key() + .auto_increment(), + ) + .col(ColumnDef::new(UserTracks::UserId).integer().not_null()) + .col(ColumnDef::new(UserTracks::TrackId).integer().not_null()) + .foreign_key( + ForeignKey::create() + .name("fk-user_tracks-user_id") + .from(UserTracks::Table, UserTracks::UserId) + .to(Users::Table, Users::Id), + ) + .foreign_key( + ForeignKey::create() + .name("fk-user_tracks-track_id") + .from(UserTracks::Table, UserTracks::TrackId) + .to(Tracks::Table, Tracks::Id), + ) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(UserTracks::Table).to_owned()) + .await + } +} + +/// Learn more at https://docs.rs/sea-query#iden +#[derive(Iden)] +enum UserTracks { + Table, + Id, + UserId, + TrackId, +} diff --git a/server/src/migrator/mod.rs b/server/src/migrator/mod.rs index 473e061..edf6a1f 100644 --- a/server/src/migrator/mod.rs +++ b/server/src/migrator/mod.rs @@ -2,6 +2,8 @@ mod m20230606_000001_create_tracks_table; mod m20230606_000002_create_ticks_table; mod m20230606_000003_create_groups_table; mod m20230606_000004_create_track2groups_table; +mod m20230626_083036_create_users_table; +mod m20230626_150551_associate_users_and_tracks; use sea_orm_migration::prelude::*; @@ -15,6 +17,8 @@ impl MigratorTrait for Migrator { Box::new(m20230606_000002_create_ticks_table::Migration), Box::new(m20230606_000003_create_groups_table::Migration), Box::new(m20230606_000004_create_track2groups_table::Migration), + Box::new(m20230626_083036_create_users_table::Migration), + Box::new(m20230626_150551_associate_users_and_tracks::Migration), ] } } diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..f7556e2 --- /dev/null +++ b/shell.nix @@ -0,0 +1,17 @@ +# DEVELOPMENT shell environment +{ pkgs ? import {} }: + +pkgs.mkShell { + name = "kalkutago"; + nativeBuildInputs = with pkgs.buildPackages; [ + clang + yarn nodejs + openssl + python3 + python3Packages.requests + python3Packages.ipython + rustup + docker + gnumake + ]; +} diff --git a/test.py b/test.py new file mode 100644 index 0000000..fe9db4f --- /dev/null +++ b/test.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# +# Quick script to test endpoints of kalkutago + +from requests import get, post, put, patch +from time import gmtime as utc + +credentials = {"name": "testuser", "password": "testpass"} +track = {"name": "test", "description": "test track", "icon": "❓", "enabled": 1} + +def test_auth(method): + res = method(f'http://kalkutago/api/v1/auth', json=credentials) + assert 'user' in res.cookies.iterkeys(), \ + f'no user cookie found. Cookies: {res.cookies.get_dict()}; body: ' + \ + res.text + return res.cookies['user'] + +def test_create_user(): + return test_auth(post) + +def test_login(): + return test_auth(put) + +def test_track_creation(auth_cookie): + res = post('http://kalkutago/api/v1/tracks', json=track, + cookies={'user': auth_cookie}) + print(res.text) + res.raise_for_status() + return res.json() + +def test_get_track(auth_cookie, track): + res = get(f'http://kalkutago/api/v1/tracks/{track["id"]}', + cookies={'user': auth_cookie}) + print(res.text) + res.raise_for_status() + retrieved = res.json() + assert track == retrieved, f'expected {track!r} to equal {retrieved!r}' + return retrieved + +def test_tick(auth_cookie, track): + res = patch(f'http://kalkutago/api/v1/tracks/{track["id"]}/ticked', + cookies={'user': auth_cookie}) + print(res.text) + res.raise_for_status() + retrieved = res.json() + # result: + # {"id":1,"track_id":6,"year":2023,"month":8,"day":10,"hour":13,"minute":7,"second":41,"has_time_info":1} + now = utc() + assert retrieved['track_id'] == track['id'] + assert retrieved['year'] == now.tm_year + assert retrieved['month'] == now.tm_mon + assert retrieved['day'] == now.tm_mday + return retrieved + + +if __name__ == "__main__": + login_cookie = test_create_user() + test_login() + track = test_track_creation(login_cookie) + retrieved = test_get_track(login_cookie, track) + tick = test_tick(login_cookie, track)