Skip to content

Commit f7d26df

Browse files
SK-2645: Merge branch 'release/26.1.0' into release/26.1.4
2 parents 87cf3ae + 3294b09 commit f7d26df

12 files changed

Lines changed: 450 additions & 48 deletions

File tree

skyflow/service_account/_utils.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import datetime
33
import time
44
import jwt
5+
from urllib.parse import urlparse
56
from skyflow.error import SkyflowError
67
from skyflow.service_account.client.auth_client import AuthClient
78
from skyflow.utils.logger import log_info, log_error_log
89
from skyflow.utils import get_base_url, format_scope, SkyflowMessages
910
from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField
11+
from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError
12+
from skyflow.utils import is_valid_url
1013

1114

1215
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value
@@ -79,7 +82,14 @@ def get_service_account_token(credentials, options, logger):
7982
except:
8083
log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger)
8184
raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code)
82-
85+
86+
if not isinstance(token_uri, str) or not is_valid_url(token_uri):
87+
log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger)
88+
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)
89+
90+
if options and "token_uri" in options:
91+
token_uri = options["token_uri"]
92+
8393
signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger)
8494
base_url = get_base_url(token_uri)
8595
auth_client = AuthClient(base_url)
@@ -89,10 +99,17 @@ def get_service_account_token(credentials, options, logger):
8999
if options and OptionField.ROLE_IDS in options:
90100
formatted_scope = format_scope(options.get(OptionField.ROLE_IDS))
91101

92-
response = auth_api.authentication_service_get_auth_token(assertion = signed_token,
93-
grant_type=JWT.GRANT_TYPE_JWT_BEARER,
102+
try:
103+
response = auth_api.authentication_service_get_auth_token(assertion = signed_token,
104+
grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer",
94105
scope=formatted_scope)
95-
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger)
106+
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger)
107+
except UnauthorizedError:
108+
log_error_log(SkyflowMessages.ErrorLogs.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, logger=logger)
109+
raise SkyflowError(SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, invalid_input_error_code)
110+
except Exception:
111+
log_error_log(SkyflowMessages.ErrorLogs.FAILED_TO_GET_BEARER_TOKEN.value, logger=logger)
112+
raise SkyflowError(SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value, invalid_input_error_code)
96113
return response.access_token, response.token_type
97114

98115
def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
@@ -113,32 +130,41 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger):
113130

114131

115132
def get_signed_tokens(credentials_obj, options):
116-
try:
117-
expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60)
118-
prefix = JWT.SIGNED_TOKEN_PREFIX
119-
120-
if options and options.get(OptionField.DATA_TOKENS):
121-
for token in options[OptionField.DATA_TOKENS]:
122-
claims = {
123-
JwtField.ISS: JWT.ISSUER_SDK,
124-
JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID),
125-
JwtField.EXP: expiry_time,
126-
JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID),
127-
JwtField.TOK: token,
128-
JwtField.IAT: int(time.time()),
129-
}
130-
131-
if JwtField.CTX in options:
132-
claims[JwtField.CTX] = options[JwtField.CTX]
133-
134-
private_key = credentials_obj.get(CredentialField.PRIVATE_KEY)
133+
expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60)
134+
prefix = JWT.SIGNED_TOKEN_PREFIX
135+
136+
token_uri = credentials_obj.get("tokenURI")
137+
if not isinstance(token_uri, str) or not is_valid_url(token_uri):
138+
log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value)
139+
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)
140+
141+
if options and "token_uri" in options:
142+
token_uri = options["token_uri"]
143+
144+
145+
if options and options.get(OptionField.DATA_TOKENS):
146+
for token in options[OptionField.DATA_TOKENS]:
147+
claims = {
148+
JwtField.ISS: JWT.ISSUER_SDK,
149+
JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID),
150+
JwtField.EXP: expiry_time,
151+
JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID),
152+
JwtField.TOK: token,
153+
JwtField.IAT: int(time.time()),
154+
}
155+
156+
if JwtField.CTX in options:
157+
claims[JwtField.CTX] = options[JwtField.CTX]
158+
159+
private_key = credentials_obj.get(CredentialField.PRIVATE_KEY)
160+
try:
135161
signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256)
136-
response_object = get_signed_data_token_response_object(prefix + signed_jwt, token)
137-
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value)
138-
return response_object
162+
except Exception:
163+
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
139164

140-
except Exception:
141-
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
165+
response_object = get_signed_data_token_response_object(prefix + signed_jwt, token)
166+
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value)
167+
return response_object
142168

