Skip to content

Commit c175ac9

Browse files
committed
Extend mqtt5_socks5_app to test IAM
1 parent cc74395 commit c175ac9

2 files changed

Lines changed: 318 additions & 12 deletions

File tree

bin/mqtt5_socks5_app/main.cpp

Lines changed: 228 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
*/
55

66
#include <aws/crt/Api.h>
7+
#include <aws/crt/auth/Credentials.h>
78
#include <aws/crt/crypto/Hash.h>
89
#include <aws/crt/http/HttpConnection.h>
910
#include <aws/crt/http/HttpRequestResponse.h>
1011
#include <aws/crt/io/Socks5ProxyOptions.h>
1112
#include <aws/crt/io/Uri.h>
1213

1314
#include <aws/crt/mqtt/Mqtt5Packets.h>
15+
#include <aws/iot/MqttCommon.h>
1416

1517
#include <aws/common/allocator.h>
1618
#include <aws/common/error.h>
@@ -27,6 +29,13 @@
2729
using namespace Aws::Crt;
2830
using namespace Aws::Crt::Mqtt5;
2931

32+
enum class CredentialsProviderSource {
33+
DefaultChain,
34+
Environment,
35+
Profile,
36+
Static
37+
};
38+
3039
struct app_ctx {
3140
Allocator *allocator = nullptr;
3241
Io::Uri uri;
@@ -48,6 +57,15 @@ struct app_ctx {
4857

4958
bool enable_tls = false;
5059
bool use_websocket = false;
60+
Aws::Crt::String region;
61+
CredentialsProviderSource credentials_source = CredentialsProviderSource::DefaultChain;
62+
Aws::Crt::String profile_name;
63+
Aws::Crt::String config_file;
64+
Aws::Crt::String credentials_file;
65+
Aws::Crt::String access_key_id;
66+
Aws::Crt::String secret_access_key;
67+
Aws::Crt::String session_token;
68+
bool port_overridden = false;
5169
};
5270

5371
static bool s_parse_proxy_uri(app_ctx &ctx, const char *proxy_arg) {
@@ -97,7 +115,16 @@ static void s_usage(int exit_code)
97115
fprintf(stderr, " --cert FILE: Client certificate file path (PEM format)\n");
98116
fprintf(stderr, " --key FILE: Private key file path (PEM format)\n");
99117
fprintf(stderr, " --ca-file FILE: CA certificate file path (PEM format)\n");
100-
fprintf(stderr, " --websocket: Use MQTT over WebSocket\n");
118+
fprintf(stderr, " --websocket: Use MQTT over WebSocket with SigV4 authentication\n");
119+
fprintf(stderr, " --region REGION: AWS Region for SigV4 signing when using WebSocket\n");
120+
fprintf(stderr,
121+
" --credential-source SOURCE: Credentials provider source (default-chain, environment, profile, static)\n");
122+
fprintf(stderr, " --profile NAME: AWS profile to use when credential source is profile\n");
123+
fprintf(stderr, " --config-file PATH: AWS config file override for profile credential source\n");
124+
fprintf(stderr, " --credentials-file PATH: AWS credentials file override for profile credential source\n");
125+
fprintf(stderr, " --access-key KEY: AWS access key for static credential source\n");
126+
fprintf(stderr, " --secret-key KEY: AWS secret access key for static credential source\n");
127+
fprintf(stderr, " --session-token TOKEN: AWS session token for static credential source (optional)\n");
101128
fprintf(stderr, " --verbose: Print detailed logging\n");
102129
fprintf(stderr, " --help: Display this message and exit\n");
103130
exit(exit_code);
@@ -111,6 +138,14 @@ static struct aws_cli_option s_long_options[] = {
111138
{"key", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'K'},
112139
{"ca-file", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'A'},
113140
{"websocket", AWS_CLI_OPTIONS_NO_ARGUMENT, NULL, 'W'},
141+
{"region", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'R'},
142+
{"credential-source", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'S'},
143+
{"profile", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'P'},
144+
{"config-file", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'F'},
145+
{"credentials-file", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'G'},
146+
{"access-key", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'I'},
147+
{"secret-key", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'J'},
148+
{"session-token", AWS_CLI_OPTIONS_REQUIRED_ARGUMENT, NULL, 'T'},
114149
{"verbose", AWS_CLI_OPTIONS_NO_ARGUMENT, NULL, 'v'},
115150
{"help", AWS_CLI_OPTIONS_NO_ARGUMENT, NULL, 'h'},
116151
{NULL, AWS_CLI_OPTIONS_NO_ARGUMENT, NULL, 0}, // Ensure proper termination
@@ -126,7 +161,7 @@ static void s_parse_options(int argc, char **argv, struct app_ctx &ctx)
126161
while (true)
127162
{
128163
int option_index = 0;
129-
int c = aws_cli_getopt_long(argc, argv, "b:p:x:C:K:A:Wvh", s_long_options, &option_index);
164+
int c = aws_cli_getopt_long(argc, argv, "b:p:x:C:K:A:WR:S:P:F:G:I:J:T:vh", s_long_options, &option_index);
130165
if (c == -1)
131166
{
132167
break;
@@ -139,6 +174,7 @@ static void s_parse_options(int argc, char **argv, struct app_ctx &ctx)
139174
break;
140175
case 'p':
141176
ctx.port = static_cast<uint16_t>(atoi(aws_cli_optarg));
177+
ctx.port_overridden = true;
142178
break;
143179
case 'x':
144180
if (!s_parse_proxy_uri(ctx, aws_cli_optarg)) {
@@ -157,6 +193,47 @@ static void s_parse_options(int argc, char **argv, struct app_ctx &ctx)
157193
case 'W':
158194
ctx.use_websocket = true;
159195
break;
196+
case 'R':
197+
ctx.region = aws_cli_optarg;
198+
break;
199+
case 'S': {
200+
Aws::Crt::String source = aws_cli_optarg;
201+
std::transform(source.begin(), source.end(), source.begin(), [](unsigned char ch) {
202+
return static_cast<char>(std::tolower(ch));
203+
});
204+
if (source == "default-chain") {
205+
ctx.credentials_source = CredentialsProviderSource::DefaultChain;
206+
} else if (source == "environment") {
207+
ctx.credentials_source = CredentialsProviderSource::Environment;
208+
} else if (source == "profile") {
209+
ctx.credentials_source = CredentialsProviderSource::Profile;
210+
} else if (source == "static") {
211+
ctx.credentials_source = CredentialsProviderSource::Static;
212+
} else {
213+
std::cerr << "Unknown credential source '" << aws_cli_optarg
214+
<< "'. Expected one of: default-chain, environment, profile, static." << std::endl;
215+
s_usage(1);
216+
}
217+
break;
218+
}
219+
case 'P':
220+
ctx.profile_name = aws_cli_optarg;
221+
break;
222+
case 'F':
223+
ctx.config_file = aws_cli_optarg;
224+
break;
225+
case 'G':
226+
ctx.credentials_file = aws_cli_optarg;
227+
break;
228+
case 'I':
229+
ctx.access_key_id = aws_cli_optarg;
230+
break;
231+
case 'J':
232+
ctx.secret_access_key = aws_cli_optarg;
233+
break;
234+
case 'T':
235+
ctx.session_token = aws_cli_optarg;
236+
break;
160237
case 'v':
161238
ctx.LogLevel = Aws::Crt::LogLevel::Trace;
162239
break;
@@ -169,6 +246,14 @@ static void s_parse_options(int argc, char **argv, struct app_ctx &ctx)
169246
}
170247
}
171248

249+
if (ctx.use_websocket)
250+
{
251+
ctx.enable_tls = true;
252+
if (!ctx.port_overridden && ctx.port == 1883 && !ctx.uri.GetPort())
253+
{
254+
ctx.port = 443;
255+
}
256+
}
172257
if (!ctx.enable_tls) {
173258
ctx.enable_tls = ctx.cacert || ctx.cert || ctx.key;
174259
}
@@ -197,14 +282,54 @@ static void s_parse_options(int argc, char **argv, struct app_ctx &ctx)
197282

198283
void PrintAppOptions(const app_ctx &ctx) {
199284
std::cout << "================= MQTT5 SOCKS5 APP OPTIONS =================" << std::endl;
200-
Aws::Crt::String hostNameStr = Aws::Crt::String((const char*)ctx.uri.GetHostName().ptr, ctx.uri.GetHostName().len);
285+
Aws::Crt::String hostNameStr =
286+
Aws::Crt::String((const char *)ctx.uri.GetHostName().ptr, ctx.uri.GetHostName().len);
201287
std::cout << "Broker Host: " << hostNameStr << std::endl;
202288
std::cout << "Broker Port: " << ctx.port << std::endl;
203289
std::cout << "TLS Enabled: " << (ctx.enable_tls ? "yes" : "no") << std::endl;
204290
if (ctx.cacert) std::cout << "CA Cert: " << ctx.cacert << std::endl;
205-
if (ctx.cert) std::cout << "Client Cert: " << ctx.cert << std::endl;
206-
if (ctx.key) std::cout << "Client Key: " << ctx.key << std::endl;
291+
if (ctx.cert && !ctx.use_websocket) std::cout << "Client Cert: " << ctx.cert << std::endl;
292+
if (ctx.key && !ctx.use_websocket) std::cout << "Client Key: " << ctx.key << std::endl;
207293
std::cout << "Connect Timeout (ms): " << ctx.connect_timeout << std::endl;
294+
if (ctx.use_websocket) {
295+
std::cout << "Using WebSocket: yes" << std::endl;
296+
if (!ctx.region.empty()) {
297+
std::cout << "AWS Region: " << ctx.region << std::endl;
298+
}
299+
std::cout << "Credentials Source: ";
300+
switch (ctx.credentials_source) {
301+
case CredentialsProviderSource::DefaultChain:
302+
std::cout << "default-chain";
303+
break;
304+
case CredentialsProviderSource::Environment:
305+
std::cout << "environment";
306+
break;
307+
case CredentialsProviderSource::Profile:
308+
std::cout << "profile";
309+
if (!ctx.profile_name.empty()) {
310+
std::cout << " (profile=" << ctx.profile_name << ")";
311+
}
312+
if (!ctx.config_file.empty()) {
313+
std::cout << " (config-file=" << ctx.config_file << ")";
314+
}
315+
if (!ctx.credentials_file.empty()) {
316+
std::cout << " (credentials-file=" << ctx.credentials_file << ")";
317+
}
318+
break;
319+
case CredentialsProviderSource::Static:
320+
std::cout << "static";
321+
if (!ctx.access_key_id.empty()) {
322+
std::cout << " (access-key provided)";
323+
}
324+
if (!ctx.session_token.empty()) {
325+
std::cout << " (session token provided)";
326+
}
327+
break;
328+
}
329+
std::cout << std::endl;
330+
} else {
331+
std::cout << "Using WebSocket: no" << std::endl;
332+
}
208333
if (ctx.use_proxy && ctx.socks5_proxy_options && !ctx.proxy_host_storage.empty()) {
209334
std::cout << "SOCKS5 Proxy Host: " << ctx.proxy_host_storage << std::endl;
210335
std::cout << "SOCKS5 Proxy Port: " << ctx.proxy_port << std::endl;
@@ -241,6 +366,30 @@ int main(int argc, char **argv)
241366
app_ctx.port = app_ctx.uri.GetPort();
242367
}
243368

369+
if (app_ctx.use_websocket)
370+
{
371+
if (app_ctx.region.empty())
372+
{
373+
std::cerr << "[ERROR] --region must be specified when using --websocket for SigV4 authentication."
374+
<< std::endl;
375+
return 1;
376+
}
377+
378+
if (app_ctx.credentials_source == CredentialsProviderSource::Static &&
379+
(app_ctx.access_key_id.empty() || app_ctx.secret_access_key.empty()))
380+
{
381+
std::cerr << "[ERROR] Static credentials require both --access-key and --secret-key when using WebSocket."
382+
<< std::endl;
383+
return 1;
384+
}
385+
386+
if (app_ctx.cert || app_ctx.key)
387+
{
388+
std::cout << "[INFO] Client certificate and key are ignored when using WebSocket SigV4 authentication."
389+
<< std::endl;
390+
}
391+
}
392+
244393
/**********************************************************
245394
* LOGGING
246395
**********************************************************/
@@ -255,7 +404,7 @@ int main(int argc, char **argv)
255404
apiHandle.InitializeLogging(app_ctx.LogLevel, stderr);
256405
}
257406

258-
bool useTls = app_ctx.enable_tls;
407+
bool useTls = app_ctx.use_websocket || app_ctx.enable_tls;
259408

260409
auto hostName = app_ctx.uri.GetHostName();
261410

@@ -268,7 +417,18 @@ int main(int argc, char **argv)
268417
Io::TlsConnectionOptions tlsConnectionOptions;
269418
if (useTls)
270419
{
271-
if (app_ctx.cert && app_ctx.key)
420+
if (app_ctx.use_websocket)
421+
{
422+
std::cout << "MQTT5: Configuring TLS for WebSocket connection with SigV4 authentication." << std::endl;
423+
tlsCtxOptions = Io::TlsContextOptions::InitDefaultClient();
424+
if (!tlsCtxOptions)
425+
{
426+
std::cout << "Failed to create TLS options for WebSocket with error "
427+
<< aws_error_debug_str(tlsCtxOptions.LastError()) << std::endl;
428+
exit(1);
429+
}
430+
}
431+
else if (app_ctx.cert && app_ctx.key)
272432
{
273433
std::cout << "MQTT5: Configuring TLS with cert " << app_ctx.cert << " and key " << app_ctx.key
274434
<< std::endl;
@@ -375,15 +535,71 @@ int main(int argc, char **argv)
375535
}
376536

377537
// Configure WebSocket if requested
538+
std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> websocketCredentialsProvider;
378539
if (app_ctx.use_websocket) {
379540
std::cout << "**********************************************************" << std::endl;
380-
std::cout << "MQTT5: Configuring WebSocket...." << std::endl;
541+
std::cout << "MQTT5: Configuring WebSocket with SigV4 authentication...." << std::endl;
381542
std::cout << "**********************************************************" << std::endl;
382-
// Use the default handshake transform (no-op)
543+
544+
Aws::Iot::WebsocketConfig websocketConfig(app_ctx.region, &clientBootstrap, app_ctx.allocator);
545+
546+
switch (app_ctx.credentials_source) {
547+
case CredentialsProviderSource::DefaultChain:
548+
// Already handled by the default constructor using the client bootstrap.
549+
break;
550+
case CredentialsProviderSource::Environment:
551+
websocketCredentialsProvider =
552+
Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderEnvironment(app_ctx.allocator);
553+
break;
554+
case CredentialsProviderSource::Profile: {
555+
Aws::Crt::Auth::CredentialsProviderProfileConfig profileConfig;
556+
profileConfig.Bootstrap = &clientBootstrap;
557+
if (!app_ctx.profile_name.empty()) {
558+
profileConfig.ProfileNameOverride = aws_byte_cursor_from_c_str(app_ctx.profile_name.c_str());
559+
}
560+
if (!app_ctx.config_file.empty()) {
561+
profileConfig.ConfigFileNameOverride =
562+
aws_byte_cursor_from_c_str(app_ctx.config_file.c_str());
563+
}
564+
if (!app_ctx.credentials_file.empty()) {
565+
profileConfig.CredentialsFileNameOverride =
566+
aws_byte_cursor_from_c_str(app_ctx.credentials_file.c_str());
567+
}
568+
websocketCredentialsProvider =
569+
Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderProfile(profileConfig, app_ctx.allocator);
570+
break;
571+
}
572+
case CredentialsProviderSource::Static: {
573+
Aws::Crt::Auth::CredentialsProviderStaticConfig staticConfig;
574+
staticConfig.AccessKeyId = aws_byte_cursor_from_c_str(app_ctx.access_key_id.c_str());
575+
staticConfig.SecretAccessKey = aws_byte_cursor_from_c_str(app_ctx.secret_access_key.c_str());
576+
if (!app_ctx.session_token.empty()) {
577+
staticConfig.SessionToken = aws_byte_cursor_from_c_str(app_ctx.session_token.c_str());
578+
}
579+
websocketCredentialsProvider =
580+
Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderStatic(staticConfig, app_ctx.allocator);
581+
break;
582+
}
583+
}
584+
585+
if (app_ctx.credentials_source != CredentialsProviderSource::DefaultChain) {
586+
if (!websocketCredentialsProvider) {
587+
std::cerr << "[ERROR] Failed to create credentials provider for WebSocket connection." << std::endl;
588+
return 1;
589+
}
590+
websocketConfig = Aws::Iot::WebsocketConfig(app_ctx.region, websocketCredentialsProvider, app_ctx.allocator);
591+
}
592+
593+
auto websocketConfigShared = std::make_shared<Aws::Iot::WebsocketConfig>(websocketConfig);
594+
383595
mqtt5OptionsBuilder.WithWebsocketHandshakeTransformCallback(
384-
[](std::shared_ptr<Aws::Crt::Http::HttpRequest> req,
385-
const Aws::Crt::Mqtt5::OnWebSocketHandshakeInterceptComplete &onComplete) {
386-
onComplete(req, AWS_ERROR_SUCCESS);
596+
[websocketConfigShared](std::shared_ptr<Aws::Crt::Http::HttpRequest> req,
597+
const Aws::Crt::Mqtt5::OnWebSocketHandshakeInterceptComplete &onComplete) {
598+
auto signingComplete = [onComplete](
599+
const std::shared_ptr<Aws::Crt::Http::HttpRequest> &signedRequest,
600+
int errorCode) { onComplete(signedRequest, errorCode); };
601+
auto signerConfig = websocketConfigShared->CreateSigningConfigCb();
602+
websocketConfigShared->Signer->SignRequest(req, *signerConfig, signingComplete);
387603
});
388604
}
389605

0 commit comments

Comments
 (0)