Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/py/src/ext/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class BinaryProtocol : public ProtocolBase<BinaryProtocol> {
return encodeValue(value, parsedspec.type, parsedspec.typeargs);
}

void writeUuid(char* value) {
void writeUuid(const char* value) {
writeBuffer(value, 16);
}

Expand Down
2 changes: 1 addition & 1 deletion lib/py/src/ext/compact.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class CompactProtocol : public ProtocolBase<CompactProtocol> {

void writeFieldStop() { writeByte(0); }

void writeUuid(char* value) {
void writeUuid(const char* value) {
writeBuffer(value, 16);
}

Expand Down
2 changes: 1 addition & 1 deletion lib/py/src/ext/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ProtocolBase {
return true;
}

bool writeBuffer(char* data, size_t len);
bool writeBuffer(const char* data, size_t len);

void writeByte(uint8_t val) { writeBuffer(reinterpret_cast<char*>(&val), 1); }

Expand Down
20 changes: 17 additions & 3 deletions lib/py/src/ext/protocol.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ PyObject* ProtocolBase<Impl>::getEncodedValue() {
}

template <typename Impl>
inline bool ProtocolBase<Impl>::writeBuffer(char* data, size_t size) {
inline bool ProtocolBase<Impl>::writeBuffer(const char* data, size_t size) {
if (!PycStringIO) {
PycString_IMPORT;
}
Expand Down Expand Up @@ -169,7 +169,7 @@ PyObject* ProtocolBase<Impl>::getEncodedValue() {
}

template <typename Impl>
inline bool ProtocolBase<Impl>::writeBuffer(char* data, size_t size) {
inline bool ProtocolBase<Impl>::writeBuffer(const char* data, size_t size) {
size_t need = size + output_->pos;
if (output_->buf.capacity() < need) {
try {
Expand Down Expand Up @@ -456,18 +456,32 @@ bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* type

case T_STRING: {
ScopedPyObject nval;
Py_ssize_t len;

if (PyUnicode_Check(value)) {
#if PY_VERSION_HEX >= 0x03030000
const char* str = PyUnicode_AsUTF8AndSize(value, &len);
if (!str) {
return false;
}
if (!detail::check_ssize_t_32(len)) {
return false;
}

impl()->writeI32(static_cast<int32_t>(len));
return writeBuffer(str, static_cast<size_t>(len));
#else
nval.reset(PyUnicode_AsUTF8String(value));
if (!nval) {
return false;
}
#endif
} else {
Py_INCREF(value);
nval.reset(value);
}

Py_ssize_t len = PyBytes_Size(nval.get());
len = PyBytes_Size(nval.get());
if (!detail::check_ssize_t_32(len)) {
return false;
}
Expand Down
39 changes: 39 additions & 0 deletions lib/py/test/thrift_TBinaryProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import uuid

import _import_local_thrift # noqa
from thrift.Thrift import TApplicationException
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory
Comment on lines +25 to +27
from thrift.protocol.TProtocol import TProtocolException
from thrift.transport import TTransport

Expand Down Expand Up @@ -167,6 +169,16 @@ def testField(type, data):
protocol.readStructEnd()


APPLICATION_EXCEPTION_TYPEARGS = [
TApplicationException,
(
None,
(1, 11, "message", "UTF8", None),
(2, 8, "type", None, None),
),
]
Comment on lines +172 to +179


def testMessage(data, strict=True):
message = {}
message['name'] = data[0]
Expand Down Expand Up @@ -196,6 +208,13 @@ def testMessage(data, strict=True):

class TestTBinaryProtocol(unittest.TestCase):

def setUp(self):
try:
from thrift.protocol import fastbinary # noqa: F401
self._has_fastbinary = True
except ImportError:
self._has_fastbinary = False

def test_TBinaryProtocol_write_read(self):
try:
testNaked('Byte', 123)
Expand Down Expand Up @@ -280,6 +299,26 @@ def test_TBinaryProtocol_write_read(self):
print("Assertion fail")
raise e

def test_accelerated_utf8_roundtrip_on_application_exception(self):
if not self._has_fastbinary:
self.skipTest("C extension not available")

original = TApplicationException(
type=TApplicationException.PROTOCOL_ERROR,
message=("snowman-\u2603-rocket-\U0001F680-" * 32),
)

otrans = TTransport.TMemoryBuffer()
oproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(otrans)
oproto.trans.write(oproto._fast_encode(original, APPLICATION_EXCEPTION_TYPEARGS))

itrans = TTransport.TMemoryBuffer(otrans.getvalue())
iproto = TBinaryProtocolAcceleratedFactory(fallback=False).getProtocol(itrans)
decoded = iproto._fast_decode(None, iproto, APPLICATION_EXCEPTION_TYPEARGS)

self.assertEqual(decoded.message, original.message)
self.assertEqual(decoded.type, original.type)

def test_TBinaryProtocol_no_strict_write_read(self):
TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
test_data = [("short message name", TMessageType['T_CALL'], 0),
Expand Down
Loading