diff --git a/modules/lua/lua_request.c b/modules/lua/lua_request.c index ba63584..c1ba74a 100644 --- a/modules/lua/lua_request.c +++ b/modules/lua/lua_request.c @@ -2193,23 +2193,20 @@ static int lua_websocket_greet(lua_State *L) return 0; } -static apr_status_t lua_websocket_readbytes(conn_rec* c, char* buffer, - apr_off_t len) +static apr_status_t lua_websocket_readbytes(conn_rec* c, + apr_bucket_brigade *brigade, + char* buffer, apr_off_t len) { - apr_bucket_brigade *brigade = apr_brigade_create(c->pool, c->bucket_alloc); + apr_size_t delivered; apr_status_t rv; + rv = ap_get_brigade(c->input_filters, brigade, AP_MODE_READBYTES, APR_BLOCK_READ, len); if (rv == APR_SUCCESS) { - if (!APR_BRIGADE_EMPTY(brigade)) { - apr_bucket* bucket = APR_BRIGADE_FIRST(brigade); - const char* data = NULL; - apr_size_t data_length = 0; - rv = apr_bucket_read(bucket, &data, &data_length, APR_BLOCK_READ); - if (rv == APR_SUCCESS) { - memcpy(buffer, data, len); - } - apr_bucket_delete(bucket); + delivered = len; + rv = apr_brigade_flatten(brigade, buffer, &delivered); + if ((rv == APR_SUCCESS) && (delivered < len)) { + rv = APR_INCOMPLETE; } } apr_brigade_cleanup(brigade); @@ -2239,35 +2236,28 @@ static int lua_websocket_peek(lua_State *L) static int lua_websocket_read(lua_State *L) { - apr_socket_t *sock; apr_status_t rv; int do_read = 1; int n = 0; - apr_size_t len = 1; apr_size_t plen = 0; unsigned short payload_short = 0; apr_uint64_t payload_long = 0; unsigned char *mask_bytes; char byte; - int plaintext; - - + apr_bucket_brigade *brigade; + conn_rec* c; + request_rec *r = ap_lua_check_request_rec(L, 1); - plaintext = ap_lua_ssl_is_https(r->connection) ? 0 : 1; + c = r->connection; - mask_bytes = apr_pcalloc(r->pool, 4); - sock = ap_get_conn_socket(r->connection); + + brigade = apr_brigade_create(r->pool, c->bucket_alloc); while (do_read) { do_read = 0; /* Get opcode and FIN bit */ - if (plaintext) { - rv = apr_socket_recv(sock, &byte, &len); - } - else { - rv = lua_websocket_readbytes(r->connection, &byte, 1); - } + rv = lua_websocket_readbytes(c, brigade, &byte, 1); if (rv == APR_SUCCESS) { unsigned char ubyte, fin, opcode, mask, payload; ubyte = (unsigned char)byte; @@ -2277,12 +2267,7 @@ static int lua_websocket_read(lua_State *L) opcode = ubyte & 0xf; /* Get the payload length and mask bit */ - if (plaintext) { - rv = apr_socket_recv(sock, &byte, &len); - } - else { - rv = lua_websocket_readbytes(r->connection, &byte, 1); - } + rv = lua_websocket_readbytes(c, brigade, &byte, 1); if (rv == APR_SUCCESS) { ubyte = (unsigned char)byte; /* Mask is the first bit */ @@ -2293,40 +2278,25 @@ static int lua_websocket_read(lua_State *L) /* Extended payload? */ if (payload == 126) { - len = 2; - if (plaintext) { - /* XXX: apr_socket_recv does not receive len bits, only up to len bits! */ - rv = apr_socket_recv(sock, (char*) &payload_short, &len); - } - else { - rv = lua_websocket_readbytes(r->connection, - (char*) &payload_short, 2); - } - payload_short = ntohs(payload_short); + rv = lua_websocket_readbytes(c, brigade, + (char*) &payload_short, 2); - if (rv == APR_SUCCESS) { - plen = payload_short; - } - else { + if (rv != APR_SUCCESS) { return 0; } + + plen = ntohs(payload_short); } /* Super duper extended payload? */ if (payload == 127) { - len = 8; - if (plaintext) { - rv = apr_socket_recv(sock, (char*) &payload_long, &len); - } - else { - rv = lua_websocket_readbytes(r->connection, - (char*) &payload_long, 8); - } - if (rv == APR_SUCCESS) { - plen = ap_ntoh64(&payload_long); - } - else { + rv = lua_websocket_readbytes(c, brigade, + (char*) &payload_long, 8); + + if (rv != APR_SUCCESS) { return 0; } + + plen = ap_ntoh64(&payload_long); } ap_log_rerror(APLOG_MARK, APLOG_DEBUG, 0, r, APLOGNO(03210) "Websocket: Reading %" APR_SIZE_T_FMT " (%s) bytes, masking is %s. %s", @@ -2335,46 +2305,27 @@ static int lua_websocket_read(lua_State *L) mask ? "on" : "off", fin ? "This is a final frame" : "more to follow"); if (mask) { - len = 4; - if (plaintext) { - rv = apr_socket_recv(sock, (char*) mask_bytes, &len); - } - else { - rv = lua_websocket_readbytes(r->connection, - (char*) mask_bytes, 4); - } + rv = lua_websocket_readbytes(c, brigade, + (char*) mask_bytes, 4); + if (rv != APR_SUCCESS) { return 0; } } if (plen < (HUGE_STRING_LEN*1024) && plen > 0) { apr_size_t remaining = plen; - apr_size_t received; - apr_off_t at = 0; char *buffer = apr_palloc(r->pool, plen+1); buffer[plen] = 0; - if (plaintext) { - while (remaining > 0) { - received = remaining; - rv = apr_socket_recv(sock, buffer+at, &received); - if (received > 0 ) { - remaining -= received; - at += received; - } - } - ap_log_rerror(APLOG_MARK, APLOG_TRACE1, 0, r, - "Websocket: Frame contained %" APR_OFF_T_FMT " bytes, pushed to Lua stack", - at); - } - else { - rv = lua_websocket_readbytes(r->connection, buffer, - remaining); - ap_log_rerror(APLOG_MARK, APLOG_TRACE1, 0, r, - "Websocket: SSL Frame contained %" APR_SIZE_T_FMT " bytes, "\ - "pushed to Lua stack", - remaining); + rv = lua_websocket_readbytes(c, brigade, buffer, remaining); + + if (rv != APR_SUCCESS) { + return 0; } + + ap_log_rerror(APLOG_MARK, APLOG_TRACE1, 0, r, + "Websocket: Frame contained %" APR_SIZE_T_FMT \ + " bytes, pushed to Lua stack", remaining); if (mask) { for (n = 0; n < plen; n++) { buffer[n] ^= mask_bytes[n%4]; @@ -2386,14 +2337,25 @@ static int lua_websocket_read(lua_State *L) return 2; } - /* Decide if we need to react to the opcode or not */ if (opcode == 0x09) { /* ping */ char frame[2]; - plen = 2; + apr_bucket *b; + frame[0] = 0x8A; frame[1] = 0; - apr_socket_send(sock, frame, &plen); /* Pong! */ + + /* Pong! */ + b = apr_bucket_transient_create(frame, 2, c->bucket_alloc); + APR_BRIGADE_INSERT_TAIL(brigade, b); + + rv = ap_pass_brigade(c->output_filters, brigade); + apr_brigade_cleanup(brigade); + + if (rv != APR_SUCCESS) { + return 0; + } + do_read = 1; } }