5
5
import unittest
6
6
from econml .dml import LinearDML
7
7
from econml .inference import StatsModelsInference
8
- from econml .sklearn_extensions . federated_learning import FederatedLearner
8
+ from econml .federated_learning import FederatedEstimator
9
9
10
10
11
11
class FunctionRegressor :
12
+ """A simple model that ignores the data it is fitted on, always just using the specified function to predict"""
12
13
def __init__ (self , func ):
13
14
self .func = func
14
15
@@ -20,7 +21,20 @@ def predict(self, X):
20
21
21
22
22
23
class TestFederatedLearning (unittest .TestCase ):
24
+ """
25
+ A set of unit tests for the FederatedLearner class.
23
26
27
+ These tests check various scenarios of splitting, aggregation, and comparison
28
+ between FederatedLearner and individual LinearDML estimators.
29
+
30
+ Parameters
31
+ ----------
32
+ None
33
+
34
+ Returns
35
+ -------
36
+ None
37
+ """
24
38
def test_splitting_works (self ):
25
39
26
40
num_samples = 1000
@@ -69,15 +83,19 @@ def test_splitting_works(self):
69
83
sample_weight = weights , freq_weight = freq_weights , sample_var = sample_var ,
70
84
inference = StatsModelsInference (cov_type = cov_type ))
71
85
est_h1 .fit (Y1 , T1 , X = X1 , W = W1 ,
72
- sample_weight = weights1 , freq_weight = freq_weights1 , sample_var = sample_var1 ,
86
+ sample_weight = weights1 ,
87
+ freq_weight = freq_weights1 ,
88
+ sample_var = sample_var1 ,
73
89
inference = StatsModelsInference (cov_type = cov_type ))
74
90
est_h2 .fit (Y2 , T2 , X = X2 , W = W2 ,
75
- sample_weight = weights2 , freq_weight = freq_weights2 , sample_var = sample_var2 ,
91
+ sample_weight = weights2 ,
92
+ freq_weight = freq_weights2 ,
93
+ sample_var = sample_var2 ,
76
94
inference = StatsModelsInference (cov_type = cov_type ))
77
95
78
- est_fed1 = FederatedLearner ([est_all ])
96
+ est_fed1 = FederatedEstimator ([est_all ])
79
97
80
- est_fed2 = FederatedLearner ([est_h1 , est_h2 ])
98
+ est_fed2 = FederatedEstimator ([est_h1 , est_h2 ])
81
99
82
100
np .testing .assert_allclose (est_fed1 .model_final_ ._param ,
83
101
est_fed2 .model_final_ ._param )
0 commit comments