fix(security): require explicit trust for first-time TLS pins

This commit is contained in:
Peter Steinberger
2026-02-14 17:47:13 +01:00
parent d714ac7797
commit 054366dea4
16 changed files with 549 additions and 76 deletions

View File

@@ -2,6 +2,7 @@ import AVFoundation
import Contacts
import CoreLocation
import CoreMotion
import CryptoKit
import EventKit
import Foundation
import OpenClawKit
@@ -9,6 +10,7 @@ import Network
import Observation
import Photos
import ReplayKit
import Security
import Speech
import SwiftUI
import UIKit
@@ -16,14 +18,27 @@ import UIKit
@MainActor
@Observable
final class GatewayConnectionController {
struct TrustPrompt: Identifiable, Equatable {
let stableID: String
let gatewayName: String
let host: String
let port: Int
let fingerprintSha256: String
let isManual: Bool
var id: String { self.stableID }
}
private(set) var gateways: [GatewayDiscoveryModel.DiscoveredGateway] = []
private(set) var discoveryStatusText: String = "Idle"
private(set) var discoveryDebugLog: [GatewayDiscoveryModel.DebugLogEntry] = []
private(set) var pendingTrustPrompt: TrustPrompt?
private let discovery = GatewayDiscoveryModel()
private weak var appModel: NodeAppModel?
private var didAutoConnect = false
private var pendingServiceResolvers: [String: GatewayServiceResolver] = [:]
private var pendingTrustConnect: (url: URL, stableID: String, isManual: Bool)?
init(appModel: NodeAppModel, startDiscovery: Bool = true) {
self.appModel = appModel
@@ -58,12 +73,11 @@ final class GatewayConnectionController {
}
func connect(_ gateway: GatewayDiscoveryModel.DiscoveredGateway) async {
await self.connectDiscoveredGateway(gateway, allowTOFU: true)
await self.connectDiscoveredGateway(gateway)
}
private func connectDiscoveredGateway(
_ gateway: GatewayDiscoveryModel.DiscoveredGateway,
allowTOFU: Bool) async
_ gateway: GatewayDiscoveryModel.DiscoveredGateway) async
{
let instanceId = UserDefaults.standard.string(forKey: "node.instanceId")?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
@@ -73,21 +87,43 @@ final class GatewayConnectionController {
// Resolve the service endpoint (SRV/A/AAAA). TXT is unauthenticated; do not route via TXT.
guard let target = await self.resolveServiceEndpoint(gateway.endpoint) else { return }
let tlsParams = self.resolveDiscoveredTLSParams(gateway: gateway, allowTOFU: allowTOFU)
let stableID = gateway.stableID
// Discovery is a LAN operation; refuse unauthenticated plaintext connects.
let tlsRequired = true
let stored = GatewayTLSStore.loadFingerprint(stableID: stableID)
guard gateway.tlsEnabled || stored != nil else { return }
if tlsRequired, stored == nil {
guard let url = self.buildGatewayURL(host: target.host, port: target.port, useTLS: true)
else { return }
guard let fp = await self.probeTLSFingerprint(url: url) else { return }
self.pendingTrustConnect = (url: url, stableID: stableID, isManual: false)
self.pendingTrustPrompt = TrustPrompt(
stableID: stableID,
gatewayName: gateway.name,
host: target.host,
port: target.port,
fingerprintSha256: fp,
isManual: false)
self.appModel?.gatewayStatusText = "Verify gateway TLS fingerprint"
return
}
let tlsParams = stored.map { fp in
GatewayTLSParams(required: true, expectedFingerprint: fp, allowTOFU: false, storeKey: stableID)
}
guard let url = self.buildGatewayURL(
host: target.host,
port: target.port,
useTLS: tlsParams?.required == true)
else { return }
GatewaySettingsStore.saveLastGatewayConnection(
host: target.host,
port: target.port,
useTLS: tlsParams?.required == true,
stableID: gateway.stableID)
GatewaySettingsStore.saveLastGatewayConnectionDiscovered(stableID: stableID, useTLS: true)
self.didAutoConnect = true
self.startAutoConnect(
url: url,
gatewayStableID: gateway.stableID,
gatewayStableID: stableID,
tls: tlsParams,
token: token,
password: password)
@@ -102,19 +138,34 @@ final class GatewayConnectionController {
guard let resolvedPort = self.resolveManualPort(host: host, port: port, useTLS: resolvedUseTLS)
else { return }
let stableID = self.manualStableID(host: host, port: resolvedPort)
let tlsParams = self.resolveManualTLSParams(
stableID: stableID,
tlsEnabled: resolvedUseTLS,
allowTOFUReset: self.shouldForceTLS(host: host))
let stored = GatewayTLSStore.loadFingerprint(stableID: stableID)
if resolvedUseTLS, stored == nil {
guard let url = self.buildGatewayURL(host: host, port: resolvedPort, useTLS: true) else { return }
guard let fp = await self.probeTLSFingerprint(url: url) else { return }
self.pendingTrustConnect = (url: url, stableID: stableID, isManual: true)
self.pendingTrustPrompt = TrustPrompt(
stableID: stableID,
gatewayName: "\(host):\(resolvedPort)",
host: host,
port: resolvedPort,
fingerprintSha256: fp,
isManual: true)
self.appModel?.gatewayStatusText = "Verify gateway TLS fingerprint"
return
}
let tlsParams = stored.map { fp in
GatewayTLSParams(required: true, expectedFingerprint: fp, allowTOFU: false, storeKey: stableID)
}
guard let url = self.buildGatewayURL(
host: host,
port: resolvedPort,
useTLS: tlsParams?.required == true)
else { return }
GatewaySettingsStore.saveLastGatewayConnection(
GatewaySettingsStore.saveLastGatewayConnectionManual(
host: host,
port: resolvedPort,
useTLS: tlsParams?.required == true,
useTLS: resolvedUseTLS && tlsParams != nil,
stableID: stableID)
self.didAutoConnect = true
self.startAutoConnect(
@@ -127,36 +178,63 @@ final class GatewayConnectionController {
func connectLastKnown() async {
guard let last = GatewaySettingsStore.loadLastGatewayConnection() else { return }
switch last {
case let .manual(host, port, useTLS, _):
await self.connectManual(host: host, port: port, useTLS: useTLS)
case let .discovered(stableID, _):
guard let gateway = self.gateways.first(where: { $0.stableID == stableID }) else { return }
await self.connectDiscoveredGateway(gateway)
}
}
func clearPendingTrustPrompt() {
self.pendingTrustPrompt = nil
self.pendingTrustConnect = nil
}
func acceptPendingTrustPrompt() async {
guard let pending = self.pendingTrustConnect,
let prompt = self.pendingTrustPrompt,
pending.stableID == prompt.stableID
else { return }
GatewayTLSStore.saveFingerprint(prompt.fingerprintSha256, stableID: pending.stableID)
self.clearPendingTrustPrompt()
if pending.isManual {
GatewaySettingsStore.saveLastGatewayConnectionManual(
host: prompt.host,
port: prompt.port,
useTLS: true,
stableID: pending.stableID)
} else {
GatewaySettingsStore.saveLastGatewayConnectionDiscovered(stableID: pending.stableID, useTLS: true)
}
let instanceId = UserDefaults.standard.string(forKey: "node.instanceId")?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
let token = GatewaySettingsStore.loadGatewayToken(instanceId: instanceId)
let password = GatewaySettingsStore.loadGatewayPassword(instanceId: instanceId)
let resolvedUseTLS = last.useTLS
let tlsParams = self.resolveManualTLSParams(
stableID: last.stableID,
tlsEnabled: resolvedUseTLS,
allowTOFUReset: self.shouldForceTLS(host: last.host))
guard let url = self.buildGatewayURL(
host: last.host,
port: last.port,
useTLS: tlsParams?.required == true)
else { return }
if resolvedUseTLS != last.useTLS {
GatewaySettingsStore.saveLastGatewayConnection(
host: last.host,
port: last.port,
useTLS: resolvedUseTLS,
stableID: last.stableID)
}
let tlsParams = GatewayTLSParams(
required: true,
expectedFingerprint: prompt.fingerprintSha256,
allowTOFU: false,
storeKey: pending.stableID)
self.didAutoConnect = true
self.startAutoConnect(
url: url,
gatewayStableID: last.stableID,
url: pending.url,
gatewayStableID: pending.stableID,
tls: tlsParams,
token: token,
password: password)
}
func declinePendingTrustPrompt() {
self.clearPendingTrustPrompt()
self.appModel?.gatewayStatusText = "Offline"
}
private func updateFromDiscovery() {
let newGateways = self.discovery.gateways
self.gateways = newGateways
@@ -233,25 +311,30 @@ final class GatewayConnectionController {
}
if let lastKnown = GatewaySettingsStore.loadLastGatewayConnection() {
let resolvedUseTLS = lastKnown.useTLS || self.shouldForceTLS(host: lastKnown.host)
let tlsParams = self.resolveManualTLSParams(
stableID: lastKnown.stableID,
tlsEnabled: resolvedUseTLS,
allowTOFUReset: self.shouldForceTLS(host: lastKnown.host))
guard let url = self.buildGatewayURL(
host: lastKnown.host,
port: lastKnown.port,
useTLS: tlsParams?.required == true)
else { return }
if case let .manual(host, port, useTLS, stableID) = lastKnown {
let resolvedUseTLS = useTLS || self.shouldForceTLS(host: host)
let stored = GatewayTLSStore.loadFingerprint(stableID: stableID)
let tlsParams = stored.map { fp in
GatewayTLSParams(required: true, expectedFingerprint: fp, allowTOFU: false, storeKey: stableID)
}
guard let url = self.buildGatewayURL(
host: host,
port: port,
useTLS: resolvedUseTLS && tlsParams != nil)
else { return }
self.didAutoConnect = true
self.startAutoConnect(
url: url,
gatewayStableID: lastKnown.stableID,
tls: tlsParams,
token: token,
password: password)
return
// Security: autoconnect only to previously trusted gateways (stored TLS pin).
guard tlsParams != nil else { return }
self.didAutoConnect = true
self.startAutoConnect(
url: url,
gatewayStableID: stableID,
tls: tlsParams,
token: token,
password: password)
return
}
}
let preferredStableID = defaults.string(forKey: "gateway.preferredStableID")?
@@ -270,7 +353,7 @@ final class GatewayConnectionController {
self.didAutoConnect = true
Task { [weak self] in
guard let self else { return }
await self.connectDiscoveredGateway(target, allowTOFU: false)
await self.connectDiscoveredGateway(target)
}
return
}
@@ -282,7 +365,7 @@ final class GatewayConnectionController {
self.didAutoConnect = true
Task { [weak self] in
guard let self else { return }
await self.connectDiscoveredGateway(gateway, allowTOFU: false)
await self.connectDiscoveredGateway(gateway)
}
return
}
@@ -359,7 +442,7 @@ final class GatewayConnectionController {
return GatewayTLSParams(
required: true,
expectedFingerprint: nil,
allowTOFU: allowTOFU,
allowTOFU: false,
storeKey: stableID)
}
@@ -376,13 +459,22 @@ final class GatewayConnectionController {
return GatewayTLSParams(
required: true,
expectedFingerprint: stored,
allowTOFU: stored == nil || allowTOFUReset,
allowTOFU: false,
storeKey: stableID)
}
return nil
}
private func probeTLSFingerprint(url: URL) async -> String? {
await withCheckedContinuation { continuation in
let probe = GatewayTLSFingerprintProbe(url: url, timeoutSeconds: 3) { fp in
continuation.resume(returning: fp)
}
probe.start()
}
}
private func resolveServiceEndpoint(_ endpoint: NWEndpoint) async -> (host: String, port: Int)? {
guard case let .service(name, type, domain, _) = endpoint else { return nil }
let key = "\(domain)|\(type)|\(name)"
@@ -692,3 +784,71 @@ extension GatewayConnectionController {
}
}
#endif
private final class GatewayTLSFingerprintProbe: NSObject, URLSessionDelegate {
private let url: URL
private let timeoutSeconds: Double
private let onComplete: (String?) -> Void
private var didFinish = false
private var session: URLSession?
private var task: URLSessionWebSocketTask?
init(url: URL, timeoutSeconds: Double, onComplete: @escaping (String?) -> Void) {
self.url = url
self.timeoutSeconds = timeoutSeconds
self.onComplete = onComplete
}
func start() {
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = self.timeoutSeconds
config.timeoutIntervalForResource = self.timeoutSeconds
let session = URLSession(configuration: config, delegate: self, delegateQueue: nil)
self.session = session
let task = session.webSocketTask(with: self.url)
self.task = task
task.resume()
DispatchQueue.global(qos: .utility).asyncAfter(deadline: .now() + self.timeoutSeconds) { [weak self] in
self?.finish(nil)
}
}
func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,
completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void
) {
guard challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust,
let trust = challenge.protectionSpace.serverTrust
else {
completionHandler(.performDefaultHandling, nil)
return
}
let fp = GatewayTLSFingerprintProbe.certificateFingerprint(trust)
completionHandler(.cancelAuthenticationChallenge, nil)
self.finish(fp)
}
private func finish(_ fingerprint: String?) {
objc_sync_enter(self)
defer { objc_sync_exit(self) }
guard !self.didFinish else { return }
self.didFinish = true
self.task?.cancel(with: .goingAway, reason: nil)
self.session?.invalidateAndCancel()
self.onComplete(fingerprint)
}
private static func certificateFingerprint(_ trust: SecTrust) -> String? {
guard let chain = SecTrustCopyCertificateChain(trust) as? [SecCertificate],
let cert = chain.first
else {
return nil
}
let data = SecCertificateCopyData(cert) as Data
let digest = SHA256.hash(data: data)
return digest.map { String(format: "%02x", $0) }.joined()
}
}

View File

@@ -13,6 +13,7 @@ enum GatewaySettingsStore {
private static let manualPortDefaultsKey = "gateway.manual.port"
private static let manualTlsDefaultsKey = "gateway.manual.tls"
private static let discoveryDebugLogsDefaultsKey = "gateway.discovery.debugLogs"
private static let lastGatewayKindDefaultsKey = "gateway.last.kind"
private static let lastGatewayHostDefaultsKey = "gateway.last.host"
private static let lastGatewayPortDefaultsKey = "gateway.last.port"
private static let lastGatewayTlsDefaultsKey = "gateway.last.tls"
@@ -114,25 +115,73 @@ enum GatewaySettingsStore {
account: self.gatewayPasswordAccount(instanceId: instanceId))
}
static func saveLastGatewayConnection(host: String, port: Int, useTLS: Bool, stableID: String) {
enum LastGatewayConnection: Equatable {
case manual(host: String, port: Int, useTLS: Bool, stableID: String)
case discovered(stableID: String, useTLS: Bool)
var stableID: String {
switch self {
case let .manual(_, _, _, stableID):
return stableID
case let .discovered(stableID, _):
return stableID
}
}
var useTLS: Bool {
switch self {
case let .manual(_, _, useTLS, _):
return useTLS
case let .discovered(_, useTLS):
return useTLS
}
}
}
private enum LastGatewayKind: String {
case manual
case discovered
}
static func saveLastGatewayConnectionManual(host: String, port: Int, useTLS: Bool, stableID: String) {
let defaults = UserDefaults.standard
defaults.set(LastGatewayKind.manual.rawValue, forKey: self.lastGatewayKindDefaultsKey)
defaults.set(host, forKey: self.lastGatewayHostDefaultsKey)
defaults.set(port, forKey: self.lastGatewayPortDefaultsKey)
defaults.set(useTLS, forKey: self.lastGatewayTlsDefaultsKey)
defaults.set(stableID, forKey: self.lastGatewayStableIDDefaultsKey)
}
static func loadLastGatewayConnection() -> (host: String, port: Int, useTLS: Bool, stableID: String)? {
static func saveLastGatewayConnectionDiscovered(stableID: String, useTLS: Bool) {
let defaults = UserDefaults.standard
defaults.set(LastGatewayKind.discovered.rawValue, forKey: self.lastGatewayKindDefaultsKey)
defaults.removeObject(forKey: self.lastGatewayHostDefaultsKey)
defaults.removeObject(forKey: self.lastGatewayPortDefaultsKey)
defaults.set(useTLS, forKey: self.lastGatewayTlsDefaultsKey)
defaults.set(stableID, forKey: self.lastGatewayStableIDDefaultsKey)
}
static func loadLastGatewayConnection() -> LastGatewayConnection? {
let defaults = UserDefaults.standard
let stableID = defaults.string(forKey: self.lastGatewayStableIDDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
guard !stableID.isEmpty else { return nil }
let useTLS = defaults.bool(forKey: self.lastGatewayTlsDefaultsKey)
let kindRaw = defaults.string(forKey: self.lastGatewayKindDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
let kind = LastGatewayKind(rawValue: kindRaw) ?? .manual
if kind == .discovered {
return .discovered(stableID: stableID, useTLS: useTLS)
}
let host = defaults.string(forKey: self.lastGatewayHostDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
let port = defaults.integer(forKey: self.lastGatewayPortDefaultsKey)
let useTLS = defaults.bool(forKey: self.lastGatewayTlsDefaultsKey)
let stableID = defaults.string(forKey: self.lastGatewayStableIDDefaultsKey)?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? ""
guard !host.isEmpty, port > 0, port <= 65535, !stableID.isEmpty else { return nil }
return (host: host, port: port, useTLS: useTLS, stableID: stableID)
// Back-compat: older builds persisted manual-style host/port without a kind marker.
guard !host.isEmpty, port > 0, port <= 65535 else { return nil }
return .manual(host: host, port: port, useTLS: useTLS, stableID: stableID)
}
static func loadGatewayClientIdOverride(stableID: String) -> String? {

View File

@@ -0,0 +1,42 @@
import SwiftUI
struct GatewayTrustPromptAlert: ViewModifier {
@Environment(GatewayConnectionController.self) private var gatewayController: GatewayConnectionController
private var promptBinding: Binding<GatewayConnectionController.TrustPrompt?> {
Binding(
get: { self.gatewayController.pendingTrustPrompt },
set: { newValue in
if newValue == nil {
self.gatewayController.clearPendingTrustPrompt()
}
})
}
func body(content: Content) -> some View {
content.alert(item: self.promptBinding) { prompt in
Alert(
title: Text("Trust this gateway?"),
message: Text(
"""
First-time TLS connection.
Verify this SHA-256 fingerprint out-of-band before trusting:
\(prompt.fingerprintSha256)
"""),
primaryButton: .cancel(Text("Cancel")) {
self.gatewayController.declinePendingTrustPrompt()
},
secondaryButton: .default(Text("Trust and connect")) {
Task { await self.gatewayController.acceptPendingTrustPrompt() }
})
}
}
}
extension View {
func gatewayTrustPromptAlert() -> some View {
self.modifier(GatewayTrustPromptAlert())
}
}

View File

@@ -21,6 +21,7 @@ struct GatewayOnboardingView: View {
}
.navigationTitle("Connect Gateway")
}
.gatewayTrustPromptAlert()
}
}

View File

@@ -52,6 +52,7 @@ struct RootCanvas: View {
CameraFlashOverlay(nonce: self.appModel.cameraFlashNonce)
}
}
.gatewayTrustPromptAlert()
.sheet(item: self.$presentedSheet) { sheet in
switch sheet {
case .settings:

View File

@@ -376,6 +376,7 @@ struct SettingsTab: View {
}
}
}
.gatewayTrustPromptAlert()
}
@ViewBuilder
@@ -388,11 +389,13 @@ struct SettingsTab: View {
.font(.footnote)
.foregroundStyle(.secondary)
if let lastKnown = GatewaySettingsStore.loadLastGatewayConnection() {
if let lastKnown = GatewaySettingsStore.loadLastGatewayConnection(),
case let .manual(host, port, _, _) = lastKnown
{
Button {
Task { await self.connectLastKnown() }
} label: {
self.lastKnownButtonLabel(host: lastKnown.host, port: lastKnown.port)
self.lastKnownButtonLabel(host: host, port: port)
}
.disabled(self.connectingGatewayID != nil)
.buttonStyle(.borderedProminent)