diff --git a/codecarbon/cli/main.py b/codecarbon/cli/main.py index 6f28c2309..8889f4509 100644 --- a/codecarbon/cli/main.py +++ b/codecarbon/cli/main.py @@ -1,15 +1,20 @@ +import json import os import signal import sys import time +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from typing import Optional +from urllib.parse import parse_qs, urlparse import questionary import requests import typer -from fief_client import Fief -from fief_client.integrations.cli import FiefAuth +from authlib.common.security import generate_token +from authlib.integrations.requests_client import OAuth2Session +from authlib.oauth2.rfc7636 import create_s256_code_challenge from rich import print from rich.prompt import Confirm from typing_extensions import Annotated @@ -22,7 +27,6 @@ get_existing_local_exp_id, overwrite_local_config, ) -from codecarbon.cli.monitor import run_and_monitor from codecarbon.core.api_client import ApiClient, get_datetime_with_timezone from codecarbon.core.schemas import ExperimentCreate, OrganizationCreate, ProjectCreate from codecarbon.emissions_tracker import EmissionsTracker, OfflineEmissionsTracker @@ -31,8 +35,9 @@ "AUTH_CLIENT_ID", "jsUPWIcUECQFE_ouanUuVhXx52TTjEVcVNNtNGeyAtU", ) -AUTH_SERVER_URL = os.environ.get( - "AUTH_SERVER_URL", "https://auth.codecarbon.io/codecarbon" +AUTH_SERVER_WELL_KNOWN = os.environ.get( + "AUTH_SERVER_WELL_KNOWN", + "https://auth.codecarbon.io/codecarbon/.well-known/openid-configuration", ) API_URL = os.environ.get("API_URL", "https://dashboard.codecarbon.io/api") @@ -115,17 +120,115 @@ def show_config(path: Path = Path("./.codecarbon.config")) -> None: ) -def get_fief_auth(): - fief = Fief(AUTH_SERVER_URL, AUTH_CLIENT_ID) - fief_auth = FiefAuth(fief, "./credentials.json") - return fief_auth +_REDIRECT_PORT = 8090 +_REDIRECT_URI = f"http://localhost:{_REDIRECT_PORT}/callback" +_CREDENTIALS_FILE = Path("./credentials.json") + + +class _CallbackHandler(BaseHTTPRequestHandler): + """HTTP handler that captures the OAuth2 authorization callback.""" + + callback_url = None + error = None + + def do_GET(self): + _CallbackHandler.callback_url = f"http://localhost:{_REDIRECT_PORT}{self.path}" + parsed = urlparse(self.path) + params = parse_qs(parsed.query) + + if "error" in params: + _CallbackHandler.error = params["error"][0] + self.send_response(400) + self.send_header("Content-Type", "text/html") + self.end_headers() + msg = params.get("error_description", [params["error"][0]])[0] + self.wfile.write( + f"
{msg}
".encode() + ) + else: + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"You can close this window.
" + ) + + def log_message(self, format, *args): + pass # Suppress server logs + + +def _discover_endpoints(): + """Fetch OpenID Connect discovery document.""" + resp = requests.get(AUTH_SERVER_WELL_KNOWN) + resp.raise_for_status() + return resp.json() + + +def _authorize(): + """Run the OAuth2 Authorization Code flow with PKCE.""" + discovery = _discover_endpoints() + + session = OAuth2Session( + client_id=AUTH_CLIENT_ID, + redirect_uri=_REDIRECT_URI, + scope="openid offline_access", + token_endpoint_auth_method="none", + ) + + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + + uri, state = session.create_authorization_url( + discovery["authorization_endpoint"], + code_challenge=code_challenge, + code_challenge_method="S256", + ) + + # Reset handler state + _CallbackHandler.callback_url = None + _CallbackHandler.error = None + + server = HTTPServer(("localhost", _REDIRECT_PORT), _CallbackHandler) + + print("Opening browser for authentication...") + webbrowser.open(uri) + + server.handle_request() + server.server_close() + + if _CallbackHandler.error: + raise ValueError(f"Authorization failed: {_CallbackHandler.error}") + + if not _CallbackHandler.callback_url: + raise ValueError("Authorization failed: no callback received") + + token = session.fetch_token( + discovery["token_endpoint"], + authorization_response=_CallbackHandler.callback_url, + code_verifier=code_verifier, + ) + + _save_credentials(token) + return token + + +def _save_credentials(tokens): + """Save OAuth tokens to credentials file.""" + with open(_CREDENTIALS_FILE, "w") as f: + json.dump(tokens, f) + + +def _load_credentials(): + """Load OAuth tokens from credentials file.""" + with open(_CREDENTIALS_FILE, "r") as f: + return json.load(f) def _get_access_token(): try: - access_token_info = get_fief_auth().access_token_info() - access_token = access_token_info["access_token"] - return access_token + creds = _load_credentials() + return creds["access_token"] except Exception as e: raise ValueError( f"Not able to retrieve the access token, please run `codecarbon login` first! (error: {e})" @@ -133,8 +236,8 @@ def _get_access_token(): def _get_id_token(): - id_token = get_fief_auth()._tokens["id_token"] - return id_token + creds = _load_credentials() + return creds["id_token"] @codecarbon.command( @@ -152,7 +255,7 @@ def api_get(): @codecarbon.command("login", short_help="Login to CodeCarbon") def login(): - get_fief_auth().authorize() + _authorize() api = ApiClient(endpoint_url=API_URL) # TODO: get endpoint from config access_token = _get_access_token() api.set_access_token(access_token) diff --git a/codecarbon/cli/oidc_auth.py b/codecarbon/cli/oidc_auth.py new file mode 100644 index 000000000..0c68779ef --- /dev/null +++ b/codecarbon/cli/oidc_auth.py @@ -0,0 +1,247 @@ +""" +OIDC Authentication module for CodeCarbon CLI. + +This module replaces the deprecated fief-client library with a standard +OIDC implementation using python-jose for JWT validation. +""" + +import hashlib +import json +import secrets +import webbrowser +from base64 import urlsafe_b64encode +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from threading import Thread +from typing import Dict, Optional +from urllib.parse import parse_qs, urlencode, urlparse + +import requests +from jose import jwt +from jose.exceptions import JWTError + + +class OIDCAuth: + """ + Uses Authorization Code flow with PKCE for secure authentication. + Stores tokens in a local credentials file. + """ + + def __init__( + self, + server_url: str, + client_id: str, + credentials_file: str = "./credentials.json", + ): + + self.server_url = server_url.rstrip("/") + self.client_id = client_id + self.credentials_file = Path(credentials_file) + self._tokens: Optional[Dict] = None + self._oidc_config: Optional[Dict] = None + self._jwks: Optional[Dict] = None + + # Load existing credentials + self._load_credentials() + + def _get_oidc_configuration(self) -> Dict: + if self._oidc_config is None: + config_url = f"{self.server_url}/.well-known/openid-configuration" + response = requests.get(config_url) + response.raise_for_status() + self._oidc_config = response.json() + return self._oidc_config + + def _get_jwks(self) -> Dict: + if self._jwks is None: + config = self._get_oidc_configuration() + jwks_uri = config["jwks_uri"] + response = requests.get(jwks_uri) + response.raise_for_status() + self._jwks = response.json() + return self._jwks + + def _generate_pkce_pair(self): + code_verifier = ( + urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=") + ) + code_challenge = ( + urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest()) + .decode("utf-8") + .rstrip("=") + ) + return code_verifier, code_challenge + + def _load_credentials(self): + if self.credentials_file.exists(): + try: + with open(self.credentials_file, "r") as f: + self._tokens = json.load(f) + except (json.JSONDecodeError, IOError): + self._tokens = None + + def _save_credentials(self): + if self._tokens: + self.credentials_file.parent.mkdir(parents=True, exist_ok=True) + with open(self.credentials_file, "w") as f: + json.dump(self._tokens, f, indent=2) + + def authorize(self, redirect_port: int = 51562): + config = self._get_oidc_configuration() + authorization_endpoint = config["authorization_endpoint"] + token_endpoint = config["token_endpoint"] + + code_verifier, code_challenge = self._generate_pkce_pair() + state = secrets.token_urlsafe(32) + + redirect_uri = f"http://localhost:{redirect_port}/callback" + + auth_params = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": redirect_uri, + "scope": "openid profile email", + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + auth_url = f"{authorization_endpoint}?{urlencode(auth_params)}" + + authorization_code = None + server_error = None + + class CallbackHandler(BaseHTTPRequestHandler): + def log_message(self, format, *args): + # Suppress server logs + pass + + def do_GET(self): + nonlocal authorization_code, server_error + + parsed = urlparse(self.path) + params = parse_qs(parsed.query) + + if "code" in params and "state" in params: + if params["state"][0] == state: + authorization_code = params["code"][0] + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + b"You can close this window.
" + ) + else: + server_error = "State mismatch" + self.send_response(400) + self.end_headers() + elif "error" in params: + server_error = params["error"][0] + self.send_response(400) + self.end_headers() + + server = HTTPServer(("localhost", redirect_port), CallbackHandler) + server_thread = Thread(target=server.handle_request, daemon=True) + server_thread.start() + print(f"Opening browser for authentication...") + print(auth_url) + webbrowser.open(auth_url) + server_thread.join(timeout=300) # 5 minute timeout + server.server_close() + + if server_error: + raise Exception(f"Authorization failed: {server_error}") + + if not authorization_code: + raise Exception("Authorization timed out or was cancelled") + + # Exchange code for tokens + token_params = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": redirect_uri, + "client_id": self.client_id, + "code_verifier": code_verifier, + } + + response = requests.post(token_endpoint, data=token_params) + response.raise_for_status() + self._tokens = response.json() + self._save_credentials() + + print("Authentication successful!") + + def _refresh_tokens(self): + """Refresh access token using refresh token.""" + if not self._tokens or "refresh_token" not in self._tokens: + raise Exception("No refresh token available") + + config = self._get_oidc_configuration() + token_endpoint = config["token_endpoint"] + + token_params = { + "grant_type": "refresh_token", + "refresh_token": self._tokens["refresh_token"], + "client_id": self.client_id, + } + + response = requests.post(token_endpoint, data=token_params) + response.raise_for_status() + self._tokens = response.json() + self._save_credentials() + + # def _validate_token(self, token: str) -> Dict: + # try: + # jwks = self._get_jwks() + # # Decode and validate + # claims = jwt.decode( + # token, + # jwks, + # algorithms=['RS256'], + # audience=self.client_id, + # issuer=self.server_url, + # ) + # return claims + # except JWTError as e: + # raise Exception(f"Token validation failed: {e}") + + def _validate_token(self, token: str) -> Dict: + try: + claims = jwt.get_unverified_claims(token) + import time + + if "exp" in claims and claims["exp"] < time.time(): + raise Exception("Token expired") + return claims + except JWTError as e: + raise Exception(f"Token validation failed: {e}") + + def access_token_info(self) -> Dict: + if not self._tokens or "access_token" not in self._tokens: + raise Exception("Not authenticated. Please run login first.") + + access_token = self._tokens["access_token"] + + try: + claims = self._validate_token(access_token) + return { + "access_token": access_token, + "claims": claims, + } + except Exception: + # Token might be expired, try to refresh + try: + self._refresh_tokens() + access_token = self._tokens["access_token"] + claims = self._validate_token(access_token) + return { + "access_token": access_token, + "claims": claims, + } + except Exception as e: + raise Exception(f"Failed to get valid access token: {e}") + + def get_id_token(self) -> str: + if not self._tokens or "id_token" not in self._tokens: + raise Exception("Not authenticated. Please run login first.") + + return self._tokens["id_token"]