From 526990704bfa583363a78b7b0a34fe41c82796f4 Mon Sep 17 00:00:00 2001 From: "D. Scott Boggs" Date: Tue, 13 Jun 2023 12:26:39 -0400 Subject: [PATCH] add route to import the database file --- server/src/api/import.rs | 41 +++++++++++++++++++++++++++++++++++----- server/src/api/mod.rs | 5 ++--- server/src/error.rs | 8 ++++++++ 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/server/src/api/import.rs b/server/src/api/import.rs index f6dc8db..956de57 100644 --- a/server/src/api/import.rs +++ b/server/src/api/import.rs @@ -1,3 +1,5 @@ +//! This is behind a feature gate for a reason: it's wildly unsafe and +//! insecure. It absolutely enables arbitrary sql injection. use rocket::{http::Status, State}; use sea_orm::{ConnectionTrait, DatabaseBackend, DatabaseConnection, Statement}; @@ -5,11 +7,8 @@ use crate::error::Error; use super::error::ApiResult; -/// This is behind a feature gate for a reason: it's wildly unsafe and -/// insecure. It absolutely enables arbitrary sql injection. -#[cfg(feature = "unsafe_import")] -#[post("/import", data = "")] -pub(crate) async fn import_sql( +#[post("/dump", data = "")] +pub(crate) async fn sql_dump( db: &State, sql_dump: &str, ) -> ApiResult { @@ -25,3 +24,35 @@ pub(crate) async fn import_sql( } Ok(Status::Ok) } + +#[post("/", data="")] +pub(crate) async fn db_file( + db: &State, + sqlite_db: &[u8], +) -> ApiResult { + use std::{ + io::Write, + process::{Command, Stdio}, + }; + + let mut proc = Command::new("sqlite3") + .args(["-", ".dump"]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(Error::from)?; + proc.stdin + .take() + .ok_or(Error::Unreachable)? + .write_all(sqlite_db) + .map_err(Error::from)?; + let result = proc.wait_with_output().map_err(Error::from)?; + if result.status.success() { + sql_dump(db, &String::from_utf8(result.stdout).map_err(Error::from)?).await + } else { + Err(Error::SqliteCommandError(String::from_utf8_lossy( + &result.stderr, + ).to_string()).into()) + } +} diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index 8a8930a..1392b52 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -1,5 +1,6 @@ mod error; mod groups; +#[cfg(feature = "unsafe_import")] mod import; mod ticks; mod tracks; @@ -10,8 +11,6 @@ use std::net::{IpAddr, Ipv4Addr}; use rocket::fs::{FileServer, NamedFile}; use rocket::{routes, Config}; use sea_orm::DatabaseConnection; - -use crate::api::import::import_sql; use crate::error::Error; use crate::rocket::{Build, Rocket}; @@ -58,7 +57,7 @@ pub(crate) fn start_server(db: DatabaseConnection) -> Rocket { .mount("/", FileServer::from("/src/public")); #[cfg(feature = "unsafe_import")] - let it = it.mount("/api/v1", routes![import_sql]); + let it = it.mount("/api/v1/import", routes![import::sql_dump, import::db_file]); it } diff --git a/server/src/error.rs b/server/src/error.rs index f838c05..19c4f1a 100644 --- a/server/src/error.rs +++ b/server/src/error.rs @@ -1,3 +1,5 @@ +use std::string; + use derive_builder::UninitializedFieldError; #[derive(Debug, thiserror::Error)] @@ -8,6 +10,12 @@ pub enum Error { SeaOrm(#[from] sea_orm::DbErr), #[error(transparent)] Io(#[from] std::io::Error), + #[error("error running sqlite command: {0}")] + SqliteCommandError(String), + #[error("BUG: this case should have been unreachable")] + Unreachable, + #[error(transparent)] + Utf8(#[from] string::FromUtf8Error), } pub type Result = std::result::Result;