1111#include <pb_encode.h>
1212#include <squareup/subzero/internal.pb.h>
1313
14- static size_t get_serialized_message_size (const InternalCommandRequest * const cmd ) {
14+ // Helper which returns the size of a buffer that would be needed to hold the serialized version of the given
15+ // protobuf structure, assuming that pb_encode_delimited() serialization will be used.
16+ static size_t get_serialized_message_size (const pb_field_t fields [], const void * const proto ) {
1517 pb_ostream_t stream = PB_OSTREAM_SIZING ;
16- if (!pb_encode_delimited (& stream , InternalCommandRequest_fields , cmd )) {
18+ if (!pb_encode_delimited (& stream , fields , proto )) {
1719 ERROR ("%s: pb_encode_delimited() failed: %s" , __func__ , PB_GET_ERROR (& stream ));
1820 return 0 ;
1921 }
2022 return stream .bytes_written ;
2123}
2224
25+ // Serializes the given protobuf structure to the given buffer of the given size, using pb_encode_delimited().
26+ // Returns true on success or false on failure.
27+ // Caller brings their own buffer memory, this function does not allocate.
28+ static bool
29+ serialize_to_buf (void * const buffer , const size_t buffer_size , const pb_field_t fields [], const void * const proto ) {
30+ pb_ostream_t ostream = pb_ostream_from_buffer (buffer , buffer_size );
31+ if (!pb_encode_delimited (& ostream , fields , proto )) {
32+ ERROR ("%s: pb_encode_delimited() failed: %s" , __func__ , PB_GET_ERROR (& ostream ));
33+ return false;
34+ }
35+ return true;
36+ }
37+
38+ // Deserializes the given buffer into the given protobuf structure, using pb_decode_delimited().
39+ // Returns true on success or false on failure.
40+ // Caller brings their own protobuf structure, this function does not allocate.
41+ static bool
42+ deserialize_from_buf (const void * const buffer , const size_t buffer_size , const pb_field_t fields [], void * const proto ) {
43+ pb_istream_t istream = pb_istream_from_buffer (buffer , buffer_size );
44+ if (!pb_decode_delimited (& istream , fields , proto )) {
45+ ERROR ("%s: pb_decode_delimited() failed: %s" , __func__ , PB_GET_ERROR (& istream ));
46+ return false;
47+ }
48+ return true;
49+ }
50+
2351int verify_rpc_oversized_message_rejected (void ) {
2452 int result = 0 ;
2553 uint8_t * serialized_request = NULL ;
2654 uint8_t * serialized_response = NULL ;
2755
56+ // Construct an initial InternalCommandRequest which holds an InitWallet command
57+ // with a maximum-allowed-length random_bytes field.
2858 InternalCommandRequest cmd = InternalCommandRequest_init_default ;
2959 cmd .version = VERSION ;
3060 cmd .wallet_id = 1 ; // dummy value
@@ -35,12 +65,15 @@ int verify_rpc_oversized_message_rejected(void) {
3565 cmd .command .InitWallet .random_bytes .size = MASTER_SEED_SIZE ;
3666 random_buffer (cmd .command .InitWallet .random_bytes .bytes , MASTER_SEED_SIZE );
3767
38- size_t serialized_size = get_serialized_message_size (& cmd );
68+ // Compute the size of the serialized struct.
69+ size_t serialized_size = get_serialized_message_size (InternalCommandRequest_fields , & cmd );
3970 if (serialized_size == 0 ) {
71+ ERROR ("%s: error computing serialized request size" , __func__ );
4072 result = -1 ;
4173 goto out ;
4274 }
4375
76+ // Allocate a buffer to hold the serialized struct.
4477 // Note that we allocate 1 extra byte because we'll be extending the message.
4578 serialized_request = (uint8_t * ) calloc (1 , serialized_size + 1 );
4679 if (NULL == serialized_request ) {
@@ -49,30 +82,30 @@ int verify_rpc_oversized_message_rejected(void) {
4982 goto out ;
5083 }
5184
52- pb_ostream_t ostream = pb_ostream_from_buffer ( serialized_request , serialized_size );
53- if (!pb_encode_delimited ( & ostream , InternalCommandRequest_fields , & cmd )) {
54- ERROR ("%s: pb_encode_delimited () failed: %s " , __func__ , PB_GET_ERROR ( & ostream ) );
85+ // Serialize the struct to a byte array.
86+ if (!serialize_to_buf ( serialized_request , serialized_size , InternalCommandRequest_fields , & cmd )) {
87+ ERROR ("%s: serialize_to_buf () failed" , __func__ );
5588 result = -1 ;
5689 goto out ;
5790 }
5891
5992 // Helper macro used to check our assumptions in the gnarly protobuf mangling code below
60- #define ASSERT_BYTE_EQUALS (buf , idx , expected_val ) \
61- do { \
62- const uint8_t* buf_ = (buf); \
63- const size_t idx_ = (idx); \
64- const uint8_t expected_val_ = (expected_val); \
65- const uint8_t actual_val_ = buf_[idx_]; \
66- if (actual_val_ != expected_val_) { \
67- ERROR( \
68- "%s: buf[%zu] contains unexpected value: %hhu, expected: %hhu", \
69- __func__, \
70- idx_, \
71- actual_val_, \
72- expected_val_); \
73- result = -1; \
74- goto out; \
75- } \
93+ #define ASSERT_BYTE_EQUALS (buf , idx , expected_val ) \
94+ do { \
95+ const uint8_t* buf_ = (buf); \
96+ const size_t idx_ = (idx); \
97+ const uint8_t expected_val_ = (expected_val); \
98+ const uint8_t actual_val_ = buf_[idx_]; \
99+ if (actual_val_ != expected_val_) { \
100+ ERROR( \
101+ "%s: buf[%zu] contains an unexpected value: %hhu, expected: %hhu", \
102+ __func__, \
103+ idx_, \
104+ actual_val_, \
105+ expected_val_); \
106+ result = -1; \
107+ goto out; \
108+ } \
76109 } while (0)
77110
78111 // Corrupt the message by making the random_bytes field 1 byte longer than the max allowed size.
@@ -87,23 +120,23 @@ int verify_rpc_oversized_message_rejected(void) {
87120 // length will actually take more than 1 byte, shifting everything after
88121 // it by a byte.
89122 // *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
90- // serialized_request[1] - field id (1 << 3) + tag (0) for field 1 (version). Should equal 0x08 .
123+ // serialized_request[1] - field id (1 << 3) + tag (0) for field 1 (version). Should equal 8 .
91124 // serialized_request[2..3] - varint-encoded value for field 1. Leave this alone, it's the
92125 // contents of the 'version' field (210 at the time of writing). If
93126 // version ever exceeds 16383, this will start taking up an extra byte
94127 // and shift everything after it by a byte.
95- // serialized_request[4] - field id (2 << 3) + tag (0) for field 2 (wallet_id). Should equal 0x10 .
96- // serialized_request[5] - varint-encoded value for field 2. Leave this allone , it's the dummy
97- // 'wallet' field which we set to 1 above. Should equal 0x01 .
128+ // serialized_request[4] - field id (2 << 3) + tag (0) for field 2 (wallet_id). Should equal 16 .
129+ // serialized_request[5] - varint-encoded value for field 2. Leave this alone , it's the dummy
130+ // 'wallet' field which we set to 1 above. Should equal 1 .
98131 // serialized_request[6] - field id (5 << 3) + tag (2, for 'LEN') for field 5 (command.InitWallet).
99- // Should equal 0x2a .
132+ // Should equal 42 .
100133 // serialized_request[7] - varint-encoded LEN of the InitWalletRequest submessage.
101- // Should equal 0x42 (decimal 66) .
134+ // Should equal 66 .
102135 // *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
103136 // serialized_request[8] - field id (1 << 3) + tag (2, for 'LEN') for field 1 of sub-message.
104- // Should equal 0x0a .
137+ // Should equal 10 .
105138 // serialized_request[9] - varint-encoded LEN of field 1 (random_bytes) of sub-message.
106- // Should equal 0x40 (decimal 64) .
139+ // Should equal 64 .
107140 // *** NOTE: WE NEED TO INCREMENT THIS BY 1. ***
108141 // serialized_request[10..73] - the contents of the random_bytes field. Should be 64 bytes in length.
109142 // serialized_request[74] - doesn't exist in the original message. We add an extra data byte here.
@@ -130,17 +163,21 @@ int verify_rpc_oversized_message_rejected(void) {
130163 serialized_request [7 ]++ ; // increment LEN byte for top-level field 5
131164 serialized_request [9 ]++ ; // increment LEN byte for nested field 1
132165 serialized_request [serialized_size ] = 0xaa ; // set the last byte to an arbitrary value
166+ serialized_size ++ ; // increment serialized_size since we added a byte
133167
134- pb_istream_t istream = pb_istream_from_buffer (serialized_request , serialized_size + 1 );
168+ // Create a stream which will read from the corrupted serialized buffer.
169+ pb_istream_t istream = pb_istream_from_buffer (serialized_request , serialized_size );
170+
171+ // Allocate a buffer for the serialized response.
135172 const size_t response_buffer_size = 2048 ; // 2048 bytes should be more than enough
136173 serialized_response = (uint8_t * ) calloc (1 , response_buffer_size );
137174 if (NULL == serialized_response ) {
138175 ERROR ("%s: calloc(1, %zu) failed" , __func__ , response_buffer_size );
139176 result = -1 ;
140177 goto out ;
141178 }
142- pb_ostream_t ostream2 = pb_ostream_from_buffer ( serialized_response , response_buffer_size );
143- ERROR ( "(next line is expected to show red text...)" );
179+ // Create a stream which will write to the response buffer.
180+ pb_ostream_t ostream = pb_ostream_from_buffer ( serialized_response , response_buffer_size );
144181
145182 // Now that we have a serialized buffer, try to pass it to handle_incoming_message().
146183 // This should fail because the InitWallet.random_bytes field has a length of 65 bytes,
@@ -149,23 +186,24 @@ int verify_rpc_oversized_message_rejected(void) {
149186 // NOTE: when building for nCipher, there are command hooks that would reject the command
150187 // because it's missing the tickets for key use authorization. But this doesn't matter for
151188 // this test case, because the protobuf parsing happens before that and fails first.
152- handle_incoming_message (& istream , & ostream2 );
153- const size_t actual_response_size = ostream2 .bytes_written ;
189+ ERROR ("(next line is expected to show red text...)" );
190+ handle_incoming_message (& istream , & ostream );
191+
192+ // Extract the response structure from the serialized_response buffer. It should be an error.
193+ const size_t actual_response_size = ostream .bytes_written ;
154194 if (actual_response_size == 0 ) {
155- ERROR ("%s: no response received from handle_incoming_message(): %s" , __func__ , PB_GET_ERROR (& ostream2 ));
195+ ERROR ("%s: no response received from handle_incoming_message(): %s" , __func__ , PB_GET_ERROR (& ostream ));
156196 result = -1 ;
157197 goto out ;
158198 }
159- pb_istream_t istream2 = pb_istream_from_buffer (serialized_response , actual_response_size );
160- InternalCommandResponse response ; // note: no need to init, pb_decode_delimited() does it
161- if (!pb_decode_delimited (& istream2 , InternalCommandResponse_fields , & response )) {
162- ERROR (
163- "%s: pb_decode_delimited(..., InternalCommandResponse_fields, ...) failed: %s" ,
164- __func__ ,
165- PB_GET_ERROR (& istream2 ));
199+ InternalCommandResponse response ; // note: no need to init, deserialize_from_buf() does it via pb_decode_delimited().
200+ if (!deserialize_from_buf (serialized_response , actual_response_size , InternalCommandResponse_fields , & response )) {
201+ ERROR ("%s: deserialize_from_buf() failed" , __func__ );
166202 result = -1 ;
167203 goto out ;
168204 }
205+
206+ // Check that the response contains an error.
169207 if (response .which_response != InternalCommandResponse_Error_tag ) {
170208 ERROR (
171209 "%s: wrong response tag: %d, expected: %d" ,
@@ -175,6 +213,7 @@ int verify_rpc_oversized_message_rejected(void) {
175213 result = -1 ;
176214 goto out ;
177215 }
216+ // Check that the error response contains the expected error code.
178217 if (response .response .Error .code != Result_COMMAND_DECODE_FAILED ) {
179218 ERROR (
180219 "%s: wrong response error code: %d, expected: %d" ,
@@ -184,11 +223,13 @@ int verify_rpc_oversized_message_rejected(void) {
184223 result = -1 ;
185224 goto out ;
186225 }
226+ // Check that the error response contains some message.
187227 if (!response .response .Error .has_message ) {
188228 ERROR ("%s: error response does not contain a 'message' field" , __func__ );
189229 result = -1 ;
190230 goto out ;
191231 }
232+ // Check that the error response contains the expected message.
192233 if (0 != strcmp ("Decode Input failed: bytes overflow" , response .response .Error .message )) {
193234 ERROR ("%s: error response contains unexpected message: %s" , __func__ , response .response .Error .message );
194235 result = -1 ;
@@ -198,5 +239,8 @@ int verify_rpc_oversized_message_rejected(void) {
198239out :
199240 free (serialized_request );
200241 free (serialized_response );
242+ if (result == 0 ) {
243+ INFO ("%s: ok" , __func__ );
244+ }
201245 return result ;
202246}
0 commit comments