Skip to content

Commit 6c7e485

Browse files
committed
cleanup branch: test, file structure, docstring and linting
Signed-off-by: kgao <kevin.leo.gao@gmail.com>
1 parent e652b08 commit 6c7e485

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

econml/tests/test_federated_learning.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import unittest
66
from econml.dml import LinearDML
77
from econml.inference import StatsModelsInference
8-
from econml.sklearn_extensions.federated_learning import FederatedLearner
8+
from econml.federated_learning import FederatedEstimator
99

1010

1111
class FunctionRegressor:
12+
"""A simple model that ignores the data it is fitted on, always just using the specified function to predict"""
1213
def __init__(self, func):
1314
self.func = func
1415

@@ -20,7 +21,20 @@ def predict(self, X):
2021

2122

2223
class TestFederatedLearning(unittest.TestCase):
24+
"""
25+
A set of unit tests for the FederatedLearner class.
2326
27+
These tests check various scenarios of splitting, aggregation, and comparison
28+
between FederatedLearner and individual LinearDML estimators.
29+
30+
Parameters
31+
----------
32+
None
33+
34+
Returns
35+
-------
36+
None
37+
"""
2438
def test_splitting_works(self):
2539

2640
num_samples = 1000
@@ -69,15 +83,19 @@ def test_splitting_works(self):
6983
sample_weight=weights, freq_weight=freq_weights, sample_var=sample_var,
7084
inference=StatsModelsInference(cov_type=cov_type))
7185
est_h1.fit(Y1, T1, X=X1, W=W1,
72-
sample_weight=weights1, freq_weight=freq_weights1, sample_var=sample_var1,
86+
sample_weight=weights1,
87+
freq_weight=freq_weights1,
88+
sample_var=sample_var1,
7389
inference=StatsModelsInference(cov_type=cov_type))
7490
est_h2.fit(Y2, T2, X=X2, W=W2,
75-
sample_weight=weights2, freq_weight=freq_weights2, sample_var=sample_var2,
91+
sample_weight=weights2,
92+
freq_weight=freq_weights2,
93+
sample_var=sample_var2,
7694
inference=StatsModelsInference(cov_type=cov_type))
7795

78-
est_fed1 = FederatedLearner([est_all])
96+
est_fed1 = FederatedEstimator([est_all])
7997

80-
est_fed2 = FederatedLearner([est_h1, est_h2])
98+
est_fed2 = FederatedEstimator([est_h1, est_h2])
8199

82100
np.testing.assert_allclose(est_fed1.model_final_._param,
83101
est_fed2.model_final_._param)

0 commit comments

Comments
 (0)