diff --git a/roc_fnb/util/env_file.py b/roc_fnb/util/env_file.py new file mode 100644 index 0000000..c07baae --- /dev/null +++ b/roc_fnb/util/env_file.py @@ -0,0 +1,29 @@ +from os import environ + + +def env_file(key, default_file=KeyError, default=KeyError, default_fn=KeyError): + """ + Return a value from an environment variable or file specified by one. + + Checks first for the value specified by key with "_FILE" appended. If that + is found, read from the file there. Otherwise return the value of the + environment variable, the contents of the specified default file, the default + value, or raises KeyError. + """ + if fp := environ.get(f'{key}_FILE'): + with open(fp) as file: + return file.read() + if var := environ.get(key): + return var + if default_file is not KeyError: + try: + with open(default_file) as file: + return file.read() + except FileNotFoundError: + ... # fallthrough + if default is not KeyError: + return default + if default_fn is not KeyError: + return default_fn() + raise KeyError(f'no environment variable found ${key} nor {key}_FILE and default was not specified') + \ No newline at end of file diff --git a/roc_fnb/website/models/test_user.py b/roc_fnb/website/models/test_user.py index bc16b03..74a4b37 100644 --- a/roc_fnb/website/models/test_user.py +++ b/roc_fnb/website/models/test_user.py @@ -26,31 +26,6 @@ def test_user_and_check_password(user): assert user.check_password('monkey') -def test_jwt(user): - user._id = (_id := ObjectId(randbytes(12))) - token = user.jwt - header, payload, sig = (base64_decode(part.replace('.', '')) - for part in token.split('.')) - header = json.loads(header) - payload = json.loads(payload) - assert header['alg'] == 'RS256' - assert header['typ'] == 'JWT' - assert set(header.keys()) == {'alg', 'typ'} - # Note that JWT contents are visible to the user: this can be useful but - # must be done with caution - assert payload['email'] == user.email - assert payload['name'] == user.name - assert ObjectId(base64_decode(payload['_id'])) == user._id == _id - assert set(payload.keys()) == {'email', 'name', '_id', 'admin', 'moderator'} - - result = user.verify_jwt(token) - assert result.email == user.email - assert result.name == user.name - assert result._id == user._id == _id - assert not result.admin - assert not result.moderator - - def test_store_and_retreive(user: User, database: Database): try: database.store_user(user) @@ -73,13 +48,3 @@ def test_store_and_retreive_by_id(user: User, database: Database): finally: if id := user._id: database.delete_user(id) - -def test_store_and_retreive_by_jwt(user: User, database: Database): - try: - token = database.store_user(user).jwt - assert user._id is not None - retreived = database.get_user_from_token(token) - assert retreived == user - finally: - if id := user._id: - database.delete_user(id) \ No newline at end of file diff --git a/roc_fnb/website/models/user.py b/roc_fnb/website/models/user.py index 47f4479..6c52ee7 100644 --- a/roc_fnb/website/models/user.py +++ b/roc_fnb/website/models/user.py @@ -2,7 +2,7 @@ from base64 import b64decode, b64encode from dataclasses import dataclass import json from random import randbytes -from typing import Optional, Any +from typing import Optional, Any, Self from bson.objectid import ObjectId import scrypt @@ -24,6 +24,12 @@ class JwtUser: moderator: bool admin: bool + @classmethod + def from_json(cls, data: dict) -> Self: + _id = ObjectId(base64_decode(data.pop('_id'))) + return cls(_id=_id, **data) + + @dataclass class User: _id: Optional[ObjectId] @@ -63,6 +69,11 @@ class User: @property def public_fields(self): + """ + Session data is visible to client scripts. + + This is a feature, not a bug; client scripts may need to gather login info. + """ return { '_id': base64_encode(self._id.binary), "email": self.email, @@ -73,18 +84,3 @@ class User: def check_password(self, password: str) -> bool: return self.password_hash == scrypt.hash(password, self.salt) - - @property - def jwt(self) -> str: - return jwt.encode(self.public_fields, PRIVATE_KEY, algorithm='RS256') - - @staticmethod - def verify_jwt(token: str) -> JwtUser: - verified = jwt.decode(token, PUBLIC_KEY, verify=True, algorithms=['RS256']) - return JwtUser( - _id=ObjectId(base64_decode(verified['_id'])), - name=verified['name'], - email=verified['email'], - moderator=verified['moderator'], - admin=verified['admin'], - ) diff --git a/roc_fnb/website/server.py b/roc_fnb/website/server.py index e6facd1..714f96d 100644 --- a/roc_fnb/website/server.py +++ b/roc_fnb/website/server.py @@ -1,11 +1,15 @@ +from functools import wraps +import json from pathlib import Path +from random import randbytes from sys import stderr from flask import (Flask, redirect, url_for, request, send_file, make_response, - abort, render_template, g) + abort, render_template, session, g) +from roc_fnb.util.env_file import env_file from roc_fnb.website.database import Database -from roc_fnb.website.models.user import User +from roc_fnb.website.models.user import JwtUser db = Database.from_env() @@ -16,13 +20,31 @@ app = Flask( static_folder=Path(__file__).absolute().parent / 'static', ) +app.secret_key = env_file('FLASK_SECRET', default_file='./flask.secret', default_fn=lambda: randbytes(12)) @app.before_request def decode_user(): - if token := request.cookies.get('auth-token'): - g.user = User.verify_jwt(token) + if user := session.get('user'): + g.user = JwtUser.from_json(data=json.loads(user)) +def require_user(admin = False, moderator = False): + """ + A decorator for any routes which require authentication. + + https://stackoverflow.com/a/51820573 + """ + def _require_user(handler): + @wraps(handler) + def __require_user(): + if getattr(g, 'user', None) is None \ + or (admin and not user.admin) \ + or (moderator and not user.moderator): + abort(401) + return handler() + return __require_user + return _require_user + @app.route('/ig') def ig_redir(): return redirect('https://instagram.com/RocFNB') @@ -44,16 +66,16 @@ def submit_login(): user = db.get_user_by_name(form['name']) if not user.check_password(form['password']): abort(401) # unauthorized - response = make_response(redirect('/me')) - response.set_cookie('auth-token', user.jwt) - return response + session['user'] = json.dumps(user.public_fields) + return redirect('/me') @app.get('/login') def render_login_page(): + if getattr(g, 'user', None): + return redirect('/me') return render_template('login.html') @app.get('/me') +@require_user() def get_profile(): - if g.user is not None: - return render_template('profile.html', user=g.user) - abort(401) + return render_template('profile.html', user=g.user)