143169

144170
def generate_signed_data_tokens(credentials_file_path, options):

skyflow/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ..utils.enums import LogLevel, Env, TokenType
22
from ._skyflow_messages import SkyflowMessages
33
from ._version import SDK_VERSION
4-
from ._helpers import get_base_url, format_scope
4+
from ._helpers import get_base_url, format_scope, is_valid_url
55
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_detected_entity_to_entity_info

skyflow/utils/_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,11 @@ def get_base_url(url):
88
def format_scope(scopes):
99
if not scopes:
1010
return None
11-
return " ".join([f"role:{scope}" for scope in scopes])
11+
return " ".join([f"role:{scope}" for scope in scopes])
12+
13+
def is_valid_url(url):
14+
try:
15+
result = urlparse(url)
16+
return all([result.scheme in ("http", "https"), result.netloc])
17+
except Exception:
18+
return False

skyflow/utils/_skyflow_messages.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,13 @@ class Error(Enum):
158158
MISSING_CLIENT_ID = f"{error_prefix} Initialization failed. Unable to read client ID in credentials. Verify your client ID."
159159
MISSING_KEY_ID = f"{error_prefix} Initialization failed. Unable to read key ID in credentials. Verify your key ID."
160160
MISSING_TOKEN_URI = f"{error_prefix} Initialization failed. Unable to read token URI in credentials. Verify your token URI."
161+
INVALID_TOKEN_URI = f"{error_prefix} Initialization failed. Invalid Skyflow credentials. The token URI must be a string and a valid URL."
161162
JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials."
162163
JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials."
163164
FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents."
164165
INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable."
166+
FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token."
167+
UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token."
165168

166169
INVALID_TEXT_IN_DEIDENTIFY= f"{error_prefix} Validation error. The text field is required and must be a non-empty string. Specify a valid text."
167170
INVALID_ENTITIES_IN_DEIDENTIFY= f"{error_prefix} Validation error. The entities field must be an array of DetectEntities enums. Specify a valid entities."
@@ -336,6 +339,8 @@ class ErrorLogs(Enum):
336339
KEY_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Key ID is required."
337340
TOKEN_URI_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Token URI is required."
338341
INVALID_TOKEN_URI = f"{ERROR}: [{error_prefix}] Invalid value for token URI in credentials."
342+
FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token."
343+
UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token."
339344

340345

341346
TABLE_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Table is required."

skyflow/utils/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SDK_VERSION = '2.0.0.dev0+e33ae92'
1+
SDK_VERSION = '2.0.0.dev0+e33ae92'

skyflow/utils/validations/_validations.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \
1616
GetDetectRunRequest, Bleep, DeidentifyFileRequest
1717
from skyflow.vault.detect._file_input import FileInput
18+
from skyflow.utils._helpers import is_valid_url
1819

1920
valid_vault_config_keys = [
2021
ConfigField.VAULT_ID,
@@ -158,6 +159,15 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
158159
raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id)
159160
if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value,
160161
invalid_input_error_code)
162+
163+
if "token_uri" in credentials:
164+
token_uri = credentials.get("token_uri")
165+
if (
166+
token_uri is None
167+
or not isinstance(token_uri, str)
168+
or not is_valid_url(token_uri)
169+
):
170+
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code)
161171

162172
def validate_log_level(logger, log_level):
163173
if not isinstance(log_level, LogLevel):
@@ -222,10 +232,8 @@ def validate_update_vault_config(logger, config):
222232
if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env:
223233
raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code)
224234

225-
if ConfigField.CREDENTIALS not in config:
226-
raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code)
227-
228-
validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id)
235+
if ConfigField.CREDENTIALS in config and config.get(ConfigField.CREDENTIALS):
236+
validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id)
229237

230238
return True
231239

skyflow/vault/client/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def get_bearer_token(self, credentials):
6464
OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES),
6565
OptionField.CTX: self.__config.get(OptionField.CTX)
6666
}
67+
if "token_uri" in credentials and credentials.get("token_uri"):
68+
options["token_uri"] = credentials.get("token_uri")
6769

6870
if self.__bearer_token is None or self.__is_config_updated:
6971
if CredentialField.PATH in credentials:

