22from unittest .mock import Mock , patch , MagicMock
33import tempfile
44import os
5+
56from skyflow .error import SkyflowError
67from skyflow .utils .validations ._validations import (
78 validate_required_field , validate_api_key , validate_credentials ,
1415 validate_deidentify_text_request , validate_reidentify_text_request , validate_deidentify_file_request
1516)
1617from skyflow .utils import SkyflowMessages
17- from skyflow .utils .enums import DetectEntities
18- from skyflow .vault .detect import DeidentifyTextRequest , Transformations , DateTransformation , ReidentifyTextRequest , FileInput
19- from skyflow .vault .detect ._deidentify_file_request import DeidentifyFileRequest
20-
18+ from skyflow .utils .enums import DetectEntities , RedactionType
19+ from skyflow .vault .data import GetRequest , UpdateRequest
20+ from skyflow .vault .detect import DeidentifyTextRequest , Transformations , DateTransformation , ReidentifyTextRequest , \
21+ FileInput , DeidentifyFileRequest
22+ from skyflow .vault .tokens import DetokenizeRequest
23+ from skyflow .vault .connection ._invoke_connection_request import InvokeConnectionRequest
2124
2225class TestValidations (unittest .TestCase ):
2326 @classmethod
@@ -154,6 +157,10 @@ class InvalidEnum:
154157 validate_log_level (self .logger , invalid_log_level )
155158 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_LOG_LEVEL .value )
156159
160+ def test_validate_log_level_none (self ):
161+ with self .assertRaises (SkyflowError ) as context :
162+ validate_log_level (self .logger , None )
163+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_LOG_LEVEL .value )
157164
158165 def test_validate_keys_valid (self ):
159166 config = {"vault_id" : "test_id" , "cluster_id" : "test_cluster" }
@@ -389,7 +396,7 @@ def test_validate_query_request_empty_query(self):
389396
390397 def test_validate_query_request_invalid_query_type (self ):
391398 request = MagicMock ()
392- request .query = 123 # Invalid type
399+ request .query = 123
393400 with self .assertRaises (SkyflowError ) as context :
394401 validate_query_request (self .logger , request )
395402 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_QUERY_TYPE .value .format (str (type (123 ))))
@@ -438,7 +445,7 @@ def test_validate_get_request_valid(self):
438445
439446 def test_validate_get_request_invalid_table_type (self ):
440447 request = MagicMock ()
441- request .table = 123 # Invalid type
448+ request .table = 123
442449 with self .assertRaises (SkyflowError ) as context :
443450 validate_get_request (self .logger , request )
444451 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_TABLE_VALUE .value )
@@ -450,6 +457,63 @@ def test_validate_get_request_empty_table(self):
450457 validate_get_request (self .logger , request )
451458 self .assertEqual (context .exception .message , SkyflowMessages .Error .EMPTY_TABLE_VALUE .value )
452459
460+ def test_validate_get_request_invalid_redaction_type (self ):
461+ request = GetRequest (
462+ table = "test_table" ,
463+ fields = "invalid" ,
464+ ids = ["id1" , "id2" ],
465+ redaction_type = "invalid"
466+ )
467+
468+ with self .assertRaises (SkyflowError ) as context :
469+ validate_get_request (self .logger , request )
470+ self .assertEqual (context .exception .message ,
471+ SkyflowMessages .Error .INVALID_REDACTION_TYPE .value .format (type (request .redaction_type )))
472+
473+ def test_validate_get_request_invalid_fields_type (self ):
474+ request = GetRequest (
475+ table = "test_table" ,
476+ fields = "invalid"
477+ )
478+ with self .assertRaises (SkyflowError ) as context :
479+ validate_get_request (self .logger , request )
480+ self .assertEqual (context .exception .message ,
481+ SkyflowMessages .Error .INVALID_FIELDS_VALUE .value .format (type (request .fields )))
482+
483+ def test_validate_get_request_empty_fields (self ):
484+ request = GetRequest (
485+ table = "test_table" ,
486+ ids = [],
487+ fields = []
488+ )
489+ with self .assertRaises (SkyflowError ) as context :
490+ validate_get_request (self .logger , request )
491+ self .assertEqual (context .exception .message ,
492+ SkyflowMessages .Error .INVALID_FIELDS_VALUE .value .format (type (request .fields )))
493+
494+ def test_validate_get_request_invalid_column_values_type (self ):
495+ request = GetRequest (
496+ table = "test_table" ,
497+ column_name = "test_column" ,
498+ column_values = "invalid" ,
499+ )
500+
501+ with self .assertRaises (SkyflowError ) as context :
502+ validate_get_request (self .logger , request )
503+ self .assertEqual (context .exception .message ,
504+ SkyflowMessages .Error .INVALID_COLUMN_VALUE .value .format (type (request .column_values )))
505+
506+ def test_validate_get_request_tokens_with_redaction (self ):
507+ request = GetRequest (
508+ table = "test_table" ,
509+ return_tokens = True ,
510+ redaction_type = RedactionType .PLAIN_TEXT
511+ )
512+
513+ with self .assertRaises (SkyflowError ) as context :
514+ validate_get_request (self .logger , request )
515+ self .assertEqual (context .exception .message ,
516+ SkyflowMessages .Error .REDACTION_WITH_TOKENS_NOT_SUPPORTED .value )
453517
454518 def test_validate_query_request_valid_complex (self ):
455519 request = MagicMock ()
@@ -473,16 +537,25 @@ def test_validate_update_request_valid(self):
473537 request .token_mode = None
474538 request .tokens = None
475539 validate_update_request (self .logger , request )
476-
477540
478541 def test_validate_update_request_invalid_table_type (self ):
479- request = MagicMock ()
480- request .table = 123 # Invalid type
481- request .data = {"skyflow_id" : "id123" }
542+ request = UpdateRequest (
543+ table = 123 ,
544+ data = {"skyflow_id" : "id123" }
545+ )
482546 with self .assertRaises (SkyflowError ) as context :
483547 validate_update_request (self .logger , request )
484548 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_TABLE_VALUE .value )
485549
550+ def test_validate_update_request_invalid_token_mode (self ):
551+ request = UpdateRequest (
552+ table = "test_table" ,
553+ data = {"skyflow_id" : "id123" , "field1" : "value1" },
554+ token_mode = "invalid"
555+ )
556+ with self .assertRaises (SkyflowError ) as context :
557+ validate_update_request (self .logger , request )
558+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_TOKEN_MODE_TYPE .value )
486559
487560 def test_validate_detokenize_request_valid (self ):
488561 request = MagicMock ()
@@ -500,7 +573,7 @@ def test_validate_detokenize_request_empty_data(self):
500573
501574 def test_validate_detokenize_request_invalid_token (self ):
502575 request = MagicMock ()
503- request .data = [{"token" : 123 }]
576+ request .data = [{"token" : 123 }] # Invalid token type
504577 request .continue_on_error = False
505578 with self .assertRaises (SkyflowError ) as context :
506579 validate_detokenize_request (self .logger , request )
@@ -515,15 +588,15 @@ def test_validate_tokenize_request_valid(self):
515588
516589 def test_validate_tokenize_request_invalid_values_type (self ):
517590 request = MagicMock ()
518- request .values = "invalid"
591+ request .values = "invalid" # Should be list
519592 with self .assertRaises (SkyflowError ) as context :
520593 validate_tokenize_request (self .logger , request )
521594 self .assertEqual (context .exception .message ,
522595 SkyflowMessages .Error .INVALID_TOKENIZE_PARAMETERS .value .format (type (request .values )))
523596
524597 def test_validate_tokenize_request_empty_values (self ):
525598 request = MagicMock ()
526- request .values = []
599+ request .values = [] # Empty list
527600 with self .assertRaises (SkyflowError ) as context :
528601 validate_tokenize_request (self .logger , request )
529602 self .assertEqual (context .exception .message , SkyflowMessages .Error .EMPTY_TOKENIZE_PARAMETERS .value )
@@ -541,18 +614,56 @@ def test_validate_invoke_connection_params_valid(self):
541614 path_params = {"path1" : "value1" }
542615 validate_invoke_connection_params (self .logger , query_params , path_params )
543616
544- def test_validate_invoke_connection_params_invalid_path_params (self ):
545- query_params = {"param1" : "value1" }
546- path_params = "invalid" # Should be dict
617+ def test_validate_invoke_connection_params_invalid_path_params_type (self ):
618+ request = InvokeConnectionRequest (
619+ method = "GET" ,
620+ query_params = {"param1" : "value1" },
621+ path_params = "invalid"
622+ )
547623 with self .assertRaises (SkyflowError ) as context :
548- validate_invoke_connection_params (self .logger , query_params , path_params )
624+ validate_invoke_connection_params (self .logger , request . query_params , request . path_params )
549625 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_PATH_PARAMS .value )
550626
551- def test_validate_invoke_connection_params_invalid_query_params (self ):
552- query_params = "invalid" # Should be dict
553- path_params = {"path1" : "value1" }
627+ def test_validate_invoke_connection_params_invalid_query_params_type (self ):
628+ request = InvokeConnectionRequest (
629+ method = "GET" ,
630+ query_params = "invalid" ,
631+ path_params = {"path1" : "value1" }
632+ )
554633 with self .assertRaises (SkyflowError ) as context :
555- validate_invoke_connection_params (self .logger , query_params , path_params )
634+ validate_invoke_connection_params (self .logger , request .query_params , request .path_params )
635+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_QUERY_PARAMS .value )
636+
637+ def test_validate_invoke_connection_params_non_string_path_param (self ):
638+ request = InvokeConnectionRequest (
639+ method = "GET" ,
640+ query_params = {"param1" : "value1" },
641+ path_params = {1 : "value1" }
642+ )
643+ with self .assertRaises (SkyflowError ) as context :
644+ validate_invoke_connection_params (self .logger , request .query_params , request .path_params )
645+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_PATH_PARAMS .value )
646+
647+ def test_validate_invoke_connection_params_non_string_query_param_key (self ):
648+ request = InvokeConnectionRequest (
649+ method = "GET" ,
650+ query_params = {1 : "value1" },
651+ path_params = {"path1" : "value1" }
652+ )
653+ with self .assertRaises (SkyflowError ) as context :
654+ validate_invoke_connection_params (self .logger , request .query_params , request .path_params )
655+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_QUERY_PARAMS .value )
656+
657+ def test_validate_invoke_connection_params_non_serializable_query_params (self ):
658+ class NonSerializable :
659+ pass
660+ request = InvokeConnectionRequest (
661+ method = "GET" ,
662+ query_params = {"param1" : NonSerializable ()},
663+ path_params = {"path1" : "value1" }
664+ )
665+ with self .assertRaises (SkyflowError ) as context :
666+ validate_invoke_connection_params (self .logger , request .query_params , request .path_params )
556667 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_QUERY_PARAMS .value )
557668
558669 def test_validate_deidentify_text_request_valid (self ):
@@ -717,6 +828,16 @@ def test_validate_reidentify_text_request_invalid_redacted_entities(self):
717828 self .assertEqual (context .exception .message ,
718829 SkyflowMessages .Error .INVALID_REDACTED_ENTITIES_IN_REIDENTIFY .value )
719830
831+ def test_validate_reidentify_text_request_invalid_plain_text_entities (self ):
832+ request = ReidentifyTextRequest (
833+ text = "test text" ,
834+ plain_text_entities = "invalid"
835+ )
836+ with self .assertRaises (SkyflowError ) as context :
837+ validate_reidentify_text_request (self .logger , request )
838+ self .assertEqual (context .exception .message ,
839+ SkyflowMessages .Error .INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY .value )
840+
720841 def test_validate_deidentify_file_request_valid (self ):
721842 file_input = FileInput (file_path = self .temp_file_path )
722843 request = DeidentifyFileRequest (
@@ -886,3 +1007,40 @@ def test_validate_deidentify_file_request_invalid_wait_time(self):
8861007 with self .assertRaises (SkyflowError ) as context :
8871008 validate_deidentify_file_request (self .logger , request )
8881009 self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_WAIT_TIME .value )
1010+
1011+ def test_validate_detokenize_request_valid (self ):
1012+ request = DetokenizeRequest (
1013+ data = [{"token" : "token123" , "redaction" : RedactionType .PLAIN_TEXT }],
1014+ continue_on_error = False
1015+ )
1016+ validate_detokenize_request (self .logger , request )
1017+
1018+ def test_validate_detokenize_request_empty_data (self ):
1019+ request = DetokenizeRequest (data = [], continue_on_error = False )
1020+ with self .assertRaises (SkyflowError ) as context :
1021+ validate_detokenize_request (self .logger , request )
1022+ self .assertEqual (context .exception .message , SkyflowMessages .Error .EMPTY_TOKENS_LIST_VALUE .value )
1023+
1024+ def test_validate_detokenize_request_invalid_token_type (self ):
1025+ request = DetokenizeRequest (data = [{"token" : 123 }], continue_on_error = False )
1026+ with self .assertRaises (SkyflowError ) as context :
1027+ validate_detokenize_request (self .logger , request )
1028+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_TOKEN_TYPE .value .format ("DETOKENIZE" ))
1029+
1030+ def test_validate_detokenize_request_missing_token_key (self ):
1031+ request = DetokenizeRequest (data = [{"not_token" : "value" }], continue_on_error = False )
1032+ with self .assertRaises (SkyflowError ) as context :
1033+ validate_detokenize_request (self .logger , request )
1034+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_TOKENS_LIST_VALUE .value .format (str (type (request .data ))))
1035+
1036+ def test_validate_detokenize_request_invalid_continue_on_error_type (self ):
1037+ request = DetokenizeRequest (data = [{"token" : "token123" }], continue_on_error = "invalid" )
1038+ with self .assertRaises (SkyflowError ) as context :
1039+ validate_detokenize_request (self .logger , request )
1040+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_CONTINUE_ON_ERROR_TYPE .value )
1041+
1042+ def test_validate_detokenize_request_invalid_redaction_type (self ):
1043+ request = DetokenizeRequest (data = [{"token" : "token123" , "redaction" : "invalid" }], continue_on_error = False )
1044+ with self .assertRaises (SkyflowError ) as context :
1045+ validate_detokenize_request (self .logger , request )
1046+ self .assertEqual (context .exception .message , SkyflowMessages .Error .INVALID_REDACTION_TYPE .value .format (str (type ("invalid" ))))
0 commit comments