Friday, July 18, 2008

Java MapReduce

Ok, this isn't quite MapReduce in its entirety. It lacks any facility for distributed computing across multiple machines. But what it does do is simplify concurrent programming by distributing your MapReduce throughout multiple threads, which means multiple processors on most SMP machines.

It pretty nicely abstracts away all the threading and synchronizing issues so you can focus only on the problem at hand.

To use it, you create a class to encapsulate your mapping and reducing functions. This class must extend class MapperThread and implement class Reducer, which means you must implement two methods,
  • public Object map(Object data)
  • public Object getReduction(Collection results)
Then you call the static method MapReduce.mapReduce() passing in your data, as a Vector, and the name of your mapper-reducer class and it will return your results inside an object of your choice.

It's time for an example. Let's use a classic MapReduce example, the word counter. We'll put some "documents" as Strings into a Vector, pass this into mapReduce() which will concurrently process each document and return to us a HashMap where each key is a word and each value is the total count of that word in all our documents.

For example if our documents are {"This is document 1", "This is another document", "Document 3"}
then our output should be {this=2, is=2, document=3, 1=1, another=1, 3=1} (assuming case-insensitivity).

Here's the code to do it:

WordCount.java


package test.misc;

import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Vector;
import java.util.Map.Entry;

import test.misc.MapReduce;
import test.misc.MapperThread;
import test.misc.Reducer;

public class WordCount extends MapperThread implements Reducer {

static final String doc1 = "This is document 1";
static final String doc2 = "This is another document";
static final String doc3 = "Document 3";

public Object map(Object data) {
String doc = (String) data;
String[] tokens = doc.trim().split("\\s+");
HashMap results = new HashMap();
for (int i = 0; i < tokens.length; i++) {
accumulate(tokens[i], results);
}
return results;
}

void accumulate(String s, HashMap acc) {
String key = s.toLowerCase();
if (acc.containsKey(key)) {
Integer I = (Integer) acc.get(key);
int newval = I.intValue() + 1;
acc.put(key, new Integer(newval));
} else {
acc.put(key, new Integer(1));
}
}

public Object getReduction(Collection c) {
HashMap h = new HashMap();
for (Iterator i = c.iterator(); i.hasNext();) {
Collection entries = ((HashMap) i.next()).entrySet();
for (Iterator j = entries.iterator(); j.hasNext();) {
Entry e = (Entry) j.next();
Object key = e.getKey();
Integer val = (Integer) e.getValue();
if (h.containsKey(key)) {
Integer oldval = (Integer) h.get(key);
h.put(key, new Integer(val.intValue() + oldval.intValue()));
} else {
h.put(key, val);
}
}
}
return h;
}

public static void main(String[] args) {
Vector docs = new Vector();
docs.add(doc1);
docs.add(doc2);
docs.add(doc3);

HashMap results = null;

try {
results = (HashMap) MapReduce.mapReduce(docs, "test.misc.WordCount", "test.misc.WordCount");
} catch (Exception e) {
e.printStackTrace();
}

System.out.println(results.toString());
}
}


Notice what we're NOT doing. We're not starting any threads or worrying about synchronization.

We are processing documents concurrently but all the threading is abstracted away to the MapReduce and MapperThread classes.

We can re-use these classes again and again to perform any number of tasks in parallel. We could use them to make concurrent http requests, for instance, and all the sudden we are doing concurrent DISTRIBUTED programming.

There's also an optional argument to mapReduce, maxThreads, that limits the number of concurrently active threads. You'll need to use this limit when working with large data sets.

Here are the three classes that handle the dirty work. You'll have to configure your own package and class paths of course. Enjoy!

MapReduce.java


package test.misc.util;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Vector;

public class MapReduce {

//
// If you get OutOfMemoryError, try setting maxThreads to something like 500 or 1000, or even lower if you're
// manipulating very large objects.
//
public static Object mapReduce(Vector datalist, String mapperClass, String reducerClass, int maxThreads) throws Exception {
Vector lists = new Vector();
splitList(datalist, maxThreads, lists);
HashMap map = new HashMap();
for (Iterator iter = lists.iterator(); iter.hasNext();) {
Vector list = (Vector) iter.next();
Vector threads = createThreads(list.size(), mapperClass);
initializeMap(map, threads);
startThreads(threads, list, map);
waitForThreads(threads);
}
Object results = runReducer(map, reducerClass);
return results;
}

public static Object mapReduce(Vector datalist, String mapperClass, String reducerClass) throws Exception {
return mapReduce(datalist, mapperClass, reducerClass, 1000000);
}

static void initializeMap(HashMap results, Vector threads) {
for (int i = 0; i < threads.size(); i++) {
results.put(threads.get(i), null);
}
}

static Vector createThreads(int numThreads, String mapperClass) throws Exception {
Vector threads = new Vector(numThreads);
for (int i = 0; i < numThreads; i++) {
try {
MapperThread thread = (MapperThread) Class.forName(mapperClass).newInstance();
threads.add(thread);
} catch (java.lang.OutOfMemoryError e) {
System.err.println("Error: thread " + i);
throw e;
}
}
return threads;
}

static Vector startThreads(Vector threads, Vector data, HashMap map) throws Exception {
for (int i = 0; i < data.size(); i++) {
try {
Object item = data.get(i);
MapperThread thread = (MapperThread) threads.get(i);
thread.setParams(item, map);
thread.start();
} catch (java.lang.OutOfMemoryError e) {
System.err.println("Error: thread " + i);
throw e;
}
}
return threads;
}

static void waitForThreads(Vector threads) {
for (Iterator i = threads.iterator(); i.hasNext();) {
Thread thread = (Thread) i.next();
try {
thread.join();
} catch (InterruptedException e) {}
}
}

static Object runReducer(HashMap map, String reducerClass) throws Exception {
Reducer reducer = (Reducer) Class.forName(reducerClass).newInstance();
return reducer.getReduction(map.values());
}

static Vector splitList(Vector in, int max, Vector acc) {
if (in.size() <= max) {
acc.add(in);
return acc;
} else {
acc.add(new Vector(in.subList(0, max)));
return splitList(new Vector(in.subList(max, in.size())), max, acc);
}
}
}

MapperThread.java


package test.misc.util;

import java.util.HashMap;

public class MapperThread extends Thread {
protected Object data;
protected HashMap mappings;

public void setParams(Object data, HashMap map) {
this.data = data;
this.mappings = map;
}

/**
* This is the method your subclass must override.
* @param data - This is taken from successive elements of the list
* you pass in to mapReduce.
* @return - Whatever is returned here will be added to the collection passed in to
* your Reducer.getReduction() method.
*/
public Object map(Object data) {
return data;
}

public void run() {
mappings.put(this, map(data));
}
}


Reducer.java


package test.misc.util;

import java.util.Collection;

public interface Reducer {

/**
* You'll override this in your own implementation.
* @param results - A collection of objects returned from your MapperThread subclass.
* @return - The return from this method is what is returned by mapReduce().
*/
public Object getReduction(Collection results);
}