tests/service_account/test__utils.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,73 @@ def test_generate_signed_data_tokens_from_creds_with_invalid_string(self):
143143
credentials_string = '{'
144144
with self.assertRaises(SkyflowError) as context:
145145
result = generate_signed_data_tokens_from_creds(credentials_string, options)
146-
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value)
146+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value)
147+
148+
@patch("skyflow.service_account._utils.AuthClient")
149+
@patch("skyflow.service_account._utils.get_signed_jwt")
150+
def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client):
151+
creds = {
152+
"privateKey": "private_key",
153+
"clientID": "client_id",
154+
"keyID": "key_id",
155+
"tokenURI": "https://valid-url.com"
156+
}
157+
options = {"role_ids": ["role1", "role2"]}
158+
mock_get_signed_jwt.return_value = "signed"
159+
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
160+
mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token",
161+
"token_type": "bearer"})
162+
access_token, token_type = get_service_account_token(creds, options, None)
163+
self.assertEqual(access_token, "token")
164+
self.assertEqual(token_type, "bearer")
165+
args, kwargs = mock_auth_api.authentication_service_get_auth_token.call_args
166+
self.assertIn("scope", kwargs)
167+
self.assertEqual(kwargs["scope"], "role:role1 role:role2")
168+
169+
@patch("skyflow.service_account._utils.AuthClient")
170+
@patch("skyflow.service_account._utils.get_signed_jwt")
171+
def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, mock_auth_client):
172+
creds = {
173+
"privateKey": "private_key",
174+
"clientID": "client_id",
175+
"keyID": "key_id",
176+
"tokenURI": "https://valid-url.com"
177+
}
178+
mock_get_signed_jwt.return_value = "signed"
179+
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
180+
from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError
181+
mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized")
182+
with self.assertRaises(SkyflowError) as context:
183+
get_service_account_token(creds, {}, None)
184+
self.assertEqual(context.exception.message,
185+
SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value)
186+
187+
@patch("skyflow.service_account._utils.AuthClient")
188+
@patch("skyflow.service_account._utils.get_signed_jwt")
189+
def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, mock_auth_client):
190+
creds = {
191+
"privateKey": "private_key",
192+
"clientID": "client_id",
193+
"keyID": "key_id",
194+
"tokenURI": "https://valid-url.com"
195+
}
196+
mock_get_signed_jwt.return_value = "signed"
197+
mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value
198+
mock_auth_api.authentication_service_get_auth_token.side_effect = Exception("some error")
199+
with self.assertRaises(SkyflowError) as context:
200+
get_service_account_token(creds, {}, None)
201+
self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value)
202+
203+
@patch("jwt.encode", side_effect=Exception("jwt error"))
204+
def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode):
205+
creds = {
206+
"privateKey": "private_key",
207+
"clientID": "client_id",
208+
"keyID": "key_id",
209+
"tokenURI": "https://valid-url.com"
210+
}
211+
options = {"data_tokens": ["token1"]}
212+
with self.assertRaises(SkyflowError) as context:
213+
from skyflow.service_account._utils import get_signed_tokens
214+
get_signed_tokens(creds, options)
215+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value)

tests/utils/test__helpers.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from skyflow.utils import get_base_url, format_scope
2+
from skyflow.utils import get_base_url, format_scope, is_valid_url
33

44
VALID_URL = "https://example.com/path?query=1"
55
BASE_URL = "https://example.com"
@@ -35,4 +35,27 @@ def test_format_scope_single_scope(self):
3535
def test_format_scope_special_characters(self):
3636
scopes_with_special_chars = ["admin", "user:write", "read-only"]
3737
expected_result = "role:admin role:user:write role:read-only"
38-
self.assertEqual(format_scope(scopes_with_special_chars), expected_result)
38+
self.assertEqual(format_scope(scopes_with_special_chars), expected_result)
39+
40+
def test_is_valid_url_valid(self):
41+
self.assertTrue(is_valid_url("https://example.com"))
42+
self.assertTrue(is_valid_url("http://example.com/path"))
43+
44+
def test_is_valid_url_invalid(self):
45+
self.assertFalse(is_valid_url("ftp://example.com"))
46+
self.assertFalse(is_valid_url("example.com"))
47+
self.assertFalse(is_valid_url("invalid-url"))
48+
self.assertFalse(is_valid_url(""))
49+
50+
def test_is_valid_url_none(self):
51+
self.assertFalse(is_valid_url(None))
52+
53+
def test_is_valid_url_no_scheme(self):
54+
self.assertFalse(is_valid_url("www.example.com"))
55+
56+
def test_is_valid_url_exception(self):
57+
class BadStr:
58+
def __str__(self):
59+
raise Exception("bad str")
60+
61+
self.assertFalse(is_valid_url(BadStr()))

0 commit comments

Comments
 (0)