Files
openclaw/apps/shared/OpenClawKit/Sources/OpenClawKit/GatewayTLSPinning.swift
2026-03-03 17:13:52 +01:00

167 lines
6.2 KiB
Swift

import CryptoKit
import Foundation
import Security
public struct GatewayTLSParams: Sendable {
public let required: Bool
public let expectedFingerprint: String?
public let allowTOFU: Bool
public let storeKey: String?
public init(required: Bool, expectedFingerprint: String?, allowTOFU: Bool, storeKey: String?) {
self.required = required
self.expectedFingerprint = expectedFingerprint
self.allowTOFU = allowTOFU
self.storeKey = storeKey
}
}
public enum GatewayTLSStore {
private static let keychainService = "ai.openclaw.tls-pinning"
// Legacy UserDefaults location used before Keychain migration.
private static let legacySuiteName = "ai.openclaw.shared"
private static let legacyKeyPrefix = "gateway.tls."
public static func loadFingerprint(stableID: String) -> String? {
self.migrateFromUserDefaultsIfNeeded(stableID: stableID)
let raw = self.keychainLoad(account: stableID)?.trimmingCharacters(in: .whitespacesAndNewlines)
if raw?.isEmpty == false { return raw }
return nil
}
public static func saveFingerprint(_ value: String, stableID: String) {
self.keychainSave(value, account: stableID)
}
// MARK: - Migration
/// On first Keychain read for a given stableID, move any legacy UserDefaults
/// fingerprint into Keychain and remove the old entry.
private static func migrateFromUserDefaultsIfNeeded(stableID: String) {
guard let defaults = UserDefaults(suiteName: self.legacySuiteName) else { return }
let legacyKey = self.legacyKeyPrefix + stableID
guard let existing = defaults.string(forKey: legacyKey)?
.trimmingCharacters(in: .whitespacesAndNewlines),
!existing.isEmpty
else { return }
if self.keychainLoad(account: stableID) == nil {
guard self.keychainSave(existing, account: stableID) else { return }
}
defaults.removeObject(forKey: legacyKey)
}
// MARK: - Self-contained Keychain helpers (OpenClawKit can't import iOS KeychainStore)
private static func keychainLoad(account: String) -> String? {
let query: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: self.keychainService,
kSecAttrAccount as String: account,
kSecReturnData as String: true,
kSecMatchLimit as String: kSecMatchLimitOne,
]
var item: CFTypeRef?
let status = SecItemCopyMatching(query as CFDictionary, &item)
guard status == errSecSuccess, let data = item as? Data else { return nil }
return String(data: data, encoding: .utf8)
}
@discardableResult
private static func keychainSave(_ value: String, account: String) -> Bool {
let data = Data(value.utf8)
let query: [String: Any] = [
kSecClass as String: kSecClassGenericPassword,
kSecAttrService as String: self.keychainService,
kSecAttrAccount as String: account,
]
// Delete-then-add to enforce accessibility attribute.
SecItemDelete(query as CFDictionary)
var insert = query
insert[kSecValueData as String] = data
insert[kSecAttrAccessible as String] = kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly
return SecItemAdd(insert as CFDictionary, nil) == errSecSuccess
}
}
public final class GatewayTLSPinningSession: NSObject, WebSocketSessioning, URLSessionDelegate, @unchecked Sendable {
private let params: GatewayTLSParams
private lazy var session: URLSession = {
let config = URLSessionConfiguration.default
config.waitsForConnectivity = true
return URLSession(configuration: config, delegate: self, delegateQueue: nil)
}()
public init(params: GatewayTLSParams) {
self.params = params
super.init()
}
public func makeWebSocketTask(url: URL) -> WebSocketTaskBox {
let task = self.session.webSocketTask(with: url)
task.maximumMessageSize = 4 * 1024 * 1024
return WebSocketTaskBox(task: task)
}
public func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
) {
guard challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust,
let trust = challenge.protectionSpace.serverTrust
else {
completionHandler(.performDefaultHandling, nil)
return
}
let expected = params.expectedFingerprint.map(normalizeFingerprint)
if let fingerprint = certificateFingerprint(trust) {
if let expected {
if fingerprint == expected {
completionHandler(.useCredential, URLCredential(trust: trust))
} else {
completionHandler(.cancelAuthenticationChallenge, nil)
}
return
}
if params.allowTOFU {
if let storeKey = params.storeKey {
GatewayTLSStore.saveFingerprint(fingerprint, stableID: storeKey)
}
completionHandler(.useCredential, URLCredential(trust: trust))
return
}
}
let ok = SecTrustEvaluateWithError(trust, nil)
if ok || !params.required {
completionHandler(.useCredential, URLCredential(trust: trust))
} else {
completionHandler(.cancelAuthenticationChallenge, nil)
}
}
}
private func certificateFingerprint(_ trust: SecTrust) -> String? {
guard let chain = SecTrustCopyCertificateChain(trust) as? [SecCertificate],
let cert = chain.first
else {
return nil
}
return sha256Hex(SecCertificateCopyData(cert) as Data)
}
private func sha256Hex(_ data: Data) -> String {
let digest = SHA256.hash(data: data)
return digest.map { String(format: "%02x", $0) }.joined()
}
private func normalizeFingerprint(_ raw: String) -> String {
let stripped = raw.replacingOccurrences(
of: #"(?i)^sha-?256\s*:?\s*"#,
with: "",
options: .regularExpression)
return stripped.lowercased().filter(\.isHexDigit)
}