Skip to content

Commit 8bdd711

Browse files
SK-2270: add more unit tests scenario for validation methods
1 parent c12edd7 commit 8bdd711

File tree

1 file changed

+179
-21
lines changed

1 file changed

+179
-21
lines changed

tests/utils/validations/test__validations.py

Lines changed: 179 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import Mock, patch, MagicMock
33
import tempfile
44
import os
5+
56
from skyflow.error import SkyflowError
67
from skyflow.utils.validations._validations import (
78
validate_required_field, validate_api_key, validate_credentials,
@@ -14,10 +15,12 @@
1415
validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request
1516
)
1617
from skyflow.utils import SkyflowMessages
17-
from skyflow.utils.enums import DetectEntities
18-
from skyflow.vault.detect import DeidentifyTextRequest, Transformations, DateTransformation, ReidentifyTextRequest, FileInput
19-
from skyflow.vault.detect._deidentify_file_request import DeidentifyFileRequest
20-
18+
from skyflow.utils.enums import DetectEntities, RedactionType
19+
from skyflow.vault.data import GetRequest, UpdateRequest
20+
from skyflow.vault.detect import DeidentifyTextRequest, Transformations, DateTransformation, ReidentifyTextRequest, \
21+
FileInput, DeidentifyFileRequest
22+
from skyflow.vault.tokens import DetokenizeRequest
23+
from skyflow.vault.connection._invoke_connection_request import InvokeConnectionRequest
2124

2225
class TestValidations(unittest.TestCase):
2326
@classmethod
@@ -154,6 +157,10 @@ class InvalidEnum:
154157
validate_log_level(self.logger, invalid_log_level)
155158
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_LOG_LEVEL.value)
156159

160+
def test_validate_log_level_none(self):
161+
with self.assertRaises(SkyflowError) as context:
162+
validate_log_level(self.logger, None)
163+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_LOG_LEVEL.value)
157164

158165
def test_validate_keys_valid(self):
159166
config = {"vault_id": "test_id", "cluster_id": "test_cluster"}
@@ -389,7 +396,7 @@ def test_validate_query_request_empty_query(self):
389396

390397
def test_validate_query_request_invalid_query_type(self):
391398
request = MagicMock()
392-
request.query = 123 # Invalid type
399+
request.query = 123
393400
with self.assertRaises(SkyflowError) as context:
394401
validate_query_request(self.logger, request)
395402
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(str(type(123))))
@@ -438,7 +445,7 @@ def test_validate_get_request_valid(self):
438445

439446
def test_validate_get_request_invalid_table_type(self):
440447
request = MagicMock()
441-
request.table = 123 # Invalid type
448+
request.table = 123
442449
with self.assertRaises(SkyflowError) as context:
443450
validate_get_request(self.logger, request)
444451
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value)
@@ -450,6 +457,63 @@ def test_validate_get_request_empty_table(self):
450457
validate_get_request(self.logger, request)
451458
self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TABLE_VALUE.value)
452459

460+
def test_validate_get_request_invalid_redaction_type(self):
461+
request = GetRequest(
462+
table="test_table",
463+
fields="invalid",
464+
ids=["id1", "id2"],
465+
redaction_type="invalid"
466+
)
467+
468+
with self.assertRaises(SkyflowError) as context:
469+
validate_get_request(self.logger, request)
470+
self.assertEqual(context.exception.message,
471+
SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(request.redaction_type)))
472+
473+
def test_validate_get_request_invalid_fields_type(self):
474+
request= GetRequest(
475+
table="test_table",
476+
fields="invalid"
477+
)
478+
with self.assertRaises(SkyflowError) as context:
479+
validate_get_request(self.logger, request)
480+
self.assertEqual(context.exception.message,
481+
SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(request.fields)))
482+
483+
def test_validate_get_request_empty_fields(self):
484+
request = GetRequest(
485+
table="test_table",
486+
ids=[],
487+
fields=[]
488+
)
489+
with self.assertRaises(SkyflowError) as context:
490+
validate_get_request(self.logger, request)
491+
self.assertEqual(context.exception.message,
492+
SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(request.fields)))
493+
494+
def test_validate_get_request_invalid_column_values_type(self):
495+
request = GetRequest(
496+
table="test_table",
497+
column_name="test_column",
498+
column_values="invalid",
499+
)
500+
501+
with self.assertRaises(SkyflowError) as context:
502+
validate_get_request(self.logger, request)
503+
self.assertEqual(context.exception.message,
504+
SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(request.column_values)))
505+
506+
def test_validate_get_request_tokens_with_redaction(self):
507+
request = GetRequest(
508+
table="test_table",
509+
return_tokens=True,
510+
redaction_type = RedactionType.PLAIN_TEXT
511+
)
512+
513+
with self.assertRaises(SkyflowError) as context:
514+
validate_get_request(self.logger, request)
515+
self.assertEqual(context.exception.message,
516+
SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value)
453517

