Skip to content

Commit

Permalink
Merge pull request #92 from IdentityPython/entity_id2base_url
Browse files Browse the repository at this point in the history
In case base_url is not defined use entity_id.
  • Loading branch information
rohe authored Jan 10, 2024
2 parents 195e0c2 + 023f058 commit f20b465
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 17 deletions.
17 changes: 12 additions & 5 deletions src/idpyoidc/claims.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""):

return keyjar, _uri_path

def get_base_url(self, configuration: dict):
def get_base_url(self, configuration: dict, entity_id: Optional[str]=""):
raise NotImplementedError()

def get_id(self, configuration: dict):
Expand All @@ -134,7 +134,10 @@ def add_extra_keys(self, keyjar, id):
def get_jwks(self, keyjar):
return keyjar.export_jwks()

def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None):
def handle_keys(self,
configuration: dict,
keyjar: Optional[KeyJar] = None,
entity_id: Optional[str] = ""):
_jwks = _jwks_uri = None
_id = self.get_id(configuration)
keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id)
Expand All @@ -147,15 +150,19 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None):
if "jwks_uri" in configuration: # simple
_jwks_uri = configuration.get("jwks_uri")
elif uri_path:
_base_url = self.get_base_url(configuration)
_base_url = self.get_base_url(configuration, entity_id=entity_id)
_jwks_uri = add_path(_base_url, uri_path)
else: # jwks or nothing
_jwks = self.get_jwks(keyjar)

return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri}

def load_conf(
self, configuration: dict, supports: dict, keyjar: Optional[KeyJar] = None
self,
configuration: dict,
supports: dict,
keyjar: Optional[KeyJar] = None,
entity_id: Optional[str] = ""
) -> KeyJar:
for attr, val in configuration.items():
if attr in ["preference", "capabilities"]:
Expand All @@ -167,7 +174,7 @@ def load_conf(

self.locals(configuration)

for key, val in self.handle_keys(configuration, keyjar=keyjar).items():
for key, val in self.handle_keys(configuration, keyjar=keyjar, entity_id=entity_id).items():
if key == "keyjar":
keyjar = val
elif val:
Expand Down
9 changes: 7 additions & 2 deletions src/idpyoidc/client/claims/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from cryptojwt import KeyJar
from cryptojwt.exception import IssuerNotFound
from cryptojwt.jwk.hmac import SYMKey
Expand All @@ -11,10 +13,13 @@ def get_client_authn_methods():


class Claims(claims.Claims):
def get_base_url(self, configuration: dict):
def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""):
_base = configuration.get("base_url")
if not _base:
_base = configuration.get("client_id")
if entity_id:
_base = entity_id
else:
_base = configuration.get("client_id")

return _base

Expand Down
3 changes: 2 additions & 1 deletion src/idpyoidc/client/service_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def __init__(
for key, val in kwargs.items():
setattr(self, key, val)

self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), keyjar=keyjar)
self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), keyjar=keyjar,
entity_id=self.entity_id)

_jwks_uri = self.provider_info.get("jwks_uri")
if _jwks_uri:
Expand Down
17 changes: 10 additions & 7 deletions src/idpyoidc/metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from functools import cmp_to_key
import logging
from typing import Callable
from typing import Optional

Expand Down Expand Up @@ -128,7 +128,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""):
_uri_path = conf["key_conf"].get("uri_path")
return keyjar, _uri_path

def get_base_url(self, configuration: dict):
def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""):
raise NotImplementedError()

def get_id(self, configuration: dict):
Expand All @@ -140,9 +140,11 @@ def add_extra_keys(self, keyjar, id):
def get_jwks(self, keyjar):
return None

def handle_keys(
self, configuration: dict, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = ""
):
def handle_keys(self,
configuration: dict,
keyjar: Optional[KeyJar] = None,
base_url: Optional[str] = "",
entity_id: Optional[str] = ""):
_jwks = _jwks_uri = None
_id = self.get_id(configuration)
keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id)
Expand All @@ -154,15 +156,16 @@ def handle_keys(
_jwks_uri = configuration.get("jwks_uri")
elif uri_path:
if not base_url:
base_url = self.get_base_url(configuration)
base_url = self.get_base_url(configuration, entity_id=entity_id)
_jwks_uri = add_path(base_url, uri_path)
else: # jwks or nothing
_jwks = self.get_jwks(keyjar)

return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri}

def load_conf(
self, configuration, supports, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = ""
self, configuration, supports, keyjar: Optional[KeyJar] = None,
base_url: Optional[str] = ""
):
for attr, val in configuration.items():
if attr == "preference":
Expand Down
7 changes: 5 additions & 2 deletions src/idpyoidc/server/claims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@


class Claims(claims.Claims):
def get_base_url(self, configuration: dict):
def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""):
_base = configuration.get("base_url")
if not _base:
_base = configuration.get("issuer")
if entity_id:
_base = entity_id
else:
_base = configuration.get("issuer")

return _base

Expand Down

0 comments on commit f20b465

Please sign in to comment.