diff --git a/lib/net.js b/lib/net.js index 4f659ba06c3..f513a9a8029 100644 --- a/lib/net.js +++ b/lib/net.js @@ -17,6 +17,8 @@ var accept = process.accept; var close = process.close; var shutdown = process.shutdown; var read = process.read; +var recvMsg = process.recvMsg; +var sendFD = process.sendFD; var write = process.write; var toRead = process.toRead; var setNoDelay = process.setNoDelay; @@ -28,7 +30,6 @@ var EINPROGRESS = process.EINPROGRESS; var ENOENT = process.ENOENT; var END_OF_FILE = 42; - function Socket (peerInfo) { process.EventEmitter.call(); @@ -37,7 +38,7 @@ function Socket (peerInfo) { // Allocated on demand. self.recvBuffer = null; - self.readWatcher = new IOWatcher() + self.readWatcher = new IOWatcher(); self.readWatcher.callback = function () { // If this is the first recv (recvBuffer doesn't exist) or we've used up // most of the recvBuffer, allocate a new one. @@ -47,10 +48,23 @@ function Socket (peerInfo) { } debug('recvBuffer.used ' + self.recvBuffer.used); - var bytesRead = read(self.fd, + var bytesRead; + var receivedFd = -1; + + if (self.type == "unix") { + var msgInfo = recvMsg(self.fd, + self.recvBuffer, + self.recvBuffer.used, + self.recvBuffer.length - self.recvBuffer.used); + bytesRead = msgInfo[0]; + receivedFd = msgInfo[1]; + debug('receivedFd ' + receivedFd); + } else { + bytesRead = read(self.fd, self.recvBuffer, self.recvBuffer.used, self.recvBuffer.length - self.recvBuffer.used); + } debug('bytesRead ' + bytesRead + '\n'); if (bytesRead == 0) { @@ -59,10 +73,15 @@ function Socket (peerInfo) { self.emit('eof'); if (!self.writable) self.forceClose(); } else { - var slice = self.recvBuffer.slice(self.recvBuffer.used, - self.recvBuffer.used + bytesRead); - self.recvBuffer.used += bytesRead; - self.emit('data', slice); + if (receivedFd == -1) { + var slice = self.recvBuffer.slice(self.recvBuffer.used, + self.recvBuffer.used + bytesRead); + self.recvBuffer.used += bytesRead; + self.emit('data', slice); + } else { + self.recvBuffer.used += bytesRead; + self.emit('fd', receivedFd); + } } }; self.readable = false; @@ -70,10 +89,16 @@ function Socket (peerInfo) { self.sendQueue = []; // queue of buffers that need to be written to socket // XXX use link list? self.sendQueueSize = 0; // in bytes, not to be confused with sendQueue.length! + self.sendMessageQueueSize = 0; // number of messages remaining to be sent self._doFlush = function () { - assert(self.sendQueueSize > 0); + /* Socket becomes writeable on connect() but don't flush if there's + * nothing actually to write */ + if ((self.sendQueueSize == 0) && (self.sendMessageQueueSize == 0)) { + return; + } if (self.flush()) { assert(self.sendQueueSize == 0); + assert(self.sendMessageQueueSize == 0); self.emit("drain"); } }; @@ -134,6 +159,7 @@ Socket.prototype._allocateSendBuffer = function () { var b = new process.Buffer(1024); b.used = 0; b.sent = 0; + b.isMsg = false; this.sendQueue.push(b); return b; }; @@ -154,6 +180,7 @@ Socket.prototype._sendString = function (data, encoding) { } } // if we didn't find one, take the last + // TODO what if this isn't empty but encoding == fd ? if (!buffer) { buffer = self._sendQueueLast(); // if last buffer is used up @@ -166,13 +193,22 @@ Socket.prototype._sendString = function (data, encoding) { var charsWritten; var bytesWritten; - if (encoding.toLowerCase() == 'utf8') { + // The special encoding "fd" means that data is an integer FD and we want + // to pass the FD on the socket with sendmsg() + if (encoding == "fd") { + buffer.isFd = true; + // TODO is this OK -- does it guarantee that the fd is the only thing in the buffer? + charsWritten = buffer.asciiWrite(data, buffer.used, buffer.length - buffer.used); + bytesWritten = charsWritten; + } else if (encoding.toLowerCase() == 'utf8') { + buffer.isFd = false; charsWritten = buffer.utf8Write(data, buffer.used, buffer.length - buffer.used); bytesWritten = process.Buffer.utf8Length(data.slice(0, charsWritten)); } else { // ascii + buffer.isFd = false; charsWritten = buffer.asciiWrite(data, buffer.used, buffer.length - buffer.used); @@ -180,7 +216,11 @@ Socket.prototype._sendString = function (data, encoding) { } buffer.used += bytesWritten; - self.sendQueueSize += bytesWritten; + if (buffer.isFd) { + self.sendMessageQueueSize += 1; + } else { + self.sendQueueSize += bytesWritten; + } debug('charsWritten ' + charsWritten); debug('buffer.used ' + buffer.used); @@ -235,6 +275,27 @@ Socket.prototype.send = function (data, encoding) { return this.flush(); }; +// Sends a file descriptor over a unix socket +Socket.prototype.sendFD = function(socketToPass) { + var self = this; + + if (!self.writable) throw new Error('Socket is not writable'); + + if (self._sendQueueLast == END_OF_FILE) { + throw new Error('socket.close() called already; cannot write.'); + } + + if (self.type != "unix") { + throw new Error('FD passing only available on unix sockets'); + } + + if (! socketToPass instanceof Socket) { + throw new Error('Provided arg is not a socket'); + } + + return self.send(socketToPass.fd.toString(), "fd"); +}; + // Flushes the write buffer out. Emits "drain" if the buffer is empty. Socket.prototype.flush = function () { @@ -253,23 +314,35 @@ Socket.prototype.flush = function () { if (b.sent == b.used) { // this can be improved - save the buffer for later? - self.sendQueue.shift() + self.sendQueue.shift(); continue; } - bytesWritten = write(self.fd, - b, - b.sent, - b.used - b.sent); + var fdToSend = null; + if (b.isFd) { + fdToSend = parseInt(b.asciiSlice(b.sent, b.used - b.sent)); + bytesWritten = sendFD(self.fd, fdToSend); + } else { + bytesWritten = write(self.fd, + b, + b.sent, + b.used - b.sent); + } if (bytesWritten === null) { // could not flush everything self.writeWatcher.start(); assert(self.sendQueueSize > 0); return false; } - b.sent += bytesWritten; - self.sendQueueSize -= bytesWritten; - debug('bytes sent: ' + b.sent); + if (b.isFd) { + b.sent = b.used; + self.sendMessageQueueSize -= 1; + debug('sent fd: ' + fdToSend); + } else { + b.sent += bytesWritten; + self.sendQueueSize -= bytesWritten; + debug('bytes sent: ' + b.sent); + } } self.writeWatcher.stop(); return true; @@ -299,11 +372,11 @@ Socket.prototype.connect = function () { // socketError() if there isn't an error, we're connected. AFAIK this a // platform independent way determining when a non-blocking connection // is established, but I have only seen it documented in the Linux - // Manual Page connect(2) under the error code EINPROGRESS. + // Manual Page connect(2) under the error code EINPROGRESS. self.writeWatcher.set(self.fd, false, true); self.writeWatcher.start(); self.writeWatcher.callback = function () { - var errno = socketError(self.fd); + var errno = socketError(self.fd); if (errno == 0) { // connection established self.readWatcher.start(); @@ -340,7 +413,6 @@ Socket.prototype.address = function () { return getsockname(this.fd); }; - Socket.prototype.setNoDelay = function (v) { if (this.type == 'tcp') setNoDelay(this.fd, v); }; @@ -393,6 +465,7 @@ function Server (listener) { debug('accept: ' + JSON.stringify(peerInfo)); if (!peerInfo) return; var peer = new Socket(peerInfo); + peer.type = self.type; peer.server = self; self.emit('connection', peer); } diff --git a/src/node_net2.cc b/src/node_net2.cc index 90a05002b57..4e52cddb4b3 100644 --- a/src/node_net2.cc +++ b/src/node_net2.cc @@ -41,6 +41,9 @@ static Persistent remote_address_symbol; static Persistent remote_port_symbol; static Persistent address_symbol; static Persistent port_symbol; +static Persistent type_symbol; +static Persistent tcp_symbol; +static Persistent unix_symbol; #define FD_ARG(a) \ if (!(a)->IsInt32()) { \ @@ -181,7 +184,7 @@ static inline Handle ParseAddressArgs(Handle first, strcpy(un.sun_path, *path); addr = (struct sockaddr*)&un; - addrlen = path.length() + sizeof(un.sun_family); + addrlen = path.length() + sizeof(un.sun_family) + 1; } else { // TCP or UDP @@ -326,7 +329,6 @@ static Handle Connect(const Arguments& args) { return Undefined(); } - static Handle GetSockName(const Arguments& args) { HandleScope scope; @@ -358,6 +360,37 @@ static Handle GetSockName(const Arguments& args) { return scope.Close(info); } +static Handle GetPeerName(const Arguments& args) { + HandleScope scope; + + FD_ARG(args[0]) + + struct sockaddr_storage address_storage; + socklen_t len = sizeof(struct sockaddr_storage); + + int r = getpeername(fd, (struct sockaddr *) &address_storage, &len); + + if (r < 0) { + return ThrowException(ErrnoException(errno, "getsockname")); + } + + Local info = Object::New(); + + if (address_storage.ss_family == AF_INET6) { + struct sockaddr_in6 *a = (struct sockaddr_in6*)&address_storage; + + char ip[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, &(a->sin6_addr), ip, INET6_ADDRSTRLEN); + + int port = ntohs(a->sin6_port); + + info->Set(remote_address_symbol, String::New(ip)); + info->Set(remote_port_symbol, Integer::New(port)); + } + + return scope.Close(info); +} + static Handle Listen(const Arguments& args) { HandleScope scope; @@ -484,6 +517,78 @@ static Handle Read(const Arguments& args) { return scope.Close(Integer::New(bytes_read)); } +// bytesRead, receivedFd = t.recvMsg(fd, buffer, offset, length) +static Handle RecvMsg(const Arguments& args) { + HandleScope scope; + + if (args.Length() < 4) { + return ThrowException(Exception::TypeError( + String::New("Takes 4 parameters"))); + } + + FD_ARG(args[0]) + + if (!IsBuffer(args[1])) { + return ThrowException(Exception::TypeError( + String::New("Second argument should be a buffer"))); + } + + struct buffer * buffer = BufferUnwrap(args[1]); + + size_t off = args[2]->Int32Value(); + if (buffer_p(buffer, off) == NULL) { + return ThrowException(Exception::Error( + String::New("Offset is out of bounds"))); + } + + size_t len = args[3]->Int32Value(); + if (buffer_remaining(buffer, off) < len) { + return ThrowException(Exception::Error( + String::New("Length is extends beyond buffer"))); + } + + struct iovec iov[1]; + struct msghdr msg; + int received_fd; + char control_msg[CMSG_SPACE(sizeof(received_fd))]; + struct cmsghdr *cmsg; + + // TODO: zero out control_msg ? + + iov[0].iov_base = buffer_p(buffer, off); + iov[0].iov_len = buffer_remaining(buffer, off); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = NULL; + msg.msg_namelen = 0; + /* Set up to receive a descriptor even if one isn't in the message */ + msg.msg_control = (void *) control_msg; + msg.msg_controllen = CMSG_LEN(sizeof(received_fd)); + + ssize_t bytes_read = recvmsg(fd, &msg, 0); + + if (bytes_read < 0) { + if (errno == EAGAIN || errno == EINTR) return Null(); + return ThrowException(ErrnoException(errno, "recvMsg")); + } + + // Return array of [bytesRead, fd] with fd == -1 if there was no FD + Local a = Array::New(2); + a->Set(Integer::New(0), Integer::New(bytes_read)); + + cmsg = CMSG_FIRSTHDR(&msg); + if (cmsg->cmsg_type == SCM_RIGHTS) { + received_fd = *(int *) CMSG_DATA(cmsg); + } + else { + received_fd = -1; + } + + a->Set(Integer::New(1), Integer::New(received_fd)); + return scope.Close(a); +} + + // var bytesWritten = t.write(fd, buffer, offset, length); // returns null on EAGAIN or EINTR, raises an exception on all other errors static Handle Write(const Arguments& args) { @@ -527,6 +632,60 @@ static Handle Write(const Arguments& args) { return scope.Close(Integer::New(written)); } +// var bytesWritten = t.sendFD(self.fd) +// returns null on EAGAIN or EINTR, raises an exception on all other errors +static Handle SendFD(const Arguments& args) { + HandleScope scope; + + if (args.Length() < 2) { + return ThrowException(Exception::TypeError( + String::New("Takes 2 parameters"))); + } + + FD_ARG(args[0]) + + // TODO: make sure fd is a unix domain socket? + + if (!args[1]->IsInt32()) { + return ThrowException(Exception::TypeError( + String::New("FD to send is not an integer"))); + } + + int fd_to_send = args[1]->Int32Value(); + + struct msghdr msg; + struct iovec iov[1]; + char control_msg[CMSG_SPACE(sizeof(fd_to_send))]; + struct cmsghdr *cmsg; + char *dummy = "d"; // Need to send at least a byte of data in the message + + iov[0].iov_base = dummy; + iov[0].iov_len = 1; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = NULL; + msg.msg_namelen = 0; + msg.msg_flags = 0; + msg.msg_control = (void *) control_msg; + cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(sizeof(fd_to_send)); + *(int*) CMSG_DATA(cmsg) = fd_to_send; + msg.msg_controllen = cmsg->cmsg_len; + + ssize_t written = sendmsg(fd, &msg, 0); + + if (written < 0) { + if (errno == EAGAIN || errno == EINTR) return Null(); + return ThrowException(ErrnoException(errno, "sendmsg")); + } + + /* Note that the FD isn't explicitly closed here, this + * happens in the JS */ + + return scope.Close(Integer::New(written)); +} // Probably only works for Linux TCP sockets? // Returns the amount of data on the read queue. @@ -744,6 +903,9 @@ void InitNet2(Handle target) { NODE_SET_METHOD(target, "write", Write); NODE_SET_METHOD(target, "read", Read); + NODE_SET_METHOD(target, "sendFD", SendFD); + NODE_SET_METHOD(target, "recvMsg", RecvMsg); + NODE_SET_METHOD(target, "socket", Socket); NODE_SET_METHOD(target, "close", Close); NODE_SET_METHOD(target, "shutdown", Shutdown); @@ -757,7 +919,8 @@ void InitNet2(Handle target) { NODE_SET_METHOD(target, "socketError", SocketError); NODE_SET_METHOD(target, "toRead", ToRead); NODE_SET_METHOD(target, "setNoDelay", SetNoDelay); - NODE_SET_METHOD(target, "getsocksame", GetSockName); + NODE_SET_METHOD(target, "getsockname", GetSockName); + NODE_SET_METHOD(target, "getpeername", GetPeerName); NODE_SET_METHOD(target, "getaddrinfo", GetAddrInfo); NODE_SET_METHOD(target, "needsLookup", NeedsLookup); diff --git a/test/mjsunit/fixtures/net-fd-passing-receiver.js b/test/mjsunit/fixtures/net-fd-passing-receiver.js new file mode 100644 index 00000000000..be29a287233 --- /dev/null +++ b/test/mjsunit/fixtures/net-fd-passing-receiver.js @@ -0,0 +1,37 @@ +process.mixin(require("../common")); +net = require("net"); + +process.Buffer.prototype.toString = function () { + return this.utf8Slice(0, this.length); +}; + + +path = process.ARGV[2]; +greeting = process.ARGV[3]; + +receiver = net.createServer(function(socket) { + socket.addListener("fd", function(fd) { + var peerInfo = process.getpeername(fd); + peerInfo.fd = fd; + var passedSocket = new net.Socket(peerInfo); + + passedSocket.addListener("eof", function() { + passedSocket.close(); + }); + + passedSocket.addListener("data", function(data) { + passedSocket.send("[echo] " + data); + }); + passedSocket.addListener("close", function() { + receiver.close(); + }); + passedSocket.send("[greeting] " + greeting); + }); +}); + +/* To signal the test runne we're up and listening */ +receiver.addListener("listening", function() { + print("ready"); +}); + +receiver.listen(path); diff --git a/test/mjsunit/test-net-fd-passing.js b/test/mjsunit/test-net-fd-passing.js new file mode 100644 index 00000000000..a3bd0edc20a --- /dev/null +++ b/test/mjsunit/test-net-fd-passing.js @@ -0,0 +1,67 @@ +process.mixin(require("./common")); +net = require("net"); + +process.Buffer.prototype.toString = function () { + return this.utf8Slice(0, this.length); +}; + +var tests_run = 0; + +function fdPassingTest(path, port) { + var greeting = "howdy"; + var message = "beep toot"; + var expectedData = ["[greeting] " + greeting, "[echo] " + message]; + + puts(fixturesDir); + var receiverArgs = [fixturesDir + "/net-fd-passing-receiver.js", path, greeting]; + var receiver = process.createChildProcess(process.ARGV[0], receiverArgs); + + var initializeSender = function() { + var fdHighway = new net.Socket(); + fdHighway.connect(path); + + var sender = net.createServer(function(socket) { + fdHighway.sendFD(socket); + socket.flush(); + socket.forceClose(); // want to close() the fd, not shutdown() + }); + + sender.addListener("listening", function() { + var client = net.createConnection(port); + + client.addListener("connect", function() { + client.send(message); + }); + + client.addListener("data", function(data) { + assert.equal(expectedData[0], data); + if (expectedData.length > 1) { + expectedData.shift(); + } + else { + // time to shut down + fdHighway.close(); + sender.close(); + client.forceClose(); + } + }); + }); + + tests_run += 1; + sender.listen(port); + }; + + receiver.addListener("output", function(data) { + var initialized = false; + if ((! initialized) && (data == "ready")) { + initializeSender(); + initialized = true; + } + }); +} + +fdPassingTest("/tmp/passing-socket-test", 31075); + +process.addListener("exit", function () { + assert.equal(1, tests_run); +});