7
7
#[ cfg( feature = "api-override" ) ]
8
8
use crate :: ApiEndpoint ;
9
9
use crate :: {
10
- proxy:: { AllowedClientsProvider , ApiConnectionMode , ConnectionModeProvider } ,
10
+ proxy:: { ApiConnectionMode , ConnectionModeProvider } ,
11
11
AddressCache ,
12
12
} ;
13
13
use async_trait:: async_trait;
@@ -16,8 +16,8 @@ use futures::{
16
16
StreamExt ,
17
17
} ;
18
18
use mullvad_types:: access_method:: { AccessMethod , AccessMethodSetting , Id , Settings } ;
19
- use std:: { marker :: PhantomData , net:: SocketAddr , path:: PathBuf } ;
20
- use talpid_types:: net:: { AllowedEndpoint , Endpoint , TransportProtocol } ;
19
+ use std:: { net:: SocketAddr , path:: PathBuf } ;
20
+ use talpid_types:: net:: { AllowedClients , AllowedEndpoint , Endpoint , TransportProtocol } ;
21
21
22
22
pub enum Message {
23
23
Get ( ResponseTx < ResolvedConnectionMode > ) ,
@@ -242,29 +242,25 @@ impl ConnectionModeProvider for AccessModeConnectionModeProvider {
242
242
/// [`ApiConnectionMode::Direct`]) via a bridge ([`ApiConnectionMode::Proxied`])
243
243
/// or via any supported custom proxy protocol
244
244
/// ([`talpid_types::net::proxy::CustomProxy`]).
245
- pub struct AccessModeSelector < P > {
245
+ pub struct AccessModeSelector < B : AccessMethodResolver > {
246
246
#[ cfg( feature = "api-override" ) ]
247
247
api_endpoint : ApiEndpoint ,
248
248
cmd_rx : mpsc:: UnboundedReceiver < Message > ,
249
249
cache_dir : PathBuf ,
250
- bridge_dns_proxy_provider : Box < dyn BridgeAndDNSProxy > ,
250
+ bridge_dns_proxy_provider : B ,
251
251
access_method_settings : Settings ,
252
252
address_cache : AddressCache ,
253
253
access_method_event_sender : mpsc:: UnboundedSender < ( AccessMethodEvent , oneshot:: Sender < ( ) > ) > ,
254
254
connection_mode_provider_sender : mpsc:: UnboundedSender < ApiConnectionMode > ,
255
255
current : ResolvedConnectionMode ,
256
256
/// `index` is used to keep track of the [`AccessMethodSetting`] to use.
257
257
index : usize ,
258
- provider : PhantomData < P > ,
259
258
}
260
259
261
- impl < P > AccessModeSelector < P >
262
- where
263
- P : AllowedClientsProvider + ' static ,
264
- {
260
+ impl < B : AccessMethodResolver + ' static > AccessModeSelector < B > {
265
261
pub async fn spawn (
266
262
cache_dir : PathBuf ,
267
- mut bridge_dns_proxy_provider : Box < dyn BridgeAndDNSProxy > ,
263
+ mut bridge_dns_proxy_provider : B ,
268
264
#[ cfg_attr( not( feature = "api-override" ) , allow( unused_mut) ) ]
269
265
mut access_method_settings : Settings ,
270
266
#[ cfg( feature = "api-override" ) ] api_endpoint : ApiEndpoint ,
@@ -283,18 +279,15 @@ where
283
279
284
280
// Always start looking from the position of `Direct`.
285
281
let ( index, next) = Self :: find_next_active ( 0 , & access_method_settings) ;
286
- let initial_connection_mode = Self :: resolve_inner_with_default (
287
- & next,
288
- & address_cache,
289
- & mut * bridge_dns_proxy_provider,
290
- )
291
- . await ;
282
+ let initial_connection_mode =
283
+ Self :: resolve_inner_with_default ( & next, & address_cache, & mut bridge_dns_proxy_provider)
284
+ . await ;
292
285
293
286
let ( change_tx, change_rx) = mpsc:: unbounded ( ) ;
294
287
295
288
let api_connection_mode = initial_connection_mode. connection_mode . clone ( ) ;
296
289
297
- let selector: AccessModeSelector < P > = AccessModeSelector {
290
+ let selector = AccessModeSelector {
298
291
#[ cfg( feature = "api-override" ) ]
299
292
api_endpoint,
300
293
cmd_rx,
@@ -306,7 +299,6 @@ where
306
299
connection_mode_provider_sender : change_tx,
307
300
current : initial_connection_mode,
308
301
index,
309
- provider : PhantomData ,
310
302
} ;
311
303
312
304
tokio:: spawn ( selector. into_future ( ) ) ;
@@ -527,25 +519,25 @@ where
527
519
Self :: resolve_inner (
528
520
& access_method,
529
521
& self . address_cache ,
530
- & mut * self . bridge_dns_proxy_provider ,
522
+ & mut self . bridge_dns_proxy_provider ,
531
523
)
532
524
. await
533
525
}
534
526
535
527
async fn resolve_inner (
536
- access_method : & AccessMethodSetting ,
528
+ method_setting : & AccessMethodSetting ,
537
529
address_cache : & AddressCache ,
538
- bridge_dns_proxy_provider : & mut dyn BridgeAndDNSProxy ,
530
+ bridge_dns_proxy_provider : & mut B ,
539
531
) -> Option < ResolvedConnectionMode > {
540
532
let connection_mode = bridge_dns_proxy_provider
541
- . match_access_method ( access_method )
533
+ . resolve_access_method_setting ( method_setting )
542
534
. await ?;
543
535
let endpoint =
544
- resolve_allowed_endpoint :: < P > ( & connection_mode, address_cache. get_address ( ) . await ) ;
536
+ resolve_allowed_endpoint :: < B > ( & connection_mode, address_cache. get_address ( ) . await ) ;
545
537
Some ( ResolvedConnectionMode {
546
538
connection_mode,
547
539
endpoint,
548
- setting : access_method . clone ( ) ,
540
+ setting : method_setting . clone ( ) ,
549
541
} )
550
542
}
551
543
@@ -558,21 +550,21 @@ where
558
550
Self :: resolve_inner_with_default (
559
551
& access_method,
560
552
& self . address_cache ,
561
- & mut * self . bridge_dns_proxy_provider ,
553
+ & mut self . bridge_dns_proxy_provider ,
562
554
)
563
555
. await
564
556
}
565
557
566
558
async fn resolve_inner_with_default (
567
559
access_method : & AccessMethodSetting ,
568
560
address_cache : & AddressCache ,
569
- bridge_dns_proxy_provider : & mut dyn BridgeAndDNSProxy ,
561
+ bridge_dns_proxy_provider : & mut B ,
570
562
) -> ResolvedConnectionMode {
571
563
match Self :: resolve_inner ( access_method, address_cache, bridge_dns_proxy_provider) . await {
572
564
Some ( resolved) => resolved,
573
565
None => {
574
566
log:: trace!( "Defaulting to direct API connection" ) ;
575
- let endpoint = resolve_allowed_endpoint :: < P > (
567
+ let endpoint = resolve_allowed_endpoint :: < B > (
576
568
& ApiConnectionMode :: Direct ,
577
569
address_cache. get_address ( ) . await ,
578
570
) ;
@@ -587,24 +579,26 @@ where
587
579
}
588
580
589
581
#[ async_trait]
590
- pub trait BridgeAndDNSProxy : Send + Sync {
591
- async fn match_access_method (
582
+ pub trait AccessMethodResolver : Send + Sync {
583
+ async fn resolve_access_method_setting (
592
584
& mut self ,
593
585
access_method : & AccessMethodSetting ,
594
586
) -> Option < ApiConnectionMode > ;
587
+
588
+ fn allowed_clients ( connection_mode : & ApiConnectionMode ) -> AllowedClients ;
595
589
}
596
590
597
- pub fn resolve_allowed_endpoint < P > (
591
+ pub fn resolve_allowed_endpoint < B > (
598
592
connection_mode : & ApiConnectionMode ,
599
593
fallback : SocketAddr ,
600
594
) -> AllowedEndpoint
601
595
where
602
- P : AllowedClientsProvider ,
596
+ B : AccessMethodResolver ,
603
597
{
604
598
let endpoint = match connection_mode. get_endpoint ( ) {
605
599
Some ( endpoint) => endpoint,
606
600
None => Endpoint :: from_socket_address ( fallback, TransportProtocol :: Tcp ) ,
607
601
} ;
608
- let clients = P :: allowed_clients ( connection_mode) ;
602
+ let clients = B :: allowed_clients ( connection_mode) ;
609
603
AllowedEndpoint { endpoint, clients }
610
604
}
0 commit comments