Skip to content

Commit 00b4cc3

Browse files
authored
feat: implement forgot password feature (langgenius#5534)
1 parent f546db5 commit 00b4cc3

File tree

33 files changed

+1000
-26
lines changed

33 files changed

+1000
-26
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,5 @@ sdks/python-client/dify_client.egg-info
174174
.vscode/*
175175
!.vscode/launch.json
176176
pyrightconfig.json
177+
178+
.idea/

api/configs/feature/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ class SecurityConfig(BaseModel):
1717
default=None,
1818
)
1919

20+
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
21+
description='Expiry time in hours for reset token',
22+
default=24,
23+
)
2024

2125
class AppExecutionConfig(BaseModel):
2226
"""

api/controllers/console/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131

3232
# Import auth controllers
33-
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
33+
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
3434

3535
# Import billing controllers
3636
from .billing import billing

api/controllers/console/auth/error.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,28 @@ class ApiKeyAuthFailedError(BaseHTTPException):
55
error_code = 'auth_failed'
66
description = "{message}"
77
code = 500
8+
9+
10+
class InvalidEmailError(BaseHTTPException):
11+
error_code = 'invalid_email'
12+
description = "The email address is not valid."
13+
code = 400
14+
15+
16+
class PasswordMismatchError(BaseHTTPException):
17+
error_code = 'password_mismatch'
18+
description = "The passwords do not match."
19+
code = 400
20+
21+
22+
class InvalidTokenError(BaseHTTPException):
23+
error_code = 'invalid_or_expired_token'
24+
description = "The token is invalid or has expired."
25+
code = 400
26+
27+
28+
class PasswordResetRateLimitExceededError(BaseHTTPException):
29+
error_code = 'password_reset_rate_limit_exceeded'
30+
description = "Password reset rate limit exceeded. Try again later."
31+
code = 429
32+
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import base64
2+
import logging
3+
import secrets
4+
5+
from flask_restful import Resource, reqparse
6+
7+
from controllers.console import api
8+
from controllers.console.auth.error import (
9+
InvalidEmailError,
10+
InvalidTokenError,
11+
PasswordMismatchError,
12+
PasswordResetRateLimitExceededError,
13+
)
14+
from controllers.console.setup import setup_required
15+
from extensions.ext_database import db
16+
from libs.helper import email as email_validate
17+
from libs.password import hash_password, valid_password
18+
from models.account import Account
19+
from services.account_service import AccountService
20+
from services.errors.account import RateLimitExceededError
21+
22+
23+
class ForgotPasswordSendEmailApi(Resource):
24+
25+
@setup_required
26+
def post(self):
27+
parser = reqparse.RequestParser()
28+
parser.add_argument('email', type=str, required=True, location='json')
29+
args = parser.parse_args()
30+
31+
email = args['email']
32+
33+
if not email_validate(email):
34+
raise InvalidEmailError()
35+
36+
account = Account.query.filter_by(email=email).first()
37+
38+
if account:
39+
try:
40+
AccountService.send_reset_password_email(account=account)
41+
except RateLimitExceededError:
42+
logging.warning(f"Rate limit exceeded for email: {account.email}")
43+
raise PasswordResetRateLimitExceededError()
44+
else:
45+
# Return success to avoid revealing email registration status
46+
logging.warning(f"Attempt to reset password for unregistered email: {email}")
47+
48+
return {"result": "success"}
49+
50+
51+
class ForgotPasswordCheckApi(Resource):
52+
53+
@setup_required
54+
def post(self):
55+
parser = reqparse.RequestParser()
56+
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
57+
args = parser.parse_args()
58+
token = args['token']
59+
60+
reset_data = AccountService.get_reset_password_data(token)
61+
62+
if reset_data is None:
63+
return {'is_valid': False, 'email': None}
64+
return {'is_valid': True, 'email': reset_data.get('email')}
65+
66+
67+
class ForgotPasswordResetApi(Resource):
68+
69+
@setup_required
70+
def post(self):
71+
parser = reqparse.RequestParser()
72+
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
73+
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
74+
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
75+
args = parser.parse_args()
76+
77+
new_password = args['new_password']
78+
password_confirm = args['password_confirm']
79+
80+
if str(new_password).strip() != str(password_confirm).strip():
81+
raise PasswordMismatchError()
82+
83+
token = args['token']
84+
reset_data = AccountService.get_reset_password_data(token)
85+
86+
if reset_data is None:
87+
raise InvalidTokenError()
88+
89+
AccountService.revoke_reset_password_token(token)
90+
91+
salt = secrets.token_bytes(16)
92+
base64_salt = base64.b64encode(salt).decode()
93+
94+
password_hashed = hash_password(new_password, salt)
95+
base64_password_hashed = base64.b64encode(password_hashed).decode()
96+
97+
account = Account.query.filter_by(email=reset_data.get('email')).first()
98+
account.password = base64_password_hashed
99+
account.password_salt = base64_salt
100+
db.session.commit()
101+
102+
return {'result': 'success'}
103+
104+
105+
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
106+
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
107+
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')

api/controllers/console/workspace/account.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def get(self):
245245
return {'data': integrate_data}
246246

247247

248+
249+
248250
# Register API resources
249251
api.add_resource(AccountInitApi, '/account/init')
250252
api.add_resource(AccountProfileApi, '/account/profile')

api/libs/helper.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
import json
2+
import logging
23
import random
34
import re
45
import string
56
import subprocess
7+
import time
68
import uuid
79
from collections.abc import Generator
810
from datetime import datetime
911
from hashlib import sha256
10-
from typing import Union
12+
from typing import Any, Optional, Union
1113
from zoneinfo import available_timezones
1214

13-
from flask import Response, stream_with_context
15+
from flask import Response, current_app, stream_with_context
1416
from flask_restful import fields
1517

18+
from extensions.ext_redis import redis_client
19+
from models.account import Account
20+
1621

1722
def run(script):
1823
return subprocess.getstatusoutput('source /root/.bashrc && ' + script)
@@ -46,12 +51,12 @@ def uuid_value(value):
4651
error = ('{value} is not a valid uuid.'
4752
.format(value=value))
4853
raise ValueError(error)
49-
54+
5055
def alphanumeric(value: str):
5156
# check if the value is alphanumeric and underlined
5257
if re.match(r'^[a-zA-Z0-9_]+$', value):
5358
return value
54-
59+
5560
raise ValueError(f'{value} is not a valid alphanumeric value')
5661

5762
def timestamp_value(timestamp):
@@ -163,3 +168,97 @@ def generate() -> Generator:
163168

164169
return Response(stream_with_context(generate()), status=200,
165170
mimetype='text/event-stream')
171+
172+
173+
class TokenManager:
174+
175+
@classmethod
176+
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
177+
old_token = cls._get_current_token_for_account(account.id, token_type)
178+
if old_token:
179+
if isinstance(old_token, bytes):
180+
old_token = old_token.decode('utf-8')
181+
cls.revoke_token(old_token, token_type)
182+
183+
token = str(uuid.uuid4())
184+
token_data = {
185+
'account_id': account.id,
186+
'email': account.email,
187+
'token_type': token_type
188+
}
189+
if additional_data:
190+
token_data.update(additional_data)
191+
192+
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
193+
token_key = cls._get_token_key(token, token_type)
194+
redis_client.setex(
195+
token_key,
196+
expiry_hours * 60 * 60,
197+
json.dumps(token_data)
198+
)
199+
200+
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
201+
return token
202+
203+
@classmethod
204+
def _get_token_key(cls, token: str, token_type: str) -> str:
205+
return f'{token_type}:token:{token}'
206+
207+
@classmethod
208+
def revoke_token(cls, token: str, token_type: str):
209+
token_key = cls._get_token_key(token, token_type)
210+
redis_client.delete(token_key)
211+
212+
@classmethod
213+
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]:
214+
key = cls._get_token_key(token, token_type)
215+
token_data_json = redis_client.get(key)
216+
if token_data_json is None:
217+
logging.warning(f"{token_type} token {token} not found with key {key}")
218+
return None
219+
token_data = json.loads(token_data_json)
220+
return token_data
221+
222+
@classmethod
223+
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
224+
key = cls._get_account_token_key(account_id, token_type)
225+
current_token = redis_client.get(key)
226+
return current_token
227+
228+
@classmethod
229+
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int):
230+
key = cls._get_account_token_key(account_id, token_type)
231+
redis_client.setex(key, expiry_hours * 60 * 60, token)
232+
233+
@classmethod
234+
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
235+
return f'{token_type}:account:{account_id}'
236+
237+
238+
class RateLimiter:
239+
def __init__(self, prefix: str, max_attempts: int, time_window: int):
240+
self.prefix = prefix
241+
self.max_attempts = max_attempts
242+
self.time_window = time_window
243+
244+
def _get_key(self, email: str) -> str:
245+
return f"{self.prefix}:{email}"
246+
247+
def is_rate_limited(self, email: str) -> bool:
248+
key = self._get_key(email)
249+
current_time = int(time.time())
250+
window_start_time = current_time - self.time_window
251+
252+
redis_client.zremrangebyscore(key, '-inf', window_start_time)
253+
attempts = redis_client.zcard(key)
254+
255+
if attempts and int(attempts) >= self.max_attempts:
256+
return True
257+
return False
258+
259+
def increment_rate_limit(self, email: str):
260+
key = self._get_key(email)
261+
current_time = int(time.time())
262+
263+
redis_client.zadd(key, {current_time: current_time})
264+
redis_client.expire(key, self.time_window * 2)

api/services/account_service.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from constants.languages import language_timezone_mapping, languages
1414
from events.tenant_event import tenant_was_created
1515
from extensions.ext_redis import redis_client
16+
from libs.helper import RateLimiter, TokenManager
1617
from libs.passport import PassportService
1718
from libs.password import compare_password, hash_password, valid_password
1819
from libs.rsa import generate_key_pair
@@ -29,14 +30,22 @@
2930
LinkAccountIntegrateError,
3031
MemberNotInTenantError,
3132
NoPermissionError,
33+
RateLimitExceededError,
3234
RoleAlreadyAssignedError,
3335
TenantNotFound,
3436
)
3537
from tasks.mail_invite_member_task import send_invite_member_mail_task
38+
from tasks.mail_reset_password_task import send_reset_password_mail_task
3639

3740

3841
class AccountService:
3942

43+
reset_password_rate_limiter = RateLimiter(
44+
prefix="reset_password_rate_limit",
45+
max_attempts=5,
46+
time_window=60 * 60
47+
)
48+
4049
@staticmethod
4150
def load_user(user_id: str) -> Account:
4251
account = Account.query.filter_by(id=user_id).first()
@@ -222,9 +231,33 @@ def load_logged_in_account(*, account_id: str, token: str):
222231
return None
223232
return AccountService.load_user(account_id)
224233

234+
@classmethod
235+
def send_reset_password_email(cls, account):
236+
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
237+
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
238+
239+
token = TokenManager.generate_token(account, 'reset_password')
240+
send_reset_password_mail_task.delay(
241+
language=account.interface_language,
242+
to=account.email,
243+
token=token
244+
)
245+
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
246+
return token
247+
248+
@classmethod
249+
def revoke_reset_password_token(cls, token: str):
250+
TokenManager.revoke_token(token, 'reset_password')
251+
252+
@classmethod
253+
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
254+
return TokenManager.get_token_data(token, 'reset_password')
255+
256+
225257
def _get_login_cache_key(*, account_id: str, token: str):
226258
return f"account_login:{account_id}:{token}"
227259

260+
228261
class TenantService:
229262

230263
@staticmethod

api/services/errors/account.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,8 @@ class MemberNotInTenantError(BaseServiceError):
5151

5252
class RoleAlreadyAssignedError(BaseServiceError):
5353
pass
54+
55+
56+
class RateLimitExceededError(BaseServiceError):
57+
pass
58+

0 commit comments

Comments
 (0)