@@ -1251,6 +1251,35 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
1251
1251
self .assertIn ("Setting initial namespace not supported by the DBR version" ,
1252
1252
str (cm .exception ))
1253
1253
1254
+ @patch ("databricks.sql.thrift_backend.TCLIService.Client" )
1255
+ @patch ("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response" )
1256
+ def test_execute_command_sets_complex_type_fields_correctly (self , mock_handle_execute_response ,
1257
+ tcli_service_class ):
1258
+ tcli_service_instance = tcli_service_class .return_value
1259
+ # Iterate through each possible combination of native types (True, False and unset)
1260
+ for (complex , timestamp , decimals ) in itertools .product (
1261
+ [True , False , None ], [True , False , None ], [True , False , None ]):
1262
+ complex_arg_types = {}
1263
+ if complex is not None :
1264
+ complex_arg_types ["_use_arrow_native_complex_types" ] = complex
1265
+ if timestamp is not None :
1266
+ complex_arg_types ["_use_arrow_native_timestamps" ] = timestamp
1267
+ if decimals is not None :
1268
+ complex_arg_types ["_use_arrow_native_decimals" ] = decimals
1269
+
1270
+ thrift_backend = ThriftBackend ("foobar" , 443 , "path" , [], ** complex_arg_types )
1271
+ thrift_backend .execute_command (Mock (), Mock (), 100 , 100 , Mock ())
1272
+
1273
+ t_execute_statement_req = tcli_service_instance .ExecuteStatement .call_args [0 ][0 ]
1274
+ # If the value is unset, the native type should default to True
1275
+ self .assertEqual (t_execute_statement_req .useArrowNativeTypes .timestampAsArrow ,
1276
+ complex_arg_types .get ("_use_arrow_native_timestamps" , True ))
1277
+ self .assertEqual (t_execute_statement_req .useArrowNativeTypes .decimalAsArrow ,
1278
+ complex_arg_types .get ("_use_arrow_native_decimals" , True ))
1279
+ self .assertEqual (t_execute_statement_req .useArrowNativeTypes .complexTypesAsArrow ,
1280
+ complex_arg_types .get ("_use_arrow_native_complex_types" , True ))
1281
+ self .assertFalse (t_execute_statement_req .useArrowNativeTypes .intervalTypesAsArrow )
1282
+
1254
1283
1255
1284
if __name__ == '__main__' :
1256
1285
unittest .main ()
0 commit comments