diff --git a/main.py b/main.py index d1863f8..fe22397 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ import psutil from Crypto.Cipher import AES from Crypto.Random import get_random_bytes +from Crypto.Hash import SHA256 from Crypto.Util.Padding import pad, unpad from config import * @@ -39,25 +40,31 @@ def load_registered_usbs(self): logging.error(f"Error loading registered USBs: {e}") def encrypt_data(self, data: bytes) -> bytes: - """Encrypt data using AES-256.""" + """Encrypt data using AES-256. (Better though 😎)""" if not ENCRYPTION_KEY: raise ValueError("Encryption key not set") - key = ENCRYPTION_KEY.encode().ljust(32)[:32] + key = SHA256.new(ENCRYPTION_KEY.encode()).digest() cipher = AES.new(key, AES.MODE_CBC) ct_bytes = cipher.encrypt(pad(data, AES.block_size)) return cipher.iv + ct_bytes def decrypt_data(self, encrypted_data: bytes) -> bytes: - """Decrypt data using AES-256.""" + """Decrypt data using AES-256. Same thing 😎""" if not ENCRYPTION_KEY: - raise ValueError("Encryption key not set") + raise ValueError("Encryption key not set") + if len(encrypted_data) < 16: + raise ValueError("Invalid encrypted data (too short)") - key = ENCRYPTION_KEY.encode().ljust(32)[:32] + key = SHA256.new(ENCRYPTION_KEY.encode()).digest() iv = encrypted_data[:16] ct = encrypted_data[16:] - cipher = AES.new(key, AES.MODE_CBC, iv) - return unpad(cipher.decrypt(ct), AES.block_size) + cipher = AES.new(key, AES.MODE_CBC, iv=iv) + + try: + return unpad(cipher.decrypt(ct), AES.block_size) + except ValueError as e: + raise ValueError("Decryption failed - possibly corrupt data") from e def get_usb_uuid(self, device_path: str) -> Optional[str]: """Get UUID of a USB device.""" @@ -123,7 +130,7 @@ def monitor_usb_events(self): if self.verify_usb(device): self.unlock_system() else: - logging.warning(f"Unauthorized USB device detected: {device}") + logging.warning(f"Unauthorized or Unsecure USB device detected: {device}") # Check for removed devices removed_devices = previous_devices - current_devices @@ -145,4 +152,4 @@ def main(): sys.exit(1) if __name__ == "__main__": - main() \ No newline at end of file + main()