1 module websocketd.server; 2 3 import std.socket; 4 import std.experimental.logger; 5 6 import websocketd.request; 7 import websocketd.frame; 8 9 alias PeerID = size_t; 10 11 class WebSocketState { 12 Socket socket; 13 bool handshaken; 14 Frame[] frames = []; 15 public immutable PeerID id; 16 public immutable Address address; 17 public string path; 18 19 @disable this(); 20 21 this(PeerID id, Socket socket) { 22 this.socket = socket; 23 this.handshaken = false; 24 this.id = id; 25 this.address = cast(immutable Address)(socket.remoteAddress); 26 } 27 28 public void performHandshake(ubyte[] message) { 29 import std.base64 : Base64; 30 import std.digest.sha : sha1Of; 31 import std.conv : to; 32 33 assert(!handshaken); 34 enum MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 35 enum KEY = "Sec-WebSocket-Key"; 36 Request request = Request.parse(message); 37 if (!request.done || KEY !in request.headers) 38 return; 39 this.path = request.path; 40 string accept = Base64.encode(sha1Of(request.headers[KEY] ~ MAGIC)).to!string; 41 assert(socket.isAlive); 42 socket.send( 43 "HTTP/1.1 101 Switching Protocol\r\n" ~ "Upgrade: websocket\r\n" ~ "Connection: Upgrade\r\n" 44 ~ "Sec-WebSocket-Accept: " ~ accept ~ "\r\n\r\n"); 45 handshaken = true; 46 } 47 } 48 49 abstract class WebSocketServer { 50 51 private WebSocketState[PeerID] sockets; 52 private Socket listener; 53 private size_t maxConnections; 54 55 abstract void onOpen(PeerID s, string path); 56 abstract void onTextMessage(PeerID s, string s); 57 abstract void onBinaryMessage(PeerID s, ubyte[] o); 58 abstract void onClose(PeerID s); 59 60 private static PeerID counter = 0; 61 62 this() { 63 listener = new TcpSocket(); 64 } 65 66 private void add(Socket socket) { 67 if (sockets.length >= maxConnections) { 68 infof("Maximum number of connections reached (%d)", maxConnections); 69 socket.close(); 70 return; 71 } 72 auto s = new WebSocketState(counter++, socket); 73 infof("Acception connection from %s (id=%s)", socket.remoteAddress, s.id); 74 sockets[s.id] = s; 75 } 76 77 private void remove(WebSocketState socket) { 78 sockets.remove(socket.id); 79 infof("Closing connection with client id %s", socket.id); 80 if (socket.socket.isAlive) 81 socket.socket.close(); 82 onClose(socket.id); 83 } 84 85 private void handle(WebSocketState socket, ubyte[] message) { 86 import std.conv : to; 87 import std.algorithm : swap; 88 89 string processId = typeof(this).stringof ~ socket.id.to!string; 90 if (socket.handshaken) { 91 Frame prevFrame = processId.parse(message); 92 Frame newFrame, temp; 93 do { 94 handleFrame(socket, prevFrame); 95 newFrame = processId.parse([]); 96 swap(newFrame, prevFrame); 97 } 98 while (newFrame != prevFrame); 99 } else { 100 socket.performHandshake(message); 101 if (socket.handshaken) 102 infof("Handshake with %s done (path=%s)", socket.id, socket.path); 103 onOpen(socket.id, socket.path); 104 } 105 } 106 107 private void handleFrame(WebSocketState socket, Frame frame) { 108 tracef("From client %s received frame: done=%s; fin=%s; op=%s; length=%d", socket.id, frame.done, frame.fin, frame.op, frame.length); 109 if (!frame.done) 110 return; 111 final switch (frame.op) { 112 case Op.CONT: 113 return handleCont(socket, frame); 114 case Op.TEXT: 115 return handleText(socket, frame); 116 case Op.BINARY: 117 return handleBinary(socket, frame); 118 case Op.CLOSE: 119 return handleClose(socket, frame); 120 case Op.PING: 121 return handlePing(socket, frame); 122 case Op.PONG: 123 return handlePong(socket, frame); 124 } 125 } 126 127 private void handleCont(WebSocketState socket, Frame frame) { 128 assert(socket.frames.length > 0); 129 if (frame.fin) { 130 Op originalOp = socket.frames[0].op; 131 ubyte[] data = []; 132 for (size_t i = 0; i < socket.frames.length; i++) 133 data ~= socket.frames[i].data; 134 data ~= frame.data; 135 socket.frames = []; 136 if (originalOp == Op.TEXT) 137 onTextMessage(socket.id, cast(string)data); 138 else if (originalOp == Op.BINARY) 139 onBinaryMessage(socket.id, data); 140 } else 141 socket.frames ~= frame; 142 } 143 144 private void handleText(WebSocketState socket, Frame frame) { 145 assert(socket.frames.length == 0); 146 if (frame.fin) 147 onTextMessage(socket.id, cast(string)frame.data); 148 else 149 socket.frames ~= frame; 150 } 151 152 private void handleBinary(WebSocketState socket, Frame frame) { 153 assert(socket.frames.length == 0); 154 if (frame.fin) 155 onBinaryMessage(socket.id, frame.data); 156 else 157 socket.frames ~= frame; 158 } 159 160 private void handleClose(WebSocketState socket, Frame frame) { 161 remove(socket); 162 } 163 164 private void handlePing(WebSocketState socket, Frame frame) { 165 socket.socket.send(Frame(true, Op.PONG, false, 0, [0, 0, 0, 0], true, []).serialize); 166 } 167 168 private void handlePong(WebSocketState socket, Frame frame) { 169 tracef("Received pong from %s", socket.id); 170 } 171 172 public void sendText(PeerID dest, string message) { 173 if (dest !in sockets) { 174 warningf("Tried to send a message to %s which is not connected", dest); 175 return; 176 } 177 import std.string : representation; 178 179 auto bytes = message.representation.dup; 180 auto frame = Frame(true, Op.TEXT, false, message.length, [0, 0, 0, 0], true, bytes); 181 auto serial = frame.serialize; 182 tracef("Sending %d bytes to %s in one frame of %d bytes long", bytes.length, dest, serial.length); 183 sockets[dest].socket.send(serial); 184 } 185 186 public void sendBinary(PeerID dest, ubyte[] message) { 187 if (dest !in sockets) { 188 warningf("Tried to send a message to %s which is not connected", dest); 189 return; 190 } 191 auto frame = Frame(true, Op.BINARY, false, message.length, [0, 0, 0, 0], true, message); 192 auto serial = frame.serialize; 193 tracef("Sending %d bytes to %s in one frame of %d bytes long", message.length, dest, serial.length); 194 sockets[dest].socket.send(serial); 195 } 196 197 public void run(ushort port, size_t maxConnections, size_t bufferSize = 1024)() { 198 this.maxConnections = maxConnections; 199 200 listener.blocking = false; 201 listener.bind(new InternetAddress(port)); 202 listener.listen(10); 203 204 infof("Listening on port: %d", port); 205 infof("Maximum allowed connections: %d", maxConnections); 206 207 auto set = new SocketSet(maxConnections + 1); 208 while (true) { 209 set.add(listener); 210 foreach (id, s; sockets) 211 set.add(s.socket); 212 Socket.select(set, null, null); 213 214 foreach (id, socket; sockets) { 215 if (!set.isSet(socket.socket)) 216 continue; 217 ubyte[bufferSize] buffer; 218 long receivedLength = socket.socket.receive(buffer[]); 219 tracef("Received %d bytes from %s", receivedLength, socket.id); 220 if (receivedLength > 0) { 221 handle(socket, buffer[0 .. receivedLength]); 222 continue; 223 } 224 remove(socket); 225 } 226 227 if (set.isSet(listener)) { 228 add(listener.accept()); 229 } 230 231 set.reset(); 232 } 233 } 234 }