commit 73018cba2753a83ab3fb22a769542eb671d1da0b
parent 4adcdc5d20042ac1cd1afa491db8ae82a919ea1e
Author: erai <erai@omiltem.net>
Date: Wed, 5 Jun 2024 11:03:29 -0400
basic sshd command io
Diffstat:
M | sshd.c | | | 315 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------- |
M | syscall.c | | | 2 | ++ |
2 files changed, 257 insertions(+), 60 deletions(-)
diff --git a/sshd.c b/sshd.c
@@ -311,7 +311,7 @@ enum {
SSH_CR_EXEC,
SSH_CR_WINCH,
SSH_CR_SIGNAL,
- SSH_CR_EXIT,
+ SSH_CR_EXIT_STATUS,
SSH_CR_EXIT_SIGNAL,
}
@@ -818,7 +818,7 @@ struct ssh_channel_open {
channel_type: ssh_str;
channel: int;
initial_window_size: int;
- max_packet_size: int;
+ maximum_packet_size: int;
}
decode_channel_open(p: *ssh_channel_open, ctx: *sshd_ctx) {
@@ -831,7 +831,7 @@ decode_channel_open(p: *ssh_channel_open, ctx: *sshd_ctx) {
decode_str(&p.channel_type, ctx);
decode_u32(&p.channel, ctx);
decode_u32(&p.initial_window_size, ctx);
- decode_u32(&p.max_packet_size, ctx);
+ decode_u32(&p.maximum_packet_size, ctx);
if ctx.index != ctx.framelen {
die("trailing data");
@@ -1066,8 +1066,8 @@ encode_channel_request(p: *ssh_channel_request, ctx: *sshd_ctx) {
tag = SSH_MSG_CHANNEL_REQUEST;
encode_u8(&tag, ctx);
encode_u32(&p.channel, ctx);
- if p.kind == SSH_CR_EXIT {
- set_str(&p.request, "exit");
+ if p.kind == SSH_CR_EXIT_STATUS {
+ set_str(&p.request, "exit-status");
encode_str(&p.request, ctx);
p.want_reply = 0;
encode_bool(&p.want_reply, ctx);
@@ -1360,10 +1360,15 @@ dosession(ctx: *sshd_ctx) {
die("not session");
}
+ ctx.stdout_window = cco.initial_window_size;
+ if cco.maximum_packet_size < 1024 {
+ die("max packet size too small");
+ }
+
scc.channel = 0;
scc.sender_channel = 0;
- scc.initial_window_size = 16 * 1024;
- scc.maximum_packet_size = 16 * 1024;
+ scc.initial_window_size = ctx.stdin_size;
+ scc.maximum_packet_size = ctx.bufsz - 64;
encode_channel_open_confirmation(&scc, ctx);
write_frame(ctx);
@@ -1378,19 +1383,26 @@ dodisconnect(ctx: *sshd_ctx) {
dowindow(ctx: *sshd_ctx) {
var cwa: ssh_channel_window_adjust;
decode_channel_window_adjust(&cwa, ctx);
- // increase window
+ ctx.stdout_window = ctx.stdout_window + cwa.window;
}
dodata(ctx: *sshd_ctx) {
var cd: ssh_channel_data;
decode_channel_data(&cd, ctx);
- // buffer data
- // send window
+ if ctx.stdin_eof {
+ return;
+ }
+ if cd.data.len > ctx.stdin_size - ctx.stdin_fill {
+ die("stdin overflow");
+ }
+ memcpy(&ctx.stdin_buf[ctx.stdin_fill], cd.data.s, cd.data.len);
+ ctx.stdin_fill = ctx.stdin_fill + cd.data.len;
}
doeof(ctx: *sshd_ctx) {
var ce: ssh_channel_eof;
decode_channel_eof(&ce, ctx);
+ ctx.stdin_eof = 1;
}
doclose(ctx: *sshd_ctx) {
@@ -1467,18 +1479,20 @@ ssh_spawn(ctx: *sshd_ctx, argv: **byte) {
ctx.child_stderr = stderr_read;
}
-dopty(cr: *ssh_cr_pty, ctx: *sshd_ctx) {
+dopty(cr: *ssh_cr_pty, ctx: *sshd_ctx): int {
// allocate pty
+ return 0;
}
-doshell(cr: *ssh_cr_shell, ctx: *sshd_ctx) {
+doshell(cr: *ssh_cr_shell, ctx: *sshd_ctx): int {
var argv: _argv4;
argv.arg0 = "/bin/sh";
argv.arg1 = 0:*byte;
ssh_spawn(ctx, &argv.arg0);
+ return 1;
}
-doexec(cr: *ssh_cr_exec, ctx: *sshd_ctx) {
+doexec(cr: *ssh_cr_exec, ctx: *sshd_ctx): int {
var argv: _argv4;
var cmd: *byte;
cmd = alloc(ctx.a, cr.command.len + 1);
@@ -1490,32 +1504,36 @@ doexec(cr: *ssh_cr_exec, ctx: *sshd_ctx) {
argv.arg3 = 0:*byte;
ssh_spawn(ctx, &argv.arg0);
free(ctx.a, cmd);
+ return 1;
}
-dowinch(cr: *ssh_cr_winch, ctx: *sshd_ctx) {
+dowinch(cr: *ssh_cr_winch, ctx: *sshd_ctx): int {
// window change
+ return 0;
}
-dosignal(cr: *ssh_cr_signal, ctx: *sshd_ctx) {
+dosignal(cr: *ssh_cr_signal, ctx: *sshd_ctx): int {
// signal
+ return 0;
}
dorequest(ctx: *sshd_ctx) {
var cr: ssh_channel_request;
var ss: ssh_channel_success;
var sf: ssh_channel_failure;
+ var ok: int;
decode_channel_request(&cr, ctx);
if cr.kind == SSH_CR_PTY {
- dopty(&cr.pty, ctx);
+ ok = dopty(&cr.pty, ctx);
} else if cr.kind == SSH_CR_SHELL {
- doshell(&cr.shell, ctx);
+ ok = doshell(&cr.shell, ctx);
} else if cr.kind == SSH_CR_EXEC {
- doexec(&cr.exec, ctx);
+ ok = doexec(&cr.exec, ctx);
} else if cr.kind == SSH_CR_WINCH {
- dowinch(&cr.winch, ctx);
+ ok = dowinch(&cr.winch, ctx);
} else if cr.kind == SSH_CR_SIGNAL {
- dosignal(&cr.signal, ctx);
+ ok = dosignal(&cr.signal, ctx);
} else {
die("unknown request");
if cr.want_reply {
@@ -1527,9 +1545,15 @@ dorequest(ctx: *sshd_ctx) {
}
if cr.want_reply {
- ss.channel = 0;
- encode_channel_success(&ss, ctx);
- write_frame(ctx);
+ if ok {
+ ss.channel = 0;
+ encode_channel_success(&ss, ctx);
+ write_frame(ctx);
+ } else {
+ sf.channel = 0;
+ encode_channel_failure(&sf, ctx);
+ write_frame(ctx);
+ }
}
}
@@ -1542,24 +1566,47 @@ struct pfd4 {
reset_pfd(pfd: *int, ctx: *sshd_ctx): int {
var n: int;
+ var events: int;
n = 0;
- pfd[n] = ctx.fd | (POLLIN << 32);
+ if (
+ // There is a window update to send
+ (ctx.stdin_window > 1024)
+ // There is stdout to send
+ || (ctx.stdout_fill > 0 && ctx.stdout_window > 0)
+ // There is stderr to send
+ || (ctx.stderr_fill > 0 && ctx.stdout_window > 0)
+ // Child exited and there's no more stdout/stderr
+ || (ctx.has_child
+ && ctx.stdout_fill == 0
+ && ctx.stderr_fill == 0
+ && ctx.child_stdout == -1
+ && ctx.child_stderr == -1
+ && ctx.child_pid == -1)
+ ) {
+ events = POLLIN | POLLOUT;
+ } else {
+ events = POLLIN;
+ }
+ pfd[n] = ctx.fd | (events << 32);
n = n + 1;
- if ctx.child_stdin >= 0 {
- pfd[n] = ctx.child_stdin | (POLLOUT << 32);
+ if ctx.child_stdin >= 0 && (ctx.stdin_fill > 0 || ctx.stdin_eof) {
+ events = POLLOUT;
+ pfd[n] = ctx.child_stdin | (events << 32);
n = n + 1;
}
- if ctx.child_stdout >= 0 {
- pfd[n] = ctx.child_stdout | (POLLIN << 32);
+ if ctx.child_stdout >= 0 && ctx.stdout_fill < ctx.stderr_size {
+ events = POLLIN;
+ pfd[n] = ctx.child_stdout | (events << 32);
n = n + 1;
}
- if ctx.child_stderr >= 0 {
- pfd[n] = ctx.child_stderr | (POLLIN << 32);
+ if ctx.child_stderr >= 0 && ctx.stderr_fill < ctx.stderr_size {
+ events = POLLIN;
+ pfd[n] = ctx.child_stderr | (events << 32);
n = n + 1;
}
@@ -1568,48 +1615,171 @@ reset_pfd(pfd: *int, ctx: *sshd_ctx): int {
poll_client(revents: int, ctx: *sshd_ctx) {
var tag: int;
+ var len: int;
+ var swa: ssh_channel_window_adjust;
+ var sd: ssh_channel_data;
+ var sed: ssh_channel_extended_data;
+ var sr: ssh_channel_request;
+ var sc: ssh_channel_close;
+ var se: ssh_channel_eof;
+ if revents & POLLOUT {
+ if ctx.stdin_window > 1024 {
+ swa.channel = 0;
+ swa.window = ctx.stdin_window;
+ encode_channel_window_adjust(&swa, ctx);
+ write_frame(ctx);
+ ctx.stdin_window = 0;
+ } else if ctx.stdout_fill > 0 && ctx.stdout_window > 0 {
+ len = ctx.stdout_fill;
+ if len > 1024 {
+ len = 1024;
+ }
- read_frame(ctx);
+ if len > ctx.stdout_window {
+ len = ctx.stdout_window;
+ }
+
+ sd.channel = 0;
+ sd.data.s = ctx.stdout_buf;
+ sd.data.len = len;
+
+ memcpy(ctx.stdout_buf, &ctx.stdout_buf[len], ctx.stdout_fill - len);
+ ctx.stdout_window = ctx.stdout_window - len;
+ ctx.stdout_fill = ctx.stdout_fill - len;
+
+ encode_channel_data(&sd, ctx);
+ write_frame(ctx);
+ } else if ctx.stderr_fill > 0 && ctx.stdout_window > 0 {
+ len = ctx.stderr_fill;
+ if len > 1024 {
+ len = 1024;
+ }
+
+ if len > ctx.stdout_window {
+ len = ctx.stdout_window;
+ }
+
+ sed.channel = 0;
+ sed.code = 1;
+ sed.data.s = ctx.stderr_buf;
+ sed.data.len = len;
+
+ memcpy(ctx.stderr_buf, &ctx.stderr_buf[len], ctx.stderr_fill - len);
+ ctx.stdout_window = ctx.stdout_window - len;
+ ctx.stderr_fill = ctx.stderr_fill - len;
+
+ encode_channel_extended_data(&sed, ctx);
+ write_frame(ctx);
+ } else {
+ se.channel = 0;
+ encode_channel_eof(&se, ctx);
+ write_frame(ctx);
+
+ sr.channel = 0;
+ sr.want_reply = 0;
+ sr.kind = SSH_CR_EXIT_STATUS;
+ sr.exit.status = ctx.exit_status;
+ encode_channel_request(&sr, ctx);
+ write_frame(ctx);
+
+ sc.channel = 0;
+ encode_channel_close(&sc, ctx);
+ write_frame(ctx);
+
+ loop {
+ read_frame(ctx);
+ if ctx.frame[0]:int == SSH_MSG_DISCONNECT {
+ break;
+ }
+ }
- tag = ctx.frame[0]:int;
- if tag == SSH_MSG_DISCONNECT {
- dodisconnect(ctx);
- } else if tag == SSH_MSG_KEXINIT {
- dokex(ctx);
- } else if tag == SSH_MSG_CHANNEL_WINDOW_ADJUST {
- dowindow(ctx);
- } else if tag == SSH_MSG_CHANNEL_DATA {
- dodata(ctx);
- } else if tag == SSH_MSG_CHANNEL_EOF {
- doeof(ctx);
- } else if tag == SSH_MSG_CHANNEL_CLOSE {
- doclose(ctx);
- } else if tag == SSH_MSG_CHANNEL_REQUEST {
- dorequest(ctx);
+ exit(0);
+ }
} else {
- die("invalid packet");
+ read_frame(ctx);
+
+ tag = ctx.frame[0]:int;
+ if tag == SSH_MSG_DISCONNECT {
+ dodisconnect(ctx);
+ } else if tag == SSH_MSG_KEXINIT {
+ dokex(ctx);
+ } else if tag == SSH_MSG_CHANNEL_WINDOW_ADJUST {
+ dowindow(ctx);
+ } else if tag == SSH_MSG_CHANNEL_DATA {
+ dodata(ctx);
+ } else if tag == SSH_MSG_CHANNEL_EOF {
+ doeof(ctx);
+ } else if tag == SSH_MSG_CHANNEL_CLOSE {
+ doclose(ctx);
+ } else if tag == SSH_MSG_CHANNEL_REQUEST {
+ dorequest(ctx);
+ } else {
+ die("invalid packet");
+ }
}
}
-poll_stdin(ctx: *sshd_ctx) {
- if revents & POLLERR {
+poll_stdin(revents: int, ctx: *sshd_ctx) {
+ var ret: int;
+ if ctx.stdin_fill == 0 && ctx.stdin_eof {
+ close(ctx.child_stdin);
+ ctx.child_stdin = -1;
+ return;
+ }
+ ret = write(ctx.child_stdin, ctx.stdin_buf, ctx.stdin_fill);
+ if ret == -EINTR {
+ return;
+ }
+ if ret == -EPIPE {
close(ctx.child_stdin);
ctx.child_stdin = -1;
+ return;
}
+ memcpy(ctx.stdin_buf, &ctx.stdin_buf[ret], ctx.stdin_fill - ret);
+ ctx.stdin_fill = ctx.stdin_fill - ret;
+ ctx.stdin_window = ctx.stdin_window + ret;
}
-poll_stdout(ctx: *sshd_ctx) {
- var a: byte;
- if read(ctx.child_stdout, &a, 1) == 0 {
+poll_stdout(revents: int, ctx: *sshd_ctx) {
+ var ret: int;
+ ret = read(ctx.child_stdout, &ctx.stdout_buf[ctx.stdout_fill], ctx.stdout_size - ctx.stdout_fill);
+ if ret == -EINTR {
+ return;
+ }
+ if ret == 0 {
close(ctx.child_stdout);
ctx.child_stdout = -1;
+ return;
}
+ ctx.stdout_fill = ctx.stdout_fill + ret;
}
-poll_stderr(ctx: *sshd_ctx) {
- if read(ctx.child_stderr, &a, 1) == 0 {
+poll_stderr(revents: int, ctx: *sshd_ctx) {
+ var ret: int;
+ ret = read(ctx.child_stderr, &ctx.stderr_buf[ctx.stderr_fill], ctx.stderr_size - ctx.stderr_fill);
+ if ret == -EINTR {
+ return;
+ }
+ if ret == 0 {
close(ctx.child_stderr);
ctx.child_stderr = -1;
+ return;
+ }
+ ctx.stderr_fill = ctx.stderr_fill + ret;
+}
+
+poll_exit(ctx: *sshd_ctx) {
+ var ret: int;
+ var status: int;
+ loop {
+ ret = wait(-1, &status, WNOHANG);
+ if ret <= 0 {
+ break;
+ }
+ if ret == ctx.child_pid {
+ ctx.child_pid = -1;
+ ctx.exit_status = status;
+ }
}
}
@@ -1624,6 +1794,8 @@ client_loop(ctx: *sshd_ctx) {
p = &_p.p0;
loop {
+ poll_exit(ctx);
+
n = reset_pfd(p, ctx);
if poll(p, n, -1) == -EINTR {
continue;
@@ -1724,6 +1896,19 @@ struct sshd_ctx {
child_stdin: int;
child_stdout: int;
child_stderr: int;
+ stdin_window: int;
+ stdout_window: int;
+ stdin_buf: *byte;
+ stdin_size: int;
+ stdin_fill: int;
+ stdin_eof: int;
+ stdout_buf: *byte;
+ stdout_fill: int;
+ stdout_size: int;
+ stderr_buf: *byte;
+ stderr_fill: int;
+ stderr_size: int;
+ exit_status: int;
}
format_key(d: **byte, dlen: *int, k: *byte, ctx: *sshd_ctx) {
@@ -1742,11 +1927,6 @@ format_key(d: **byte, dlen: *int, k: *byte, ctx: *sshd_ctx) {
}
dosigchld() {
- loop {
- if wait(-1, 0:*int, WNOHANG) < 0 {
- break;
- }
- }
}
_restorer();
@@ -1770,6 +1950,7 @@ main(argc: int, argv: **byte, envp: **byte) {
setup_alloc(&a);
signal(SIGCHLD, dosigchld);
+ signal(SIGPIPE, SIG_IGN:func());
bzero((&ctx):*byte, sizeof(ctx));
@@ -1779,6 +1960,7 @@ main(argc: int, argv: **byte, envp: **byte) {
ed25519_pub((&ctx.pub):*byte, (&ctx.priv):*byte);
+ ctx.child_pid = -1;
ctx.child_stdin = -1;
ctx.child_stdout = -1;
ctx.child_stderr = -1;
@@ -1790,6 +1972,13 @@ main(argc: int, argv: **byte, envp: **byte) {
ctx.ckex = alloc(ctx.a, 4096);
ctx.sver = "SSH-2.0-omiltem";
+ ctx.stdin_size = 64 * 1024;
+ ctx.stdin_buf = alloc(ctx.a, ctx.stdin_size);
+ ctx.stdout_size = 64 * 1024;
+ ctx.stdout_buf = alloc(ctx.a, ctx.stdout_size);
+ ctx.stderr_size = 64 * 1024;
+ ctx.stderr_buf = alloc(ctx.a, ctx.stderr_size);
+
ctx.username = "erai";
format_key(&ctx.hostkey, &ctx.hostkeylen, (&ctx.pub):*byte, &ctx);
format_key(&ctx.userkey, &ctx.userkeylen, (&ctx.userpub):*byte, &ctx);
@@ -1799,7 +1988,7 @@ main(argc: int, argv: **byte, envp: **byte) {
die("failed to open socket");
}
- port = 2222;
+ port = 22;
sa.fpa = AF_INET | ((port & 0xff) << 24) | (((port >> 8) & 0xff) << 16);
sa.pad = 0;
if bind(fd, (&sa):*byte, sizeof(sa)) != 0 {
@@ -1811,6 +2000,12 @@ main(argc: int, argv: **byte, envp: **byte) {
}
loop {
+ loop {
+ if wait(-1, 0:*int, WNOHANG) <= 0 {
+ break;
+ }
+ }
+
ctx.fd = accept(fd, 0:*byte, 0:*int);
if ctx.fd == -EINTR {
continue;
diff --git a/syscall.c b/syscall.c
@@ -8,6 +8,7 @@ enum {
O_DIRECTORY = 0x1000,
EINTR = 4,
+ EPIPE = 32,
AF_INET = 2,
SOCK_STREAM = 1,
@@ -25,6 +26,7 @@ enum {
SIG_IGN = 1,
SIGINT = 2,
+ SIGPIPE = 13,
SIGALRM = 14,
SIGCHLD = 17,
SIGWINCH = 28,