Skip to content

Commit ceb24b7

Browse files
authored
Merge pull request #6 from messier16/fix-bugged_unseen_labels
Fix error in CategoryEncoder transform
2 parents f1fb37b + cbd8d4f commit ceb24b7

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

m16_mlutils/pipeline/CategoryEncoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ def transform(self, X):
4848
o = self.one_hot_encoders[c].transform(
4949
values.reshape(len(values), 1))
5050
except (KeyError, ValueError):
51-
o = np.zeros((1, len(self.label_encoders[c].classes_)))
51+
o = np.zeros((len(X), len(self.label_encoders[c].classes_)))
5252
one_hots.append(o)
5353
return np.concatenate(one_hots, axis=1)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from setuptools.command.install import install
1414

1515
# Package version
16-
VERSION = "0.4.3"
16+
VERSION = "0.4.4"
1717

1818
class VerifyVersionCommand(install):
1919
"""Custom command to verify that the git tag matches our version"""

tests/test_CategoryEncoder.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,17 @@ def test_previously_unseen(self):
8686
})
8787

8888
test = pd.DataFrame({
89-
'l': ['z']
89+
'l': ['z', 'y']
9090
})
9191

9292
encoder = CategoryEncoder()
9393

94-
expected = self.array([[0,0,0]])
94+
expected = self.array([
95+
[0, 0, 0],
96+
[0, 0, 0],
97+
])
9598

9699
encoder.fit(train)
97100
actual = encoder.transform(test)
98101

99-
self.assertArrayEqual(expected, actual)
102+
self.assertArrayEqual(expected, actual)

0 commit comments

Comments
 (0)