Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import Foundation

public enum MSSQLTLSClassifier {
public static func classifySSLError(_ message: String) -> MSSQLTLSFailureKind? {
let lower = message.lowercased()
if lower.contains("encryption is required") || lower.contains("server requires encryption") {
return .serverRejectedPlaintext
}
if lower.contains("encryption not supported") || lower.contains("server does not support encryption") {
return .serverRequiresPlaintext
}
if lower.contains("certificate verify failed") || lower.contains("certificate is not trusted") {
return .untrustedCertificate
}
if lower.contains("does not match host") {
return .hostnameMismatch
}
if lower.contains("ssl handshake") || lower.contains("tls handshake") || lower.contains("openssl error") {
return .cipherMismatch
}
return nil
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import Testing
@testable import TableProMSSQLCore

@Suite("MSSQL TLS Classifier")
struct MSSQLTLSClassifierTests {
@Test("Server requires encryption → serverRejectedPlaintext")
func testServerRequires() {
guard case .serverRejectedPlaintext = MSSQLTLSClassifier.classifySSLError("Server requires encryption") else {
Issue.record("Expected serverRejectedPlaintext")
return
}
}

@Test("Server does not support encryption → serverRequiresPlaintext")
func testServerNoSupport() {
guard case .serverRequiresPlaintext = MSSQLTLSClassifier.classifySSLError("encryption not supported by server") else {
Issue.record("Expected serverRequiresPlaintext")
return
}
}

@Test("Certificate verify failed → untrustedCertificate")
func testUntrustedCertificate() {
guard case .untrustedCertificate = MSSQLTLSClassifier.classifySSLError("certificate verify failed") else {
Issue.record("Expected untrustedCertificate")
return
}
}

@Test("Hostname mismatch → hostnameMismatch")
func testHostnameMismatch() {
guard case .hostnameMismatch = MSSQLTLSClassifier.classifySSLError("certificate does not match host name") else {
Issue.record("Expected hostnameMismatch")
return
}
}

@Test("OpenSSL handshake → cipherMismatch")
func testOpenSSL() {
guard case .cipherMismatch = MSSQLTLSClassifier.classifySSLError("OpenSSL error during SSL handshake") else {
Issue.record("Expected cipherMismatch")
return
}
}

@Test("Non-TLS error returns nil")
func testNonTLS() {
#expect(MSSQLTLSClassifier.classifySSLError("Login failed for user 'sa'") == nil)
}
}
16 changes: 1 addition & 15 deletions Plugins/CassandraDriverPlugin/CassandraConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,24 +197,10 @@ actor CassandraConnectionActor {
let keyResult = cass_ssl_set_private_key(ssl, keyString, passphrase)
if keyResult != CASS_OK {
cleanup()
throw Self.privateKeyLoadError(keyPEM: keyString, hasPassphrase: passphrase != nil, keyPath: keyPath)
throw CassandraClientKeyClassifier.privateKeyLoadError(keyPEM: keyString, hasPassphrase: passphrase != nil, keyPath: keyPath)
}
}

static func isEncryptedPrivateKey(_ pem: String) -> Bool {
pem.contains("ENCRYPTED PRIVATE KEY") || (pem.contains("Proc-Type:") && pem.contains("ENCRYPTED"))
}

static func privateKeyLoadError(keyPEM: String, hasPassphrase: Bool, keyPath: String) -> SSLHandshakeError {
guard isEncryptedPrivateKey(keyPEM) else {
return .clientKeyInvalid(serverMessage: "The client key at \(keyPath) is not a valid private key")
}
if hasPassphrase {
return .clientKeyPassphraseIncorrect(serverMessage: "The passphrase for the client key at \(keyPath) is incorrect")
}
return .clientKeyPassphraseRequired(serverMessage: "The client key at \(keyPath) is encrypted. Enter its passphrase.")
}

func close() {
if let session {
let closeFuture = cass_session_close(session)
Expand Down
26 changes: 1 addition & 25 deletions Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable {
session = nil
lock.unlock()
Self.logger.error("Connection test failed: \(error.localizedDescription)")
if let sslError = Self.classifySSLError(error) {
if let sslError = ClickHouseSSLClassifier.classifySSLError(error) {
throw sslError
}
throw ClickHouseError.connectionFailed
Expand Down Expand Up @@ -709,30 +709,6 @@ final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable {
func generateDropIndexSQL(table: String, indexName: String) -> String? {
"ALTER TABLE \(quoteIdentifier(table)) DROP INDEX \(quoteIdentifier(indexName))"
}

static func classifySSLError(_ error: Error) -> SSLHandshakeError? {
let urlError = error as? URLError ?? (error as NSError).underlyingErrors.compactMap { $0 as? URLError }.first
if let urlError {
switch urlError.code {
case .serverCertificateUntrusted, .serverCertificateNotYetValid, .serverCertificateHasUnknownRoot, .serverCertificateHasBadDate:
return .untrustedCertificate(serverMessage: urlError.localizedDescription)
case .clientCertificateRequired, .clientCertificateRejected:
return .clientCertRequired(serverMessage: urlError.localizedDescription)
case .secureConnectionFailed:
return .cipherMismatch(serverMessage: urlError.localizedDescription)
default:
break
}
}
let message = error.localizedDescription.lowercased()
if message.contains("certificate") && (message.contains("untrusted") || message.contains("verify failed")) {
return .untrustedCertificate(serverMessage: error.localizedDescription)
}
if message.contains("hostname") {
return .hostnameMismatch(serverMessage: error.localizedDescription)
}
return nil
}
}

// MARK: - TLS Delegate
Expand Down
22 changes: 1 addition & 21 deletions Plugins/MSSQLDriverPlugin/FreeTDSConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
init(options: MSSQLConnectionOptions) {
self.options = options
self.queue = DispatchQueue(label: "com.TablePro.freetds.\(options.host).\(options.port)", qos: .userInitiated)
_ = freetdsInitOnce

Check warning on line 136 in Plugins/MSSQLDriverPlugin/FreeTDSConnection.swift

View workflow job for this annotation

GitHub Actions / Run iOS Tests

main actor-isolated let 'freetdsInitOnce' can not be referenced from a nonisolated context
}

func connect() async throws {
Expand All @@ -148,7 +148,7 @@
}
defer { dbloginfree(login) }

for parameter in MSSQLLoginParameters.build(

Check warning on line 151 in Plugins/MSSQLDriverPlugin/FreeTDSConnection.swift

View workflow job for this annotation

GitHub Actions / Run iOS Tests

call to main actor-isolated static method 'build(user:password:applicationName:encryptionFlag:database:)' in a synchronous nonisolated context
user: options.user,
password: options.password,
applicationName: options.applicationName,
Expand All @@ -169,7 +169,7 @@
guard let proc = dbopen(login, serverName) else {
let detail = freetdsGetError(for: nil)
let msg = detail.isEmpty ? "Check host, port, credentials, and TLS settings" : detail
if let kind = FreeTDSConnection.classifySSLError(detail) {
if let kind = MSSQLTLSClassifier.classifySSLError(detail) {
throw MSSQLCoreError.tlsHandshakeFailed(kind: kind, serverMessage: detail)
}
throw MSSQLCoreError.connectionFailed("Failed to connect to \(options.host):\(options.port): \(msg)")
Expand Down Expand Up @@ -203,7 +203,7 @@
if let handle {
freetdsUnregister(handle)
queue.async {
_ = dbclose(handle)

Check warning on line 206 in Plugins/MSSQLDriverPlugin/FreeTDSConnection.swift

View workflow job for this annotation

GitHub Actions / macOS App Tests

capture of 'handle' with non-Sendable type 'UnsafeMutablePointer<DBPROCESS>' (aka 'UnsafeMutablePointer<dbprocess>') in a '@sendable' closure

Check warning on line 206 in Plugins/MSSQLDriverPlugin/FreeTDSConnection.swift

View workflow job for this annotation

GitHub Actions / macOS App Tests

capture of 'handle' with non-Sendable type 'UnsafeMutablePointer<DBPROCESS>' (aka 'UnsafeMutablePointer<dbprocess>') in a '@sendable' closure

Check warning on line 206 in Plugins/MSSQLDriverPlugin/FreeTDSConnection.swift

View workflow job for this annotation

GitHub Actions / macOS App Tests

capture of 'handle' with non-Sendable type 'UnsafeMutablePointer<DBPROCESS>' (aka 'UnsafeMutablePointer<dbprocess>') in a '@sendable' closure
}
}
}
Expand Down Expand Up @@ -526,26 +526,6 @@
}
return raw
}

static func classifySSLError(_ message: String) -> MSSQLTLSFailureKind? {
let lower = message.lowercased()
if lower.contains("encryption is required") || lower.contains("server requires encryption") {
return .serverRejectedPlaintext
}
if lower.contains("encryption not supported") || lower.contains("server does not support encryption") {
return .serverRequiresPlaintext
}
if lower.contains("certificate verify failed") || lower.contains("certificate is not trusted") {
return .untrustedCertificate
}
if lower.contains("does not match host") {
return .hostnameMismatch
}
if lower.contains("ssl handshake") || lower.contains("tls handshake") || lower.contains("openssl error") {
return .cipherMismatch
}
return nil
}
}

private extension MSSQLLoginField {
Expand Down
38 changes: 1 addition & 37 deletions Plugins/MongoDBDriverPlugin/MongoDBConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ final class MongoDBConnection: @unchecked Sendable {
let errorMsg = bsonErrorMessage(&error)
mongoc_client_destroy(newClient)
logger.error("MongoDB ping failed: \(errorMsg)")
if let sslError = Self.classifySSLError(errorMsg) {
if let sslError = MongoDBSSLClassifier.classifySSLError(errorMsg) {
throw sslError
}
throw MongoDBError(code: error.code, message: errorMsg)
Expand Down Expand Up @@ -805,42 +805,6 @@ extension MongoDBConnection {
return nil
#endif
}

static func classifySSLError(_ message: String) -> SSLHandshakeError? {
let lower = message.lowercased()
if lower.contains("certificate verify failed") || lower.contains("ssl certificate") {
return .untrustedCertificate(serverMessage: message)
}
if lower.contains("hostname") && lower.contains("verification") {
return .hostnameMismatch(serverMessage: message)
}
if lower.contains("tls required") || lower.contains("ssl required") {
return .serverRejectedPlaintext(serverMessage: message)
}
if lower.contains("client certificate required") || lower.contains("peer did not return a certificate") {
return .clientCertRequired(serverMessage: message)
}
if isCipherOrProtocolMismatch(lower) {
return .cipherMismatch(serverMessage: message)
}
if lower.contains("ssl handshake failed") || lower.contains("tls handshake failed") {
return .unknown(serverMessage: message)
}
return nil
}

static func isCipherOrProtocolMismatch(_ lower: String) -> Bool {
let signatures = [
"no shared cipher",
"sslv3 alert handshake failure",
"wrong version number",
"unsupported protocol",
"no protocols available",
"alert protocol version",
"protocol version",
]
return signatures.contains { lower.contains($0) }
}
}

// bsonToDict and bsonToJson take bson_t parameters (a CLibMongoc type),
Expand Down
26 changes: 3 additions & 23 deletions Plugins/MySQLDriverPlugin/MariaDBPluginConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
func mysqlTypeToString(_ fieldPtr: UnsafePointer<MYSQL_FIELD>) -> String {
let field = fieldPtr.pointee
let flags = UInt(field.flags)
let length = field.length

Check warning on line 70 in Plugins/MySQLDriverPlugin/MariaDBPluginConnection.swift

View workflow job for this annotation

GitHub Actions / macOS App Tests

initialization of immutable value 'length' was never used; consider replacing with assignment to '_' or removing it

// MariaDB extended metadata: detect JSON stored as LONGTEXT.
// `MARIADB_CONST_STRING` is length-prefixed (not null-terminated), so we must read
Expand Down Expand Up @@ -216,30 +216,24 @@

// MARK: - Connection Management

private static let sslOnlyErrorCodes: Set<UInt32> = [
2_026,
2_012,
1_043
]

func connect() async throws {
try await pluginDispatchAsync(on: queue) { [self] in
let mode = self.sslConfig.mode
let handle: UnsafeMutablePointer<MYSQL>
do {
handle = try self.attemptConnect(enforceSSL: mode != .disabled)
} catch let error as MariaDBPluginError where mode == .preferred && Self.sslOnlyErrorCodes.contains(error.code) {
} catch let error as MariaDBPluginError where mode == .preferred && MariaDBSSLClassifier.sslOnlyErrorCodes.contains(error.code) {
logger.notice("MySQL SSL handshake failed (code \(error.code)); falling back to plaintext for .preferred mode")
do {
handle = try self.attemptConnect(enforceSSL: false)
} catch let fallbackError as MariaDBPluginError {
if let sslError = Self.classifySSLError(fallbackError) {
if let sslError = MariaDBSSLClassifier.classifySSLError(code: fallbackError.code, message: fallbackError.message) {
throw sslError
}
throw fallbackError
}
} catch let error as MariaDBPluginError {
if let sslError = Self.classifySSLError(error) {
if let sslError = MariaDBSSLClassifier.classifySSLError(code: error.code, message: error.message) {
throw sslError
}
throw error
Expand All @@ -256,20 +250,6 @@
}
}

static func classifySSLError(_ error: MariaDBPluginError) -> SSLHandshakeError? {
let lower = error.message.lowercased()
if lower.contains("insecure transport") || lower.contains("require_secure_transport") {
return .serverRejectedPlaintext(serverMessage: error.message)
}
if Self.sslOnlyErrorCodes.contains(error.code) {
if lower.contains("certificate") {
return .untrustedCertificate(serverMessage: error.message)
}
return .cipherMismatch(serverMessage: error.message)
}
return nil
}

private func attemptConnect(enforceSSL: Bool) throws -> UnsafeMutablePointer<MYSQL> {
guard let mysql = mysql_init(nil) else {
throw MariaDBPluginError.initFailed
Expand Down
24 changes: 3 additions & 21 deletions Plugins/OracleDriverPlugin/OracleConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ final class OracleConnectionWrapper: @unchecked Sendable {
} catch let sqlError as OracleSQLError {
let detail = Self.connectFailureDetail(sqlError)
osLogger.error("Oracle connection failed: \(detail)")
if let sslError = Self.classifySSLError(detail) {
if let sslError = OracleSSLClassifier.classifySSLError(detail) {
throw sslError
}
let category = classifyConnectError(sqlError)
Expand All @@ -203,35 +203,17 @@ final class OracleConnectionWrapper: @unchecked Sendable {
} catch let nioSslError as NIOSSLError {
let detail = String(describing: nioSslError)
osLogger.error("Oracle TLS error: \(detail)")
throw Self.classifySSLError(detail) ?? SSLHandshakeError.unknown(serverMessage: detail)
throw OracleSSLClassifier.classifySSLError(detail) ?? SSLHandshakeError.unknown(serverMessage: detail)
} catch {
let detail = String(describing: error)
osLogger.error("Oracle connection failed: \(detail)")
if let sslError = Self.classifySSLError(detail) {
if let sslError = OracleSSLClassifier.classifySSLError(detail) {
throw sslError
}
throw OracleError(message: detail, category: .connectionFailed)
}
}

static func classifySSLError(_ message: String) -> SSLHandshakeError? {
let lower = message.lowercased()
if lower.contains("ora-28759") || lower.contains("failure to open file") && lower.contains("wallet") {
return .clientCertRequired(serverMessage: message)
}
if lower.contains("ora-29024") {
return .cipherMismatch(serverMessage: message)
}
if lower.contains("ora-28860") {
return .cipherMismatch(serverMessage: message)
}
if lower.contains("certificate") && (lower.contains("verify") || lower.contains("untrusted")) {
return .untrustedCertificate(serverMessage: message)
}
return nil
}


private func classifyConnectError(_ error: OracleSQLError) -> OracleError.Category {
let codeDescription = error.code.description
if codeDescription.hasPrefix("unsupportedVerifierType") {
Expand Down
28 changes: 1 addition & 27 deletions Plugins/PostgreSQLDriverPlugin/LibPQPluginConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ final class LibPQPluginConnection: @unchecked Sendable {
if PQstatus(connection) != CONNECTION_OK {
let error = self.getError(from: connection)
PQfinish(connection)
if let sslError = Self.classifySSLError(error.message) {
if let sslError = LibPQSSLClassifier.classifySSLError(error.message) {
throw sslError
}
throw error
Expand Down Expand Up @@ -865,32 +865,6 @@ final class LibPQPluginConnection: @unchecked Sendable {

// MARK: - Private Helpers

static func classifySSLError(_ message: String) -> SSLHandshakeError? {
let lower = message.lowercased()
if lower.contains("no pg_hba.conf entry") && lower.contains("no encryption") {
return .serverRejectedPlaintext(serverMessage: message)
}
if lower.contains("no pg_hba.conf entry") && lower.contains("ssl") {
return .serverRequiresPlaintext(serverMessage: message)
}
if lower.contains("server does not support ssl") || lower.contains("ssl is not enabled on the server") {
return .serverRequiresPlaintext(serverMessage: message)
}
if lower.contains("certificate verify failed") || lower.contains("self-signed certificate") || lower.contains("unable to get local issuer certificate") {
return .untrustedCertificate(serverMessage: message)
}
if lower.contains("server certificate") && lower.contains("does not match host name") {
return .hostnameMismatch(serverMessage: message)
}
if lower.contains("certificate required") || lower.contains("connection requires a valid client certificate") {
return .clientCertRequired(serverMessage: message)
}
if lower.contains("ssl error") || lower.contains("tls handshake") || lower.contains("ssl handshake") {
return .cipherMismatch(serverMessage: message)
}
return nil
}

private func getError(from conn: OpaquePointer) -> LibPQPluginError {
var message = "Unknown error"
if let msgPtr = PQerrorMessage(conn) {
Expand Down
Loading
Loading