|
1 | 1 | import json |
| 2 | +import logging |
2 | 3 | import random |
3 | 4 | import re |
4 | 5 | import string |
5 | 6 | import subprocess |
| 7 | +import time |
6 | 8 | import uuid |
7 | 9 | from collections.abc import Generator |
8 | 10 | from datetime import datetime |
9 | 11 | from hashlib import sha256 |
10 | | -from typing import Union |
| 12 | +from typing import Any, Optional, Union |
11 | 13 | from zoneinfo import available_timezones |
12 | 14 |
|
13 | | -from flask import Response, stream_with_context |
| 15 | +from flask import Response, current_app, stream_with_context |
14 | 16 | from flask_restful import fields |
15 | 17 |
|
| 18 | +from extensions.ext_redis import redis_client |
| 19 | +from models.account import Account |
| 20 | + |
16 | 21 |
|
17 | 22 | def run(script): |
18 | 23 | return subprocess.getstatusoutput('source /root/.bashrc && ' + script) |
@@ -46,12 +51,12 @@ def uuid_value(value): |
46 | 51 | error = ('{value} is not a valid uuid.' |
47 | 52 | .format(value=value)) |
48 | 53 | raise ValueError(error) |
49 | | - |
| 54 | + |
50 | 55 | def alphanumeric(value: str): |
51 | 56 | # check if the value is alphanumeric and underlined |
52 | 57 | if re.match(r'^[a-zA-Z0-9_]+$', value): |
53 | 58 | return value |
54 | | - |
| 59 | + |
55 | 60 | raise ValueError(f'{value} is not a valid alphanumeric value') |
56 | 61 |
|
57 | 62 | def timestamp_value(timestamp): |
@@ -163,3 +168,97 @@ def generate() -> Generator: |
163 | 168 |
|
164 | 169 | return Response(stream_with_context(generate()), status=200, |
165 | 170 | mimetype='text/event-stream') |
| 171 | + |
| 172 | + |
| 173 | +class TokenManager: |
| 174 | + |
| 175 | + @classmethod |
| 176 | + def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: |
| 177 | + old_token = cls._get_current_token_for_account(account.id, token_type) |
| 178 | + if old_token: |
| 179 | + if isinstance(old_token, bytes): |
| 180 | + old_token = old_token.decode('utf-8') |
| 181 | + cls.revoke_token(old_token, token_type) |
| 182 | + |
| 183 | + token = str(uuid.uuid4()) |
| 184 | + token_data = { |
| 185 | + 'account_id': account.id, |
| 186 | + 'email': account.email, |
| 187 | + 'token_type': token_type |
| 188 | + } |
| 189 | + if additional_data: |
| 190 | + token_data.update(additional_data) |
| 191 | + |
| 192 | + expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] |
| 193 | + token_key = cls._get_token_key(token, token_type) |
| 194 | + redis_client.setex( |
| 195 | + token_key, |
| 196 | + expiry_hours * 60 * 60, |
| 197 | + json.dumps(token_data) |
| 198 | + ) |
| 199 | + |
| 200 | + cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) |
| 201 | + return token |
| 202 | + |
| 203 | + @classmethod |
| 204 | + def _get_token_key(cls, token: str, token_type: str) -> str: |
| 205 | + return f'{token_type}:token:{token}' |
| 206 | + |
| 207 | + @classmethod |
| 208 | + def revoke_token(cls, token: str, token_type: str): |
| 209 | + token_key = cls._get_token_key(token, token_type) |
| 210 | + redis_client.delete(token_key) |
| 211 | + |
| 212 | + @classmethod |
| 213 | + def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]: |
| 214 | + key = cls._get_token_key(token, token_type) |
| 215 | + token_data_json = redis_client.get(key) |
| 216 | + if token_data_json is None: |
| 217 | + logging.warning(f"{token_type} token {token} not found with key {key}") |
| 218 | + return None |
| 219 | + token_data = json.loads(token_data_json) |
| 220 | + return token_data |
| 221 | + |
| 222 | + @classmethod |
| 223 | + def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: |
| 224 | + key = cls._get_account_token_key(account_id, token_type) |
| 225 | + current_token = redis_client.get(key) |
| 226 | + return current_token |
| 227 | + |
| 228 | + @classmethod |
| 229 | + def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int): |
| 230 | + key = cls._get_account_token_key(account_id, token_type) |
| 231 | + redis_client.setex(key, expiry_hours * 60 * 60, token) |
| 232 | + |
| 233 | + @classmethod |
| 234 | + def _get_account_token_key(cls, account_id: str, token_type: str) -> str: |
| 235 | + return f'{token_type}:account:{account_id}' |
| 236 | + |
| 237 | + |
| 238 | +class RateLimiter: |
| 239 | + def __init__(self, prefix: str, max_attempts: int, time_window: int): |
| 240 | + self.prefix = prefix |
| 241 | + self.max_attempts = max_attempts |
| 242 | + self.time_window = time_window |
| 243 | + |
| 244 | + def _get_key(self, email: str) -> str: |
| 245 | + return f"{self.prefix}:{email}" |
| 246 | + |
| 247 | + def is_rate_limited(self, email: str) -> bool: |
| 248 | + key = self._get_key(email) |
| 249 | + current_time = int(time.time()) |
| 250 | + window_start_time = current_time - self.time_window |
| 251 | + |
| 252 | + redis_client.zremrangebyscore(key, '-inf', window_start_time) |
| 253 | + attempts = redis_client.zcard(key) |
| 254 | + |
| 255 | + if attempts and int(attempts) >= self.max_attempts: |
| 256 | + return True |
| 257 | + return False |
| 258 | + |
| 259 | + def increment_rate_limit(self, email: str): |
| 260 | + key = self._get_key(email) |
| 261 | + current_time = int(time.time()) |
| 262 | + |
| 263 | + redis_client.zadd(key, {current_time: current_time}) |
| 264 | + redis_client.expire(key, self.time_window * 2) |
0 commit comments