aboutsummaryrefslogtreecommitdiff
path: root/src/test/modules/oauth_validator/t/oauth_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/modules/oauth_validator/t/oauth_server.py')
-rwxr-xr-xsrc/test/modules/oauth_validator/t/oauth_server.py391
1 files changed, 391 insertions, 0 deletions
diff --git a/src/test/modules/oauth_validator/t/oauth_server.py b/src/test/modules/oauth_validator/t/oauth_server.py
new file mode 100755
index 00000000000..4faf3323d38
--- /dev/null
+++ b/src/test/modules/oauth_validator/t/oauth_server.py
@@ -0,0 +1,391 @@
+#! /usr/bin/env python3
+#
+# A mock OAuth authorization server, designed to be invoked from
+# OAuth/Server.pm. This listens on an ephemeral port number (printed to stdout
+# so that the Perl tests can contact it) and runs as a daemon until it is
+# signaled.
+#
+
+import base64
+import http.server
+import json
+import os
+import sys
+import time
+import urllib.parse
+from collections import defaultdict
+
+
+class OAuthHandler(http.server.BaseHTTPRequestHandler):
+ """
+ Core implementation of the authorization server. The API is
+ inheritance-based, with entry points at do_GET() and do_POST(). See the
+ documentation for BaseHTTPRequestHandler.
+ """
+
+ JsonObject = dict[str, object] # TypeAlias is not available until 3.10
+
+ def _check_issuer(self):
+ """
+ Switches the behavior of the provider depending on the issuer URI.
+ """
+ self._alt_issuer = (
+ self.path.startswith("/alternate/")
+ or self.path == "/.well-known/oauth-authorization-server/alternate"
+ )
+ self._parameterized = self.path.startswith("/param/")
+
+ if self._alt_issuer:
+ # The /alternate issuer uses IETF-style .well-known URIs.
+ if self.path.startswith("/.well-known/"):
+ self.path = self.path.removesuffix("/alternate")
+ else:
+ self.path = self.path.removeprefix("/alternate")
+ elif self._parameterized:
+ self.path = self.path.removeprefix("/param")
+
+ def _check_authn(self):
+ """
+ Checks the expected value of the Authorization header, if any.
+ """
+ secret = self._get_param("expected_secret", None)
+ if secret is None:
+ return
+
+ assert "Authorization" in self.headers
+ method, creds = self.headers["Authorization"].split()
+
+ if method != "Basic":
+ raise RuntimeError(f"client used {method} auth; expected Basic")
+
+ username = urllib.parse.quote_plus(self.client_id)
+ password = urllib.parse.quote_plus(secret)
+ expected_creds = f"{username}:{password}"
+
+ if creds.encode() != base64.b64encode(expected_creds.encode()):
+ raise RuntimeError(
+ f"client sent '{creds}'; expected b64encode('{expected_creds}')"
+ )
+
+ def do_GET(self):
+ self._response_code = 200
+ self._check_issuer()
+
+ config_path = "/.well-known/openid-configuration"
+ if self._alt_issuer:
+ config_path = "/.well-known/oauth-authorization-server"
+
+ if self.path == config_path:
+ resp = self.config()
+ else:
+ self.send_error(404, "Not Found")
+ return
+
+ self._send_json(resp)
+
+ def _parse_params(self) -> dict[str, str]:
+ """
+ Parses apart the form-urlencoded request body and returns the resulting
+ dict. For use by do_POST().
+ """
+ size = int(self.headers["Content-Length"])
+ form = self.rfile.read(size)
+
+ assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
+ return urllib.parse.parse_qs(
+ form.decode("utf-8"),
+ strict_parsing=True,
+ keep_blank_values=True,
+ encoding="utf-8",
+ errors="strict",
+ )
+
+ @property
+ def client_id(self) -> str:
+ """
+ Returns the client_id sent in the POST body or the Authorization header.
+ self._parse_params() must have been called first.
+ """
+ if "client_id" in self._params:
+ return self._params["client_id"][0]
+
+ if "Authorization" not in self.headers:
+ raise RuntimeError("client did not send any client_id")
+
+ _, creds = self.headers["Authorization"].split()
+
+ decoded = base64.b64decode(creds).decode("utf-8")
+ username, _ = decoded.split(":", 1)
+
+ return urllib.parse.unquote_plus(username)
+
+ def do_POST(self):
+ self._response_code = 200
+ self._check_issuer()
+
+ self._params = self._parse_params()
+ if self._parameterized:
+ # Pull encoded test parameters out of the peer's client_id field.
+ # This is expected to be Base64-encoded JSON.
+ js = base64.b64decode(self.client_id)
+ self._test_params = json.loads(js)
+
+ self._check_authn()
+
+ if self.path == "/authorize":
+ resp = self.authorization()
+ elif self.path == "/token":
+ resp = self.token()
+ else:
+ self.send_error(404)
+ return
+
+ self._send_json(resp)
+
+ def _should_modify(self) -> bool:
+ """
+ Returns True if the client has requested a modification to this stage of
+ the exchange.
+ """
+ if not hasattr(self, "_test_params"):
+ return False
+
+ stage = self._test_params.get("stage")
+
+ return (
+ stage == "all"
+ or (
+ stage == "discovery"
+ and self.path == "/.well-known/openid-configuration"
+ )
+ or (stage == "device" and self.path == "/authorize")
+ or (stage == "token" and self.path == "/token")
+ )
+
+ def _get_param(self, name, default):
+ """
+ If the client has requested a modification to this stage (see
+ _should_modify()), this method searches the provided test parameters for
+ a key of the given name, and returns it if found. Otherwise the provided
+ default is returned.
+ """
+ if self._should_modify() and name in self._test_params:
+ return self._test_params[name]
+
+ return default
+
+ @property
+ def _content_type(self) -> str:
+ """
+ Returns "application/json" unless the test has requested something
+ different.
+ """
+ return self._get_param("content_type", "application/json")
+
+ @property
+ def _interval(self) -> int:
+ """
+ Returns 0 unless the test has requested something different.
+ """
+ return self._get_param("interval", 0)
+
+ @property
+ def _retry_code(self) -> str:
+ """
+ Returns "authorization_pending" unless the test has requested something
+ different.
+ """
+ return self._get_param("retry_code", "authorization_pending")
+
+ @property
+ def _uri_spelling(self) -> str:
+ """
+ Returns "verification_uri" unless the test has requested something
+ different.
+ """
+ return self._get_param("uri_spelling", "verification_uri")
+
+ @property
+ def _response_padding(self):
+ """
+ If the huge_response test parameter is set to True, returns a dict
+ containing a gigantic string value, which can then be folded into a JSON
+ response.
+ """
+ if not self._get_param("huge_response", False):
+ return dict()
+
+ return {"_pad_": "x" * 1024 * 1024}
+
+ @property
+ def _access_token(self):
+ """
+ The actual Bearer token sent back to the client on success. Tests may
+ override this with the "token" test parameter.
+ """
+ token = self._get_param("token", None)
+ if token is not None:
+ return token
+
+ token = "9243959234"
+ if self._alt_issuer:
+ token += "-alt"
+
+ return token
+
+ def _send_json(self, js: JsonObject) -> None:
+ """
+ Sends the provided JSON dict as an application/json response.
+ self._response_code can be modified to send JSON error responses.
+ """
+ resp = json.dumps(js).encode("ascii")
+ self.log_message("sending JSON response: %s", resp)
+
+ self.send_response(self._response_code)
+ self.send_header("Content-Type", self._content_type)
+ self.send_header("Content-Length", str(len(resp)))
+ self.end_headers()
+
+ self.wfile.write(resp)
+
+ def config(self) -> JsonObject:
+ port = self.server.socket.getsockname()[1]
+
+ issuer = f"http://localhost:{port}"
+ if self._alt_issuer:
+ issuer += "/alternate"
+ elif self._parameterized:
+ issuer += "/param"
+
+ return {
+ "issuer": issuer,
+ "token_endpoint": issuer + "/token",
+ "device_authorization_endpoint": issuer + "/authorize",
+ "response_types_supported": ["token"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ "grant_types_supported": [
+ "authorization_code",
+ "urn:ietf:params:oauth:grant-type:device_code",
+ ],
+ }
+
+ @property
+ def _token_state(self):
+ """
+ A cached _TokenState object for the connected client (as determined by
+ the request's client_id), or a new one if it doesn't already exist.
+
+ This relies on the existence of a defaultdict attached to the server;
+ see main() below.
+ """
+ return self.server.token_state[self.client_id]
+
+ def _remove_token_state(self):
+ """
+ Removes any cached _TokenState for the current client_id. Call this
+ after the token exchange ends to get rid of unnecessary state.
+ """
+ if self.client_id in self.server.token_state:
+ del self.server.token_state[self.client_id]
+
+ def authorization(self) -> JsonObject:
+ uri = "https://example.com/"
+ if self._alt_issuer:
+ uri = "https://example.org/"
+
+ resp = {
+ "device_code": "postgres",
+ "user_code": "postgresuser",
+ self._uri_spelling: uri,
+ "expires_in": 5,
+ **self._response_padding,
+ }
+
+ interval = self._interval
+ if interval is not None:
+ resp["interval"] = interval
+ self._token_state.min_delay = interval
+ else:
+ self._token_state.min_delay = 5 # default
+
+ # Check the scope.
+ if "scope" in self._params:
+ assert self._params["scope"][0], "empty scopes should be omitted"
+
+ return resp
+
+ def token(self) -> JsonObject:
+ if err := self._get_param("error_code", None):
+ self._response_code = self._get_param("error_status", 400)
+
+ resp = {"error": err}
+ if desc := self._get_param("error_desc", ""):
+ resp["error_description"] = desc
+
+ return resp
+
+ if self._should_modify() and "retries" in self._test_params:
+ retries = self._test_params["retries"]
+
+ # Check to make sure the token interval is being respected.
+ now = time.monotonic()
+ if self._token_state.last_try is not None:
+ delay = now - self._token_state.last_try
+ assert (
+ delay > self._token_state.min_delay
+ ), f"client waited only {delay} seconds between token requests (expected {self._token_state.min_delay})"
+
+ self._token_state.last_try = now
+
+ # If we haven't reached the required number of retries yet, return a
+ # "pending" response.
+ if self._token_state.retries < retries:
+ self._token_state.retries += 1
+
+ self._response_code = 400
+ return {"error": self._retry_code}
+
+ # Clean up any retry tracking state now that the exchange is ending.
+ self._remove_token_state()
+
+ return {
+ "access_token": self._access_token,
+ "token_type": "bearer",
+ **self._response_padding,
+ }
+
+
+def main():
+ """
+ Starts the authorization server on localhost. The ephemeral port in use will
+ be printed to stdout.
+ """
+
+ s = http.server.HTTPServer(("127.0.0.1", 0), OAuthHandler)
+
+ # Attach a "cache" dictionary to the server to allow the OAuthHandlers to
+ # track state across token requests. The use of defaultdict ensures that new
+ # entries will be created automatically.
+ class _TokenState:
+ retries = 0
+ min_delay = None
+ last_try = None
+
+ s.token_state = defaultdict(_TokenState)
+
+ # Give the parent the port number to contact (this is also the signal that
+ # we're ready to receive requests).
+ port = s.socket.getsockname()[1]
+ print(port)
+
+ # stdout is closed to allow the parent to just "read to the end".
+ stdout = sys.stdout.fileno()
+ sys.stdout.close()
+ os.close(stdout)
+
+ s.serve_forever() # we expect our parent to send a termination signal
+
+
+if __name__ == "__main__":
+ main()