diff options
author | Daniel Gustafsson <dgustafsson@postgresql.org> | 2025-02-20 16:25:17 +0100 |
---|---|---|
committer | Daniel Gustafsson <dgustafsson@postgresql.org> | 2025-02-20 16:25:17 +0100 |
commit | b3f0be788afc17d2206e1ae1c731d8aeda1f2f59 (patch) | |
tree | 4935e9d745787830d57941771dd2e63b49236ae5 /src/test/modules/oauth_validator/t/oauth_server.py | |
parent | 1fd1bd871012732e3c6c482667d2f2c56f1a9395 (diff) | |
download | postgresql-b3f0be788afc17d2206e1ae1c731d8aeda1f2f59.tar.gz postgresql-b3f0be788afc17d2206e1ae1c731d8aeda1f2f59.zip |
Add support for OAUTHBEARER SASL mechanism
This commit implements OAUTHBEARER, RFC 7628, and OAuth 2.0 Device
Authorization Grants, RFC 8628. In order to use this there is a
new pg_hba auth method called oauth. When speaking to a OAuth-
enabled server, it looks a bit like this:
$ psql 'host=example.org oauth_issuer=... oauth_client_id=...'
Visit https://oauth.example.org/login and enter the code: FPQ2-M4BG
Device authorization is currently the only supported flow so the
OAuth issuer must support that in order for users to authenticate.
Third-party clients may however extend this and provide their own
flows. The built-in device authorization flow is currently not
supported on Windows.
In order for validation to happen server side a new framework for
plugging in OAuth validation modules is added. As validation is
implementation specific, with no default specified in the standard,
PostgreSQL does not ship with one built-in. Each pg_hba entry can
specify a specific validator or be left blank for the validator
installed as default.
This adds a requirement on libcurl for the client side support,
which is optional to build, but the server side has no additional
build requirements. In order to run the tests, Python is required
as this adds a https server written in Python. Tests are gated
behind PG_TEST_EXTRA as they open ports.
This patch has been a multi-year project with many contributors
involved with reviews and in-depth discussions: Michael Paquier,
Heikki Linnakangas, Zhihong Yu, Mahendrakar Srinivasarao, Andrey
Chudnovsky and Stephen Frost to name a few. While Jacob Champion
is the main author there have been some levels of hacking by others.
Daniel Gustafsson contributed the validation module and various bits
and pieces; Thomas Munro wrote the client side support for kqueue.
Author: Jacob Champion <jacob.champion@enterprisedb.com>
Co-authored-by: Daniel Gustafsson <daniel@yesql.se>
Co-authored-by: Thomas Munro <thomas.munro@gmail.com>
Reviewed-by: Daniel Gustafsson <daniel@yesql.se>
Reviewed-by: Peter Eisentraut <peter@eisentraut.org>
Reviewed-by: Antonin Houska <ah@cybertec.at>
Reviewed-by: Kashif Zeeshan <kashi.zeeshan@gmail.com>
Discussion: https://postgr.es/m/d1b467a78e0e36ed85a09adf979d04cf124a9d4b.camel@vmware.com
Diffstat (limited to 'src/test/modules/oauth_validator/t/oauth_server.py')
-rwxr-xr-x | src/test/modules/oauth_validator/t/oauth_server.py | 391 |
1 files changed, 391 insertions, 0 deletions
diff --git a/src/test/modules/oauth_validator/t/oauth_server.py b/src/test/modules/oauth_validator/t/oauth_server.py new file mode 100755 index 00000000000..4faf3323d38 --- /dev/null +++ b/src/test/modules/oauth_validator/t/oauth_server.py @@ -0,0 +1,391 @@ +#! /usr/bin/env python3 +# +# A mock OAuth authorization server, designed to be invoked from +# OAuth/Server.pm. This listens on an ephemeral port number (printed to stdout +# so that the Perl tests can contact it) and runs as a daemon until it is +# signaled. +# + +import base64 +import http.server +import json +import os +import sys +import time +import urllib.parse +from collections import defaultdict + + +class OAuthHandler(http.server.BaseHTTPRequestHandler): + """ + Core implementation of the authorization server. The API is + inheritance-based, with entry points at do_GET() and do_POST(). See the + documentation for BaseHTTPRequestHandler. + """ + + JsonObject = dict[str, object] # TypeAlias is not available until 3.10 + + def _check_issuer(self): + """ + Switches the behavior of the provider depending on the issuer URI. + """ + self._alt_issuer = ( + self.path.startswith("/alternate/") + or self.path == "/.well-known/oauth-authorization-server/alternate" + ) + self._parameterized = self.path.startswith("/param/") + + if self._alt_issuer: + # The /alternate issuer uses IETF-style .well-known URIs. + if self.path.startswith("/.well-known/"): + self.path = self.path.removesuffix("/alternate") + else: + self.path = self.path.removeprefix("/alternate") + elif self._parameterized: + self.path = self.path.removeprefix("/param") + + def _check_authn(self): + """ + Checks the expected value of the Authorization header, if any. + """ + secret = self._get_param("expected_secret", None) + if secret is None: + return + + assert "Authorization" in self.headers + method, creds = self.headers["Authorization"].split() + + if method != "Basic": + raise RuntimeError(f"client used {method} auth; expected Basic") + + username = urllib.parse.quote_plus(self.client_id) + password = urllib.parse.quote_plus(secret) + expected_creds = f"{username}:{password}" + + if creds.encode() != base64.b64encode(expected_creds.encode()): + raise RuntimeError( + f"client sent '{creds}'; expected b64encode('{expected_creds}')" + ) + + def do_GET(self): + self._response_code = 200 + self._check_issuer() + + config_path = "/.well-known/openid-configuration" + if self._alt_issuer: + config_path = "/.well-known/oauth-authorization-server" + + if self.path == config_path: + resp = self.config() + else: + self.send_error(404, "Not Found") + return + + self._send_json(resp) + + def _parse_params(self) -> dict[str, str]: + """ + Parses apart the form-urlencoded request body and returns the resulting + dict. For use by do_POST(). + """ + size = int(self.headers["Content-Length"]) + form = self.rfile.read(size) + + assert self.headers["Content-Type"] == "application/x-www-form-urlencoded" + return urllib.parse.parse_qs( + form.decode("utf-8"), + strict_parsing=True, + keep_blank_values=True, + encoding="utf-8", + errors="strict", + ) + + @property + def client_id(self) -> str: + """ + Returns the client_id sent in the POST body or the Authorization header. + self._parse_params() must have been called first. + """ + if "client_id" in self._params: + return self._params["client_id"][0] + + if "Authorization" not in self.headers: + raise RuntimeError("client did not send any client_id") + + _, creds = self.headers["Authorization"].split() + + decoded = base64.b64decode(creds).decode("utf-8") + username, _ = decoded.split(":", 1) + + return urllib.parse.unquote_plus(username) + + def do_POST(self): + self._response_code = 200 + self._check_issuer() + + self._params = self._parse_params() + if self._parameterized: + # Pull encoded test parameters out of the peer's client_id field. + # This is expected to be Base64-encoded JSON. + js = base64.b64decode(self.client_id) + self._test_params = json.loads(js) + + self._check_authn() + + if self.path == "/authorize": + resp = self.authorization() + elif self.path == "/token": + resp = self.token() + else: + self.send_error(404) + return + + self._send_json(resp) + + def _should_modify(self) -> bool: + """ + Returns True if the client has requested a modification to this stage of + the exchange. + """ + if not hasattr(self, "_test_params"): + return False + + stage = self._test_params.get("stage") + + return ( + stage == "all" + or ( + stage == "discovery" + and self.path == "/.well-known/openid-configuration" + ) + or (stage == "device" and self.path == "/authorize") + or (stage == "token" and self.path == "/token") + ) + + def _get_param(self, name, default): + """ + If the client has requested a modification to this stage (see + _should_modify()), this method searches the provided test parameters for + a key of the given name, and returns it if found. Otherwise the provided + default is returned. + """ + if self._should_modify() and name in self._test_params: + return self._test_params[name] + + return default + + @property + def _content_type(self) -> str: + """ + Returns "application/json" unless the test has requested something + different. + """ + return self._get_param("content_type", "application/json") + + @property + def _interval(self) -> int: + """ + Returns 0 unless the test has requested something different. + """ + return self._get_param("interval", 0) + + @property + def _retry_code(self) -> str: + """ + Returns "authorization_pending" unless the test has requested something + different. + """ + return self._get_param("retry_code", "authorization_pending") + + @property + def _uri_spelling(self) -> str: + """ + Returns "verification_uri" unless the test has requested something + different. + """ + return self._get_param("uri_spelling", "verification_uri") + + @property + def _response_padding(self): + """ + If the huge_response test parameter is set to True, returns a dict + containing a gigantic string value, which can then be folded into a JSON + response. + """ + if not self._get_param("huge_response", False): + return dict() + + return {"_pad_": "x" * 1024 * 1024} + + @property + def _access_token(self): + """ + The actual Bearer token sent back to the client on success. Tests may + override this with the "token" test parameter. + """ + token = self._get_param("token", None) + if token is not None: + return token + + token = "9243959234" + if self._alt_issuer: + token += "-alt" + + return token + + def _send_json(self, js: JsonObject) -> None: + """ + Sends the provided JSON dict as an application/json response. + self._response_code can be modified to send JSON error responses. + """ + resp = json.dumps(js).encode("ascii") + self.log_message("sending JSON response: %s", resp) + + self.send_response(self._response_code) + self.send_header("Content-Type", self._content_type) + self.send_header("Content-Length", str(len(resp))) + self.end_headers() + + self.wfile.write(resp) + + def config(self) -> JsonObject: + port = self.server.socket.getsockname()[1] + + issuer = f"http://localhost:{port}" + if self._alt_issuer: + issuer += "/alternate" + elif self._parameterized: + issuer += "/param" + + return { + "issuer": issuer, + "token_endpoint": issuer + "/token", + "device_authorization_endpoint": issuer + "/authorize", + "response_types_supported": ["token"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "grant_types_supported": [ + "authorization_code", + "urn:ietf:params:oauth:grant-type:device_code", + ], + } + + @property + def _token_state(self): + """ + A cached _TokenState object for the connected client (as determined by + the request's client_id), or a new one if it doesn't already exist. + + This relies on the existence of a defaultdict attached to the server; + see main() below. + """ + return self.server.token_state[self.client_id] + + def _remove_token_state(self): + """ + Removes any cached _TokenState for the current client_id. Call this + after the token exchange ends to get rid of unnecessary state. + """ + if self.client_id in self.server.token_state: + del self.server.token_state[self.client_id] + + def authorization(self) -> JsonObject: + uri = "https://example.com/" + if self._alt_issuer: + uri = "https://example.org/" + + resp = { + "device_code": "postgres", + "user_code": "postgresuser", + self._uri_spelling: uri, + "expires_in": 5, + **self._response_padding, + } + + interval = self._interval + if interval is not None: + resp["interval"] = interval + self._token_state.min_delay = interval + else: + self._token_state.min_delay = 5 # default + + # Check the scope. + if "scope" in self._params: + assert self._params["scope"][0], "empty scopes should be omitted" + + return resp + + def token(self) -> JsonObject: + if err := self._get_param("error_code", None): + self._response_code = self._get_param("error_status", 400) + + resp = {"error": err} + if desc := self._get_param("error_desc", ""): + resp["error_description"] = desc + + return resp + + if self._should_modify() and "retries" in self._test_params: + retries = self._test_params["retries"] + + # Check to make sure the token interval is being respected. + now = time.monotonic() + if self._token_state.last_try is not None: + delay = now - self._token_state.last_try + assert ( + delay > self._token_state.min_delay + ), f"client waited only {delay} seconds between token requests (expected {self._token_state.min_delay})" + + self._token_state.last_try = now + + # If we haven't reached the required number of retries yet, return a + # "pending" response. + if self._token_state.retries < retries: + self._token_state.retries += 1 + + self._response_code = 400 + return {"error": self._retry_code} + + # Clean up any retry tracking state now that the exchange is ending. + self._remove_token_state() + + return { + "access_token": self._access_token, + "token_type": "bearer", + **self._response_padding, + } + + +def main(): + """ + Starts the authorization server on localhost. The ephemeral port in use will + be printed to stdout. + """ + + s = http.server.HTTPServer(("127.0.0.1", 0), OAuthHandler) + + # Attach a "cache" dictionary to the server to allow the OAuthHandlers to + # track state across token requests. The use of defaultdict ensures that new + # entries will be created automatically. + class _TokenState: + retries = 0 + min_delay = None + last_try = None + + s.token_state = defaultdict(_TokenState) + + # Give the parent the port number to contact (this is also the signal that + # we're ready to receive requests). + port = s.socket.getsockname()[1] + print(port) + + # stdout is closed to allow the parent to just "read to the end". + stdout = sys.stdout.fileno() + sys.stdout.close() + os.close(stdout) + + s.serve_forever() # we expect our parent to send a termination signal + + +if __name__ == "__main__": + main() |