454518
def test_validate_query_request_valid_complex(self):
455519
request = MagicMock()
@@ -473,16 +537,25 @@ def test_validate_update_request_valid(self):
473537
request.token_mode = None
474538
request.tokens = None
475539
validate_update_request(self.logger, request)
476-
477540

478541
def test_validate_update_request_invalid_table_type(self):
479-
request = MagicMock()
480-
request.table = 123 # Invalid type
481-
request.data = {"skyflow_id": "id123"}
542+
request = UpdateRequest(
543+
table=123,
544+
data = {"skyflow_id": "id123"}
545+
)
482546
with self.assertRaises(SkyflowError) as context:
483547
validate_update_request(self.logger, request)
484548
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value)
485549

550+
def test_validate_update_request_invalid_token_mode(self):
551+
request = UpdateRequest(
552+
table="test_table",
553+
data = {"skyflow_id": "id123", "field1": "value1"},
554+
token_mode = "invalid"
555+
)
556+
with self.assertRaises(SkyflowError) as context:
557+
validate_update_request(self.logger, request)
558+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value)
486559

487560
def test_validate_detokenize_request_valid(self):
488561
request = MagicMock()
@@ -500,7 +573,7 @@ def test_validate_detokenize_request_empty_data(self):
500573

501574
def test_validate_detokenize_request_invalid_token(self):
502575
request = MagicMock()
503-
request.data = [{"token": 123}]
576+
request.data = [{"token": 123}] # Invalid token type
504577
request.continue_on_error = False
505578
with self.assertRaises(SkyflowError) as context:
506579
validate_detokenize_request(self.logger, request)
@@ -515,15 +588,15 @@ def test_validate_tokenize_request_valid(self):
515588

516589
def test_validate_tokenize_request_invalid_values_type(self):
517590
request = MagicMock()
518-
request.values = "invalid"
591+
request.values = "invalid" # Should be list
519592
with self.assertRaises(SkyflowError) as context:
520593
validate_tokenize_request(self.logger, request)
521594
self.assertEqual(context.exception.message,
522595
SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(request.values)))
523596

524597
def test_validate_tokenize_request_empty_values(self):
525598
request = MagicMock()
526-
request.values = []
599+
request.values = [] # Empty list
527600
with self.assertRaises(SkyflowError) as context:
528601
validate_tokenize_request(self.logger, request)
529602
self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value)
@@ -541,18 +614,56 @@ def test_validate_invoke_connection_params_valid(self):
541614
path_params = {"path1": "value1"}
542615
validate_invoke_connection_params(self.logger, query_params, path_params)
543616

544-
def test_validate_invoke_connection_params_invalid_path_params(self):
545-
query_params = {"param1": "value1"}
546-
path_params = "invalid" # Should be dict
617+
def test_validate_invoke_connection_params_invalid_path_params_type(self):
618+
request = InvokeConnectionRequest(
619+
method="GET",
620+
query_params={"param1": "value1"},
621+
path_params="invalid"
622+
)
547623
with self.assertRaises(SkyflowError) as context:
548-
validate_invoke_connection_params(self.logger, query_params, path_params)
624+
validate_invoke_connection_params(self.logger, request.query_params, request.path_params)
549625
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_PATH_PARAMS.value)
550626

