Skip to content

feat: install and activate the tunnel provider as network extension #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Coder Desktop/Coder Desktop.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_ENTITLEMENTS = "Coder Desktop/Coder_Desktop.entitlements";
CODE_SIGN_IDENTITY = "Apple Development";
CODE_SIGN_STYLE = Automatic;
COMBINE_HIDPI_IMAGES = YES;
CURRENT_PROJECT_VERSION = 1;
Expand All @@ -788,6 +789,7 @@
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop";
PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 6.0;
};
Expand All @@ -799,6 +801,7 @@
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_ENTITLEMENTS = "Coder Desktop/Coder_Desktop.entitlements";
CODE_SIGN_IDENTITY = "Apple Development";
CODE_SIGN_STYLE = Automatic;
COMBINE_HIDPI_IMAGES = YES;
CURRENT_PROJECT_VERSION = 1;
Expand All @@ -818,6 +821,7 @@
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.coder.Coder-Desktop";
PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 6.0;
};
Expand Down Expand Up @@ -901,6 +905,7 @@
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_ENTITLEMENTS = VPN/VPN.entitlements;
"CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
Expand Down Expand Up @@ -932,6 +937,7 @@
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_ENTITLEMENTS = VPN/VPN.entitlements;
"CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
Expand Down
2 changes: 2 additions & 0 deletions Coder Desktop/Coder Desktop/Coder_Desktop.entitlements
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
<array>
<string>packet-tunnel-provider</string>
</array>
<key>com.apple.developer.system-extension.install</key>
<true/>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
Expand Down
113 changes: 113 additions & 0 deletions Coder Desktop/Coder Desktop/NetworkExtension.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import NetworkExtension
import os

enum NetworkExtensionState: Equatable {
case unconfigured
case disbled
case enabled
case failed(String)

var description: String {
switch self {
case .unconfigured:
return "Not logged in to Coder"
case .enabled:
return "NetworkExtension tunnel enabled"
case .disbled:
return "NetworkExtension tunnel disabled"
case let .failed(error):
return "NetworkExtension config failed: \(error)"
}
}
}

/// An actor that handles configuring, enabling, and disabling the VPN tunnel via the
/// NetworkExtension APIs.
extension CoderVPNService {
func configureNetworkExtension(proto: NETunnelProviderProtocol) async {
// removing the old tunnels, rather than reconfiguring ensures that configuration changes
// are picked up.
do {
try await removeNetworkExtension()
} catch {
logger.error("remove tunnel failed: \(error)")
neState = .failed(error.localizedDescription)
return
}
logger.debug("inserting new tunnel")

let tm = NETunnelProviderManager()
tm.localizedDescription = "CoderVPN"
tm.protocolConfiguration = proto

logger.debug("saving new tunnel")
do {
try await tm.saveToPreferences()
} catch {
logger.error("save tunnel failed: \(error)")
neState = .failed(error.localizedDescription)
}
}

func removeNetworkExtension() async throws(VPNServiceError) {
do {
let tunnels = try await NETunnelProviderManager.loadAllFromPreferences()
for tunnel in tunnels {
try await tunnel.removeFromPreferences()
}
} catch {
throw .internalError("couldn't remove tunnels: \(error)")
}
}

func enableNetworkExtension() async {
do {
let tm = try await getTunnelManager()
if !tm.isEnabled {
tm.isEnabled = true
try await tm.saveToPreferences()
logger.debug("saved tunnel with enabled=true")
}
try tm.connection.startVPNTunnel()
} catch {
logger.error("enable network extension: \(error)")
neState = .failed(error.localizedDescription)
return
}
logger.debug("enabled and started tunnel")
neState = .enabled
}

func disableNetworkExtension() async {
do {
let tm = try await getTunnelManager()
tm.connection.stopVPNTunnel()
tm.isEnabled = false

try await tm.saveToPreferences()
} catch {
logger.error("disable network extension: \(error)")
neState = .failed(error.localizedDescription)
return
}
logger.debug("saved tunnel with enabled=false")
neState = .disbled
}

private func getTunnelManager() async throws(VPNServiceError) -> NETunnelProviderManager {
var tunnels: [NETunnelProviderManager] = []
do {
tunnels = try await NETunnelProviderManager.loadAllFromPreferences()
} catch {
throw .internalError("couldn't load tunnels: \(error)")
}
if tunnels.isEmpty {
throw .internalError("no tunnels found")
}
return tunnels.first!
}
}

