Use flask session store instead of DIYing
This commit is contained in:
parent
41f36b0fd7
commit
467f6a77ae
29
roc_fnb/util/env_file.py
Normal file
29
roc_fnb/util/env_file.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
from os import environ
|
||||||
|
|
||||||
|
|
||||||
|
def env_file(key, default_file=KeyError, default=KeyError, default_fn=KeyError):
|
||||||
|
"""
|
||||||
|
Return a value from an environment variable or file specified by one.
|
||||||
|
|
||||||
|
Checks first for the value specified by key with "_FILE" appended. If that
|
||||||
|
is found, read from the file there. Otherwise return the value of the
|
||||||
|
environment variable, the contents of the specified default file, the default
|
||||||
|
value, or raises KeyError.
|
||||||
|
"""
|
||||||
|
if fp := environ.get(f'{key}_FILE'):
|
||||||
|
with open(fp) as file:
|
||||||
|
return file.read()
|
||||||
|
if var := environ.get(key):
|
||||||
|
return var
|
||||||
|
if default_file is not KeyError:
|
||||||
|
try:
|
||||||
|
with open(default_file) as file:
|
||||||
|
return file.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
... # fallthrough
|
||||||
|
if default is not KeyError:
|
||||||
|
return default
|
||||||
|
if default_fn is not KeyError:
|
||||||
|
return default_fn()
|
||||||
|
raise KeyError(f'no environment variable found ${key} nor {key}_FILE and default was not specified')
|
||||||
|
|
||||||
|
|
@ -26,31 +26,6 @@ def test_user_and_check_password(user):
|
||||||
assert user.check_password('monkey')
|
assert user.check_password('monkey')
|
||||||
|
|
||||||
|
|
||||||
def test_jwt(user):
|
|
||||||
user._id = (_id := ObjectId(randbytes(12)))
|
|
||||||
token = user.jwt
|
|
||||||
header, payload, sig = (base64_decode(part.replace('.', ''))
|
|
||||||
for part in token.split('.'))
|
|
||||||
header = json.loads(header)
|
|
||||||
payload = json.loads(payload)
|
|
||||||
assert header['alg'] == 'RS256'
|
|
||||||
assert header['typ'] == 'JWT'
|
|
||||||
assert set(header.keys()) == {'alg', 'typ'}
|
|
||||||
# Note that JWT contents are visible to the user: this can be useful but
|
|
||||||
# must be done with caution
|
|
||||||
assert payload['email'] == user.email
|
|
||||||
assert payload['name'] == user.name
|
|
||||||
assert ObjectId(base64_decode(payload['_id'])) == user._id == _id
|
|
||||||
assert set(payload.keys()) == {'email', 'name', '_id', 'admin', 'moderator'}
|
|
||||||
|
|
||||||
result = user.verify_jwt(token)
|
|
||||||
assert result.email == user.email
|
|
||||||
assert result.name == user.name
|
|
||||||
assert result._id == user._id == _id
|
|
||||||
assert not result.admin
|
|
||||||
assert not result.moderator
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_and_retreive(user: User, database: Database):
|
def test_store_and_retreive(user: User, database: Database):
|
||||||
try:
|
try:
|
||||||
database.store_user(user)
|
database.store_user(user)
|
||||||
|
|
@ -73,13 +48,3 @@ def test_store_and_retreive_by_id(user: User, database: Database):
|
||||||
finally:
|
finally:
|
||||||
if id := user._id:
|
if id := user._id:
|
||||||
database.delete_user(id)
|
database.delete_user(id)
|
||||||
|
|
||||||
def test_store_and_retreive_by_jwt(user: User, database: Database):
|
|
||||||
try:
|
|
||||||
token = database.store_user(user).jwt
|
|
||||||
assert user._id is not None
|
|
||||||
retreived = database.get_user_from_token(token)
|
|
||||||
assert retreived == user
|
|
||||||
finally:
|
|
||||||
if id := user._id:
|
|
||||||
database.delete_user(id)
|
|
||||||
|
|
@ -2,7 +2,7 @@ from base64 import b64decode, b64encode
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
from random import randbytes
|
from random import randbytes
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, Self
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
import scrypt
|
import scrypt
|
||||||
|
|
@ -24,6 +24,12 @@ class JwtUser:
|
||||||
moderator: bool
|
moderator: bool
|
||||||
admin: bool
|
admin: bool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, data: dict) -> Self:
|
||||||
|
_id = ObjectId(base64_decode(data.pop('_id')))
|
||||||
|
return cls(_id=_id, **data)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class User:
|
class User:
|
||||||
_id: Optional[ObjectId]
|
_id: Optional[ObjectId]
|
||||||
|
|
@ -63,6 +69,11 @@ class User:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def public_fields(self):
|
def public_fields(self):
|
||||||
|
"""
|
||||||
|
Session data is visible to client scripts.
|
||||||
|
|
||||||
|
This is a feature, not a bug; client scripts may need to gather login info.
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
'_id': base64_encode(self._id.binary),
|
'_id': base64_encode(self._id.binary),
|
||||||
"email": self.email,
|
"email": self.email,
|
||||||
|
|
@ -73,18 +84,3 @@ class User:
|
||||||
|
|
||||||
def check_password(self, password: str) -> bool:
|
def check_password(self, password: str) -> bool:
|
||||||
return self.password_hash == scrypt.hash(password, self.salt)
|
return self.password_hash == scrypt.hash(password, self.salt)
|
||||||
|
|
||||||
@property
|
|
||||||
def jwt(self) -> str:
|
|
||||||
return jwt.encode(self.public_fields, PRIVATE_KEY, algorithm='RS256')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def verify_jwt(token: str) -> JwtUser:
|
|
||||||
verified = jwt.decode(token, PUBLIC_KEY, verify=True, algorithms=['RS256'])
|
|
||||||
return JwtUser(
|
|
||||||
_id=ObjectId(base64_decode(verified['_id'])),
|
|
||||||
name=verified['name'],
|
|
||||||
email=verified['email'],
|
|
||||||
moderator=verified['moderator'],
|
|
||||||
admin=verified['admin'],
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
|
from functools import wraps
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from random import randbytes
|
||||||
from sys import stderr
|
from sys import stderr
|
||||||
|
|
||||||
from flask import (Flask, redirect, url_for, request, send_file, make_response,
|
from flask import (Flask, redirect, url_for, request, send_file, make_response,
|
||||||
abort, render_template, g)
|
abort, render_template, session, g)
|
||||||
|
|
||||||
|
from roc_fnb.util.env_file import env_file
|
||||||
from roc_fnb.website.database import Database
|
from roc_fnb.website.database import Database
|
||||||
from roc_fnb.website.models.user import User
|
from roc_fnb.website.models.user import JwtUser
|
||||||
|
|
||||||
db = Database.from_env()
|
db = Database.from_env()
|
||||||
|
|
||||||
|
|
@ -16,13 +20,31 @@ app = Flask(
|
||||||
static_folder=Path(__file__).absolute().parent / 'static',
|
static_folder=Path(__file__).absolute().parent / 'static',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.secret_key = env_file('FLASK_SECRET', default_file='./flask.secret', default_fn=lambda: randbytes(12))
|
||||||
|
|
||||||
@app.before_request
|
@app.before_request
|
||||||
def decode_user():
|
def decode_user():
|
||||||
if token := request.cookies.get('auth-token'):
|
if user := session.get('user'):
|
||||||
g.user = User.verify_jwt(token)
|
g.user = JwtUser.from_json(data=json.loads(user))
|
||||||
|
|
||||||
|
|
||||||
|
def require_user(admin = False, moderator = False):
|
||||||
|
"""
|
||||||
|
A decorator for any routes which require authentication.
|
||||||
|
|
||||||
|
https://stackoverflow.com/a/51820573
|
||||||
|
"""
|
||||||
|
def _require_user(handler):
|
||||||
|
@wraps(handler)
|
||||||
|
def __require_user():
|
||||||
|
if getattr(g, 'user', None) is None \
|
||||||
|
or (admin and not user.admin) \
|
||||||
|
or (moderator and not user.moderator):
|
||||||
|
abort(401)
|
||||||
|
return handler()
|
||||||
|
return __require_user
|
||||||
|
return _require_user
|
||||||
|
|
||||||
@app.route('/ig')
|
@app.route('/ig')
|
||||||
def ig_redir():
|
def ig_redir():
|
||||||
return redirect('https://instagram.com/RocFNB')
|
return redirect('https://instagram.com/RocFNB')
|
||||||
|
|
@ -44,16 +66,16 @@ def submit_login():
|
||||||
user = db.get_user_by_name(form['name'])
|
user = db.get_user_by_name(form['name'])
|
||||||
if not user.check_password(form['password']):
|
if not user.check_password(form['password']):
|
||||||
abort(401) # unauthorized
|
abort(401) # unauthorized
|
||||||
response = make_response(redirect('/me'))
|
session['user'] = json.dumps(user.public_fields)
|
||||||
response.set_cookie('auth-token', user.jwt)
|
return redirect('/me')
|
||||||
return response
|
|
||||||
|
|
||||||
@app.get('/login')
|
@app.get('/login')
|
||||||
def render_login_page():
|
def render_login_page():
|
||||||
|
if getattr(g, 'user', None):
|
||||||
|
return redirect('/me')
|
||||||
return render_template('login.html')
|
return render_template('login.html')
|
||||||
|
|
||||||
@app.get('/me')
|
@app.get('/me')
|
||||||
|
@require_user()
|
||||||
def get_profile():
|
def get_profile():
|
||||||
if g.user is not None:
|
|
||||||
return render_template('profile.html', user=g.user)
|
return render_template('profile.html', user=g.user)
|
||||||
abort(401)
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue