1 module redis.subscriber;
2 
3 import std.socket : TcpSocket, InternetAddress, SocketShutdown;
4 import std.stdio : stderr, writefln;
5 import std.array : empty, front, popFront;
6 import std.algorithm : find, any, min;
7 import std.conv : to;
8 public import redis;
9 
10 public:
11 
12 // Regular subscription callback
13 alias Callback = void delegate(string channel, string message);
14 
15 // Pattern subscription callback
16 alias PCallback = void delegate(string pattern, string channel, string message);
17 
18 /**
19  * Whether a response is of a particular message type
20  */
21 bool isType(string type)(Response r)
22 {
23     return r.values[0].value == type;
24 }
25 
26 class Subscriber
27 {
28 private:
29     TcpSocket conn;
30     Callback[string] callbacks;      // Regular subscription callbacks
31     PCallback[string] pCallbacks;    // Pattern subscription callbacks
32     Response[][] queue;              // Responses collected but not yet processed
33 
34     /**
35      * Send a redis command.
36      */
37     void send(string cmd)
38     {
39         // XXX - Do we need toMultiBulk here?
40         conn.send(toMultiBulk(cmd));
41     }
42 
43     /**
44      * Poll responses from the redis server and queue for later processing unless they match the
45      * predicate.
46      *
47      * This function is the workhorse behind all member functions of this type.
48      *
49      * @param pred - The predicate function that determines whether a response is an expected one
50      * @param expected - The number of responses expected to match the predicate
51      * @return - The last response that matched the predicate
52      */
53     Response queueUnless(bool delegate(Response) pred, size_t expected = 1)
54     {
55         Response resp;
56         size_t matched = 0;
57 
58         /* We will receive responses until all 'expected' responses are found. */
59 
60         // TODO - Timeout?
61         while (matched < expected) {
62             Response[] responses = receiveResponses(conn, 1);
63 
64             // This group may have zero or many matching responses
65 
66             while (!responses.empty) {
67                 auto found = responses.find!pred;
68 
69                 // Enqueue older responses for later processing
70                 queue ~= responses[0 .. $ - found.length];
71 
72                 if (!found.empty)
73                 {
74                     resp = found.front;
75                     responses = found[1 .. $];
76                     ++matched;
77                 }
78                 else
79                     break;
80             }
81         }
82 
83         return resp;
84     }
85 
86     /**
87      * Convenience wrapper for queueUnless(), which constructs a delegate from the provided message
88      * type.
89      */
90     Response queueUnlessType(string type)(size_t expected = 1)
91     {
92         return queueUnless(r => r.isType!type, expected);
93     }
94 
95     /**
96      * Process a single message
97      */
98     private void processMessage(Response resp)
99     {
100         auto elements = resp.values;
101 
102         /* Nested convenience function */
103         void reportBadResponse()
104         {
105             stderr.writefln("Unexpected subscription response: %s", resp);
106         }
107 
108         /* Nested convenience function returning response element at the specified index */
109         string element(size_t index)
110         {
111             return elements[index].value;
112         }
113 
114         string type = element(0);
115 
116         switch (type)
117         {
118         case "message":
119             if (elements.length != 3)
120                 reportBadResponse();
121             else
122             {
123                 string channel = element(1);
124                 const callback = (channel in callbacks);
125 
126                 if (callback)
127                 {
128                     string message = element(2);
129                     (*callback)(channel, message);
130                 }
131                 else
132                     stderr.writefln("No callback for message: %s", resp);
133             }
134             break;
135 
136         case "pmessage":
137             if (elements.length != 4)
138                 reportBadResponse();
139             else
140             {
141                 string pattern = element(1);
142                 const callback = (pattern in pCallbacks);
143 
144                 if (callback) {
145                     string channel = element(2);
146                     string message = element(3);
147 
148                     (*callback)(pattern, channel, message);
149                 }
150                 else
151                     stderr.writefln("No callback for pattern message: %s", resp);
152             }
153             break;
154 
155         default:
156             reportBadResponse();
157             break;
158         }
159     }
160 
161 public:
162 
163     /**
164      * Create a new non-blocking subscriber using a Redis host and port
165      */
166     this(string host = "127.0.0.1", ushort port = 6379)
167     {
168         conn = new TcpSocket(new InternetAddress(host, port));
169         conn.blocking = false;
170     }
171 
172     /**
173      * Create a new subscriber using an existing socket
174      */
175     this(TcpSocket conn)
176     {
177         this.conn = conn;
178     }
179 
180     /**
181      * Subscribe to a channel
182      *
183      * Returns the number of channels currently subscribed to
184      */
185     size_t subscribe(string channel, Callback callback)
186     {
187         auto cmd = "SUBSCRIBE " ~ channel;
188         send(cmd);
189 
190         Response resp = queueUnlessType!"subscribe"();
191         callbacks[channel] = callback;
192 
193         return resp.values[2].to!int;
194     }
195 
196     /**
197      * Subscribe to a channel pattern
198      *
199      * Returns the number of channels currently subscribed to
200      */
201     size_t psubscribe(string pattern, PCallback callback)
202     {
203         auto cmd = "PSUBSCRIBE " ~ pattern;
204         send(cmd);
205 
206         Response resp = queueUnlessType!"psubscribe"();
207         pCallbacks[pattern] = callback;
208 
209         return resp.values[2].to!int;
210     }
211 
212     /**
213      * Unsubscribe from a channel
214      *
215      * Returns the number of channels currently subscribed to
216      */
217     size_t unsubscribe(string channel)
218     {
219         auto cmd = "UNSUBSCRIBE " ~ channel;
220         send(cmd);
221 
222         Response resp = queueUnlessType!"unsubscribe"();
223         callbacks.remove(channel);
224 
225         return resp.values[2].to!int;
226     }
227 
228     /**
229      * Unsubscribe from all channels
230      *
231      * Returns the number of channels currently subscribed to
232      */
233     size_t unsubscribe()
234     {
235         auto cmd = "UNSUBSCRIBE";
236         send(cmd);
237 
238         Response resp = queueUnlessType!"unsubscribe"(callbacks.length);
239         callbacks = null;
240 
241         return resp.values[2].to!int;
242     }
243 
244     /**
245      * Unsubscribe from a channel pattern
246      *
247      * Returns the number of channels currently subscribed to
248      */
249     size_t punsubscribe(string pattern)
250     {
251         auto cmd = "PUNSUBSCRIBE " ~ pattern;
252         send(cmd);
253 
254         Response resp = queueUnlessType!"punsubscribe"();
255         pCallbacks.remove(pattern);
256 
257         return resp.values[2].to!int;
258     }
259 
260     /**
261      * Unsubscribe from all channel patterns
262      *
263      * Returns the number of channels currently subscribed to
264      */
265     size_t punsubscribe()
266     {
267         auto cmd = "PUNSUBSCRIBE";
268         send(cmd);
269 
270         Response resp = queueUnlessType!"punsubscribe"(pCallbacks.length);
271         pCallbacks = null;
272 
273         return resp.values[2].to!int;
274     }
275 
276     /**
277      * Close the redis connection
278      */
279     Response quit()
280     {
281         auto cmd = "QUIT";
282         send(cmd);
283 
284         Response resp = queueUnless(r => r.value == "OK");
285 
286         return resp;
287     }
288 
289     /**
290      * Send a PING command
291      */
292     Response ping(string argument = null)
293     {
294         auto cmd = "PING " ~ argument;
295 
296         send(cmd);
297         Response resp = queueUnless(r => r.isType!"pong");
298 
299         return resp;
300     }
301 
302     /**
303      * Poll for queued messages on the redis server and call their callbacks
304      */
305     void processMessages()
306     {
307         queue ~= receiveResponses(conn, 0);
308 
309         foreach (arr; queue) {
310             foreach (resp; arr) {
311                 processMessage(resp);
312             }
313         }
314 
315         queue.length = 0;
316         queue.assumeSafeAppend();
317     }
318 }