// we're going to mark NETunnelProviderManager as Sendable since there are official APIs that return
// it async.
extension NETunnelProviderManager: @unchecked @retroactive Sendable {}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import NetworkExtension
import SwiftUI

class PreviewSession: Session {
Expand All @@ -21,4 +22,8 @@ class PreviewSession: Session {
hasSession = false
sessionToken = nil
}

func tunnelProviderProtocol() -> NETunnelProviderProtocol? {
return nil
}
}
11 changes: 8 additions & 3 deletions Coder Desktop/Coder Desktop/Preview Content/PreviewVPN.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import NetworkExtension
import SwiftUI

@MainActor
Expand Down Expand Up @@ -28,10 +29,10 @@ final class PreviewVPN: Coder_Desktop.VPNService {
do {
try await Task.sleep(for: .seconds(10))
} catch {
state = .failed(.exampleError)
state = .failed(.longTestError)
return
}
state = shouldFail ? .failed(.exampleError) : .connected
state = shouldFail ? .failed(.longTestError) : .connected
}

func stop() async {
Expand All @@ -40,9 +41,13 @@ final class PreviewVPN: Coder_Desktop.VPNService {
do {
try await Task.sleep(for: .seconds(10))
} catch {
state = .failed(.exampleError)
state = .failed(.longTestError)
return
}
state = .disabled
}

func configureTunnelProviderProtocol(proto _: NETunnelProviderProtocol?) {
state = .connecting
}
}
22 changes: 20 additions & 2 deletions Coder Desktop/Coder Desktop/Session.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Foundation
import KeychainAccess
import NetworkExtension

protocol Session: ObservableObject {
var hasSession: Bool { get }
Expand All @@ -8,9 +9,12 @@ protocol Session: ObservableObject {

func store(baseAccessURL: URL, sessionToken: String)
func clear()
func tunnelProviderProtocol() -> NETunnelProviderProtocol?
}

class SecureSession: ObservableObject {
class SecureSession: ObservableObject, Session {
let appId = Bundle.main.bundleIdentifier!

// Stored in UserDefaults
@Published private(set) var hasSession: Bool {
didSet {
Expand All @@ -31,9 +35,21 @@ class SecureSession: ObservableObject {
}
}

func tunnelProviderProtocol() -> NETunnelProviderProtocol? {
if !hasSession { return nil }
let proto = NETunnelProviderProtocol()
proto.providerBundleIdentifier = "\(appId).VPN"
proto.passwordReference = keychain[attributes: Keys.sessionToken]?.persistentRef
proto.serverAddress = baseAccessURL!.absoluteString
return proto
}

private let keychain: Keychain

public init() {
let onChange: ((NETunnelProviderProtocol?) -> Void)?

public init(onChange: ((NETunnelProviderProtocol?) -> Void)? = nil) {
self.onChange = onChange
keychain = Keychain(service: Bundle.main.bundleIdentifier!)
_hasSession = Published(initialValue: UserDefaults.standard.bool(forKey: Keys.hasSession))
_baseAccessURL = Published(initialValue: UserDefaults.standard.url(forKey: Keys.baseAccessURL))
Expand All @@ -46,11 +62,13 @@ class SecureSession: ObservableObject {
hasSession = true
self.baseAccessURL = baseAccessURL
self.sessionToken = sessionToken
if let onChange { onChange(tunnelProviderProtocol()) }
}

public func clear() {
hasSession = false
sessionToken = nil
if let onChange { onChange(tunnelProviderProtocol()) }
}

private func keychainGet(for key: String) -> String? {
Expand Down
Loading