4
4
//! [`ApiConnectionMode`], which in turn is used by `mullvad-api` for
5
5
//! establishing connections when performing API requests.
6
6
7
+ use crate :: proxy:: { ApiConnectionMode , ConnectionModeProvider } ;
7
8
#[ cfg( feature = "api-override" ) ]
8
9
use crate :: ApiEndpoint ;
9
- use crate :: {
10
- proxy:: { AllowedClientsProvider , ApiConnectionMode , ConnectionModeProvider } ,
11
- AddressCache ,
12
- } ;
13
10
use async_trait:: async_trait;
14
11
use futures:: {
15
12
channel:: { mpsc, oneshot} ,
16
13
StreamExt ,
17
14
} ;
18
15
use mullvad_types:: access_method:: { AccessMethod , AccessMethodSetting , Id , Settings } ;
19
- use std:: { marker :: PhantomData , net :: SocketAddr , path:: PathBuf } ;
20
- use talpid_types:: net:: { AllowedEndpoint , Endpoint , TransportProtocol } ;
16
+ use std:: path:: PathBuf ;
17
+ use talpid_types:: net:: AllowedEndpoint ;
21
18
22
19
pub enum Message {
23
20
Get ( ResponseTx < ResolvedConnectionMode > ) ,
@@ -242,34 +239,28 @@ impl ConnectionModeProvider for AccessModeConnectionModeProvider {
242
239
/// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`])
243
240
/// or via any supported custom proxy protocol
244
241
/// ([`talpid_types::net::proxy::CustomProxy`]).
245
- pub struct AccessModeSelector < P > {
242
+ pub struct AccessModeSelector < B : AccessMethodResolver > {
246
243
#[ cfg( feature = "api-override" ) ]
247
244
api_endpoint : ApiEndpoint ,
248
245
cmd_rx : mpsc:: UnboundedReceiver < Message > ,
249
246
cache_dir : PathBuf ,
250
- bridge_dns_proxy_provider : Box < dyn BridgeAndDNSProxy > ,
247
+ bridge_dns_proxy_provider : B ,
251
248
access_method_settings : Settings ,
252
- address_cache : AddressCache ,
253
249
access_method_event_sender : mpsc:: UnboundedSender < ( AccessMethodEvent , oneshot:: Sender < ( ) > ) > ,
254
250
connection_mode_provider_sender : mpsc:: UnboundedSender < ApiConnectionMode > ,
255
251
current : ResolvedConnectionMode ,
256
252
/// `index` is used to keep track of the [`AccessMethodSetting`] to use.
257
253
index : usize ,
258
- provider : PhantomData < P > ,
259
254
}
260
255
261
- impl < P > AccessModeSelector < P >
262
- where
263
- P : AllowedClientsProvider + ' static ,
264
- {
256
+ impl < B : AccessMethodResolver + ' static > AccessModeSelector < B > {
265
257
pub async fn spawn (
266
258
cache_dir : PathBuf ,
267
- mut bridge_dns_proxy_provider : Box < dyn BridgeAndDNSProxy > ,
259
+ mut bridge_dns_proxy_provider : B ,
268
260
#[ cfg_attr( not( feature = "api-override" ) , allow( unused_mut) ) ]
269
261
mut access_method_settings : Settings ,
270
262
#[ cfg( feature = "api-override" ) ] api_endpoint : ApiEndpoint ,
271
263
access_method_event_sender : mpsc:: UnboundedSender < ( AccessMethodEvent , oneshot:: Sender < ( ) > ) > ,
272
- address_cache : AddressCache ,
273
264
) -> Result < ( AccessModeSelectorHandle , AccessModeConnectionModeProvider ) > {
274
265
let ( cmd_tx, cmd_rx) = mpsc:: unbounded ( ) ;
275
266
@@ -283,30 +274,24 @@ where
283
274
284
275
// Always start looking from the position of `Direct`.
285
276
let ( index, next) = Self :: find_next_active ( 0 , & access_method_settings) ;
286
- let initial_connection_mode = Self :: resolve_inner_with_default (
287
- & next,
288
- & address_cache,
289
- & mut * bridge_dns_proxy_provider,
290
- )
291
- . await ;
277
+ let initial_connection_mode =
278
+ Self :: resolve_with_default ( & next, & mut bridge_dns_proxy_provider) . await ;
292
279
293
280
let ( change_tx, change_rx) = mpsc:: unbounded ( ) ;
294
281
295
282
let api_connection_mode = initial_connection_mode. connection_mode . clone ( ) ;
296
283
297
- let selector: AccessModeSelector < P > = AccessModeSelector {
284
+ let selector = AccessModeSelector {
298
285
#[ cfg( feature = "api-override" ) ]
299
286
api_endpoint,
300
287
cmd_rx,
301
288
cache_dir,
302
289
bridge_dns_proxy_provider,
303
290
access_method_settings,
304
- address_cache,
305
291
access_method_event_sender,
306
292
connection_mode_provider_sender : change_tx,
307
293
current : initial_connection_mode,
308
294
index,
309
- provider : PhantomData ,
310
295
} ;
311
296
312
297
tokio:: spawn ( selector. into_future ( ) ) ;
@@ -408,7 +393,8 @@ where
408
393
}
409
394
410
395
async fn set_current ( & mut self , access_method : AccessMethodSetting ) {
411
- let resolved = self . resolve_with_default ( access_method) . await ;
396
+ let resolved =
397
+ Self :: resolve_with_default ( & access_method, & mut self . bridge_dns_proxy_provider ) . await ;
412
398
413
399
// Note: If the daemon is busy waiting for a call to this function
414
400
// to complete while we wait for the daemon to fully handle this
@@ -522,89 +508,49 @@ where
522
508
523
509
async fn resolve (
524
510
& mut self ,
525
- access_method : AccessMethodSetting ,
526
- ) -> Option < ResolvedConnectionMode > {
527
- Self :: resolve_inner (
528
- & access_method,
529
- & self . address_cache ,
530
- & mut * self . bridge_dns_proxy_provider ,
531
- )
532
- . await
533
- }
534
-
535
- async fn resolve_inner (
536
- access_method : & AccessMethodSetting ,
537
- address_cache : & AddressCache ,
538
- bridge_dns_proxy_provider : & mut dyn BridgeAndDNSProxy ,
511
+ method_setting : AccessMethodSetting ,
539
512
) -> Option < ResolvedConnectionMode > {
540
- let connection_mode = bridge_dns_proxy_provider
541
- . match_access_method ( access_method)
513
+ let ( endpoint, connection_mode) = self
514
+ . bridge_dns_proxy_provider
515
+ . resolve_access_method_setting ( & method_setting. access_method )
542
516
. await ?;
543
- let endpoint =
544
- resolve_allowed_endpoint :: < P > ( & connection_mode, address_cache. get_address ( ) . await ) ;
545
517
Some ( ResolvedConnectionMode {
546
518
connection_mode,
547
519
endpoint,
548
- setting : access_method . clone ( ) ,
520
+ setting : method_setting ,
549
521
} )
550
522
}
551
523
552
524
/// Resolve an access method into a set of connection details - fall back to
553
525
/// [`ApiConnectionMode::Direct`] in case `access_method` does not yield anything.
554
526
async fn resolve_with_default (
555
- & mut self ,
556
- access_method : AccessMethodSetting ,
557
- ) -> ResolvedConnectionMode {
558
- Self :: resolve_inner_with_default (
559
- & access_method,
560
- & self . address_cache ,
561
- & mut * self . bridge_dns_proxy_provider ,
562
- )
563
- . await
564
- }
565
-
566
- async fn resolve_inner_with_default (
567
- access_method : & AccessMethodSetting ,
568
- address_cache : & AddressCache ,
569
- bridge_dns_proxy_provider : & mut dyn BridgeAndDNSProxy ,
527
+ method_setting : & AccessMethodSetting ,
528
+ bridge_dns_proxy_provider : & mut B ,
570
529
) -> ResolvedConnectionMode {
571
- match Self :: resolve_inner ( access_method, address_cache, bridge_dns_proxy_provider) . await {
530
+ let ( endpoint, connection_mode) = match bridge_dns_proxy_provider
531
+ . resolve_access_method_setting ( & method_setting. access_method )
532
+ . await
533
+ {
572
534
Some ( resolved) => resolved,
573
- None => {
574
- log:: trace!( "Defaulting to direct API connection" ) ;
575
- let endpoint = resolve_allowed_endpoint :: < P > (
576
- & ApiConnectionMode :: Direct ,
577
- address_cache. get_address ( ) . await ,
578
- ) ;
579
- ResolvedConnectionMode {
580
- connection_mode : ApiConnectionMode :: Direct ,
581
- endpoint,
582
- setting : access_method. clone ( ) ,
583
- }
584
- }
535
+ None => (
536
+ bridge_dns_proxy_provider. default_connection_mode ( ) . await ,
537
+ ApiConnectionMode :: Direct ,
538
+ ) ,
539
+ } ;
540
+ ResolvedConnectionMode {
541
+ connection_mode,
542
+ endpoint,
543
+ setting : method_setting. clone ( ) ,
585
544
}
586
545
}
587
546
}
588
547
589
548
#[ async_trait]
590
- pub trait BridgeAndDNSProxy : Send + Sync {
591
- async fn match_access_method (
549
+ pub trait AccessMethodResolver : Send + Sync {
550
+ async fn resolve_access_method_setting (
592
551
& mut self ,
593
- access_method : & AccessMethodSetting ,
594
- ) -> Option < ApiConnectionMode > ;
595
- }
552
+ access_method : & AccessMethod ,
553
+ ) -> Option < ( AllowedEndpoint , ApiConnectionMode ) > ;
596
554
597
- pub fn resolve_allowed_endpoint < P > (
598
- connection_mode : & ApiConnectionMode ,
599
- fallback : SocketAddr ,
600
- ) -> AllowedEndpoint
601
- where
602
- P : AllowedClientsProvider ,
603
- {
604
- let endpoint = match connection_mode. get_endpoint ( ) {
605
- Some ( endpoint) => endpoint,
606
- None => Endpoint :: from_socket_address ( fallback, TransportProtocol :: Tcp ) ,
607
- } ;
608
- let clients = P :: allowed_clients ( connection_mode) ;
609
- AllowedEndpoint { endpoint, clients }
555
+ async fn default_connection_mode ( & self ) -> AllowedEndpoint ;
610
556
}
0 commit comments