551-
def test_validate_invoke_connection_params_invalid_query_params(self):
552-
query_params = "invalid" # Should be dict
553-
path_params = {"path1": "value1"}
627+
def test_validate_invoke_connection_params_invalid_query_params_type(self):
628+
request = InvokeConnectionRequest(
629+
method="GET",
630+
query_params="invalid",
631+
path_params={"path1": "value1"}
632+
)
554633
with self.assertRaises(SkyflowError) as context:
555-
validate_invoke_connection_params(self.logger, query_params, path_params)
634+
validate_invoke_connection_params(self.logger, request.query_params, request.path_params)
635+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_PARAMS.value)
636+
637+
def test_validate_invoke_connection_params_non_string_path_param(self):
638+
request = InvokeConnectionRequest(
639+
method="GET",
640+
query_params={"param1": "value1"},
641+
path_params={1: "value1"}
642+
)
643+
with self.assertRaises(SkyflowError) as context:
644+
validate_invoke_connection_params(self.logger, request.query_params, request.path_params)
645+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_PATH_PARAMS.value)
646+
647+
def test_validate_invoke_connection_params_non_string_query_param_key(self):
648+
request = InvokeConnectionRequest(
649+
method="GET",
650+
query_params={1: "value1"},
651+
path_params={"path1": "value1"}
652+
)
653+
with self.assertRaises(SkyflowError) as context:
654+
validate_invoke_connection_params(self.logger, request.query_params, request.path_params)
655+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_PARAMS.value)
656+
657+
def test_validate_invoke_connection_params_non_serializable_query_params(self):
658+
class NonSerializable:
659+
pass
660+
request = InvokeConnectionRequest(
661+
method="GET",
662+
query_params={"param1": NonSerializable()},
663+
path_params={"path1": "value1"}
664+
)
665+
with self.assertRaises(SkyflowError) as context:
666+
validate_invoke_connection_params(self.logger, request.query_params, request.path_params)
556667
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_PARAMS.value)
557668

558669
def test_validate_deidentify_text_request_valid(self):
@@ -717,6 +828,16 @@ def test_validate_reidentify_text_request_invalid_redacted_entities(self):
717828
self.assertEqual(context.exception.message,
718829
SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value)
719830

831+
def test_validate_reidentify_text_request_invalid_plain_text_entities(self):
832+
request = ReidentifyTextRequest(
833+
text="test text",
834+
plain_text_entities="invalid"
835+
)
836+
with self.assertRaises(SkyflowError) as context:
837+
validate_reidentify_text_request(self.logger, request)
838+
self.assertEqual(context.exception.message,
839+
SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value)
840+
720841
def test_validate_deidentify_file_request_valid(self):
721842
file_input = FileInput(file_path=self.temp_file_path)
722843
request = DeidentifyFileRequest(
@@ -886,3 +1007,40 @@ def test_validate_deidentify_file_request_invalid_wait_time(self):
8861007
with self.assertRaises(SkyflowError) as context:
8871008
validate_deidentify_file_request(self.logger, request)
8881009
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_WAIT_TIME.value)
1010+
1011+
def test_validate_detokenize_request_valid(self):
1012+
request = DetokenizeRequest(
1013+
data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT}],
1014+
continue_on_error=False
1015+
)
1016+
validate_detokenize_request(self.logger, request)
1017+
1018+
def test_validate_detokenize_request_empty_data(self):
1019+
request = DetokenizeRequest(data=[], continue_on_error=False)
1020+
with self.assertRaises(SkyflowError) as context:
1021+
validate_detokenize_request(self.logger, request)
1022+
self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value)
1023+
1024+
def test_validate_detokenize_request_invalid_token_type(self):
1025+
request = DetokenizeRequest(data=[{"token": 123}], continue_on_error=False)
1026+
with self.assertRaises(SkyflowError) as context:
1027+
validate_detokenize_request(self.logger, request)
1028+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE"))
1029+
1030+
def test_validate_detokenize_request_missing_token_key(self):
1031+
request = DetokenizeRequest(data=[{"not_token": "value"}], continue_on_error=False)
1032+
with self.assertRaises(SkyflowError) as context:
1033+
validate_detokenize_request(self.logger, request)
1034+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(str(type(request.data))))
1035+
1036+
def test_validate_detokenize_request_invalid_continue_on_error_type(self):
1037+
request = DetokenizeRequest(data=[{"token": "token123"}], continue_on_error="invalid")
1038+
with self.assertRaises(SkyflowError) as context:
1039+
validate_detokenize_request(self.logger, request)
1040+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value)
1041+
1042+
def test_validate_detokenize_request_invalid_redaction_type(self):
1043+
request = DetokenizeRequest(data=[{"token": "token123", "redaction": "invalid"}], continue_on_error=False)
1044+
with self.assertRaises(SkyflowError) as context:
1045+
validate_detokenize_request(self.logger, request)
1046+
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid"))))

0 commit comments

Comments
 (0)