Skip to content

Commit a642fa5

Browse files
author
Rafal Mucha
committed
Issue-374 Add unit tests & docs
1 parent 2a32b8a commit a642fa5

File tree

6 files changed

+52
-14
lines changed

6 files changed

+52
-14
lines changed

py4j-java/src/test/java/py4j/commands/DirCommandTest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public class DirCommandTest {
6969
{
7070
// Defined in ExampleClass
7171
ExampleClassMethods.addAll(Arrays.asList(new String[] { "method1", "method2", "method3", "method4", "method5",
72-
"method6", "method7", "method8", "method9", "method10", "method11", "getList", "getField1", "setField1",
73-
"getStringArray", "getIntArray", "callHello", "callHello2", "static_method", "getInteger",
72+
"method6", "method7", "method8", "method9", "method10", "method11", "method12", "getList", "getField1",
73+
"setField1", "getStringArray", "getIntArray", "callHello", "callHello2", "static_method", "getInteger",
7474
"getBrokenStream", "getStream", "sleepFirstTimeOnly" }));
7575
// Defined in Object
7676
ExampleClassMethods.addAll(Arrays

py4j-java/src/test/java/py4j/examples/ExampleClass.java

+15
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.nio.channels.Channels;
3838
import java.nio.channels.ReadableByteChannel;
3939
import java.util.ArrayList;
40+
import java.util.HashSet;
4041
import java.util.List;
4142

4243
public class ExampleClass {
@@ -174,6 +175,20 @@ public BigInteger method11(BigInteger bi) {
174175
return bi.add(new BigInteger("1"));
175176
}
176177

178+
public int method12(HashSet<Object> set) {
179+
Object element = set.stream().findAny().get();
180+
if (element instanceof Long) {
181+
return 4;
182+
}
183+
if (element instanceof Integer) {
184+
return 1;
185+
}
186+
if (element instanceof String) {
187+
return 2;
188+
}
189+
return 3;
190+
}
191+
177192
@SuppressWarnings("unused")
178193
private int private_method() {
179194
return 0;

py4j-python/src/py4j/protocol.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from base64 import standard_b64encode, standard_b64decode
2222

2323
from decimal import Decimal
24-
from enum import Enum
25-
from collections import namedtuple
2624

2725
from py4j.compat import (
2826
long, basestring, unicode, bytearray2,
@@ -72,11 +70,12 @@
7270
ITERATOR_TYPE = "g"
7371
PYTHON_PROXY_TYPE = "f"
7472

75-
class JavaType(Enum):
76-
PRIMITIVE_INT = INTEGER_TYPE
77-
PRIMITIVE_LONG = LONG_TYPE
78-
79-
TypeInt = namedtuple('TypeInt', ['value', 'java_type'])
73+
class TypeHint:
74+
"""Enables users to provide a hint to the Python to Java converter specifying the accurate data type for a given value.
75+
Essential to enforce i.e. correct number type, like Long."""
76+
def __init__(self, value, java_type):
77+
self.value = value
78+
self.java_type = java_type
8079

8180
# Protocol
8281
END = "e"
@@ -281,12 +280,12 @@ def get_command_part(parameter, python_proxy_pool=None):
281280

282281
if parameter is None:
283282
command_part = NULL_TYPE
283+
elif isinstance(parameter, TypeHint):
284+
command_part = parameter.java_type + smart_decode(parameter.value)
284285
elif isinstance(parameter, bool):
285286
command_part = BOOLEAN_TYPE + smart_decode(parameter)
286287
elif isinstance(parameter, Decimal):
287288
command_part = DECIMAL_TYPE + smart_decode(parameter)
288-
elif isinstance(parameter, TypeInt):
289-
command_part = parameter.java_type.value + smart_decode(parameter.value)
290289
elif isinstance(parameter, int) and parameter <= JAVA_MAX_INT\
291290
and parameter >= JAVA_MIN_INT:
292291
command_part = INTEGER_TYPE + smart_decode(parameter)

py4j-python/src/py4j/tests/java_dir_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# overloaded
3131
"method10",
3232
"method11",
33+
"method12",
3334
"getList",
3435
"getField1",
3536
"setField1",

py4j-python/src/py4j/tests/java_gateway_test.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
set_default_callback_accept_timeout, GatewayConnectionGuard,
3434
get_java_class)
3535
from py4j.protocol import (
36-
Py4JError, Py4JJavaError, Py4JNetworkError, decode_bytearray,
37-
encode_bytearray, escape_new_line, unescape_new_line, smart_decode)
36+
Py4JError, Py4JJavaError, Py4JNetworkError, TypeHint, LONG_TYPE,
37+
decode_bytearray, encode_bytearray, escape_new_line, unescape_new_line, smart_decode)
3838

3939

4040
SERVER_PORT = 25333
@@ -607,7 +607,7 @@ def internal():
607607
class TypeConversionTest(unittest.TestCase):
608608
def setUp(self):
609609
self.p = start_example_app_process()
610-
self.gateway = JavaGateway()
610+
self.gateway = JavaGateway(auto_convert=True)
611611

612612
def tearDown(self):
613613
safe_shutdown(self)
@@ -619,6 +619,8 @@ def testLongInt(self):
619619
self.assertEqual(4, ex.method7(2147483648))
620620
self.assertEqual(4, ex.method7(-2147483649))
621621
self.assertEqual(4, ex.method7(long(2147483648)))
622+
self.assertEqual(4, ex.method7(TypeHint(1, LONG_TYPE)))
623+
self.assertEqual(4, ex.method12({TypeHint(1, LONG_TYPE)}))
622624
self.assertEqual(long(4), ex.method8(3))
623625
self.assertEqual(4, ex.method8(3))
624626
self.assertEqual(long(4), ex.method8(long(3)))

py4j-web/advanced_topics.rst

+21
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,27 @@ Java methods slightly less efficient because in the worst case, Py4J needs to
726726
go through all registered converters for all parameters. This is why automatic
727727
conversion is disabled by default.
728728

729+
.. _explicit_conversion:
730+
731+
Explicit converting Python objects to Java primitives
732+
-----------------------------------------------------
733+
734+
Sometimes, especially when ``auto_convert=True`` it is difficult to enforce correct type
735+
passed from Python to Java. Then, ``TypeHint`` from ``py4j.protocol`` may be used.
736+
``java_type`` argument of constructor should be one of Java types defined in ``py4j.protocol``.
737+
738+
So if you have method in Java like:
739+
740+
.. code-block:: java
741+
742+
void method(HashSet<Long> longs) {}
743+
744+
Then you can pass arguments with correct type to this method with ``TypeHint``
745+
746+
::
747+
748+
>>> set_with_longs = { TypeHint(1, LONG_TYPE), TypeHint(2, LONG_TYPE) }
749+
>>> gateway.jvm.my.Class().method(set_with_longs)
729750

730751
.. _py4j_exceptions:
731752

0 commit comments

Comments
 (0)