Friday, November 21, 2008

Java MapReduce Improved

My MapReduce class is much improved, now fitting into a single class. I find that I mostly use the pmap() method for parallel processing of list elements.

Here is example usage using the same example as my previous Java MapReduce post:

import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Vector;
import java.util.Map.Entry;

import MapReduce;

public class WordCount {

static final String doc1 = "This is document 1";
static final String doc2 = "This is another document";
static final String doc3 = "Document 3";

public Object map(Object data) {
String doc = (String) data;
String[] tokens = doc.trim().split("\\s+");
HashMap results = new HashMap();
for (int i = 0; i < tokens.length; i++) {
accumulate(tokens[i], results);
}
return results;
}

void accumulate(String s, HashMap acc) {
String key = s.toLowerCase();
if (acc.containsKey(key)) {
Integer I = (Integer) acc.get(key);
int newval = I.intValue() + 1;
acc.put(key, new Integer(newval));
} else {
acc.put(key, new Integer(1));
}
}

public Object reduce(Object input, Object acc) {
HashMap h = (HashMap) acc;
Collection entries = ((HashMap) input).entrySet();
for (Iterator j = entries.iterator(); j.hasNext();) {
Entry e = (Entry) j.next();
Object key = e.getKey();
Integer val = (Integer) e.getValue();
if (h.containsKey(key)) {
Integer oldval = (Integer) h.get(key);
h.put(key, new Integer(val.intValue() + oldval.intValue()));
} else {
h.put(key, val);
}
}
return h;
}

public static void main(String[] args) {
Vector docs = new Vector();
docs.add(doc1);
docs.add(doc2);
docs.add(doc3);

HashMap results = new HashMap();
WordCount wc = new WordCount();

try {
results = (HashMap) MapReduce.mapReduce(docs, wc, "map", wc, "reduce", results);
} catch (Exception e) {
e.printStackTrace();
}

System.out.println(results.toString());
}
}

Pretty similar to the earlier usage but now you can pass in an instance and have access to it's data and methods from inside your map function - just be careful of synchronization if you're writing to any variable inside the map function.

You can also use a static method to do the mapping or folding.

Here's the new MapReduce class:

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

/**
* The MapReduce class provides static methods
* for encapsulating parallel processing.
* This class cannot be instantiated.
*
* The pmap() (parallel map) method in particular makes concurrent
* processing simple by abstracting away all the threading and
* synchronization.
*
* @author mike
*/
public class MapReduce extends Thread {

/**
* Concurrently maps each object A in the List inputList to a new object B by applying
* the method meth to every element in the list. Equivalent to calling
* pmap(inputList, obj, meth, 0) I.e. with no limit on the number of threads.
* @return ArrayList
* @param inputList - List of objects to be mapped
* @param obj - An instance of the class that defines meth
* @param meth - The name of the method to be run on each object in the data list.
* The method must have the prototype:
* Object method_name(Object input)
*/
public static List pmap(List inputList, Object obj, String meth) throws Exception {
return MapReduce.pmap(inputList, obj, meth, 0);
}

/**
* Concurrently maps each object A in the List inputList to a new object B by applying
* the method meth to every element in the list. Returns the new B objects in a
* new list. Mappings in the new list are in the same order and correspond to the
* objects in the original list but since each mapping is done in parallel,
* the evaluation order is undefined.
*
* @return ArrayList
* @param inputList - List of objects to be mapped
* @param obj - An instance of the class that defines meth
* @param meth - The name of the method to be run on each object in the data list.
* The method must have the prototype:
* Object method_name(Object input)
* @param maxThreads - The maximum number of threads to run at once.
* If 0, no limit. Use this limit to prevent OutOfMemoryErrors when
* processing large lists.
*/
public static List pmap(List inputList, Object obj, String meth, int maxThreads) throws Exception {
int size = inputList.size();
int inc = maxThreads <= 0 ? size : maxThreads;
ArrayList retval = new ArrayList(size);
for (int i = 0; i < size; i += inc) {
int end = (i + inc < size ? i + inc : size);
List threads = createThreads(inputList, i, end, obj, meth);
waitForThreads(threads);
for (int j = 0; j < threads.size(); j++) {
retval.add(((MapReduce)threads.get(j)).output);
}
}
return retval;
}


/**
* Calls meth(elem, accIn) on successive elements of list, starting with accIn == acc0.
* meth must return an accumulator which is passed to the next call.
* The function returns the final value of the accumulator.
* acc0 is returned if the list is empty.
* @param list - The list to be folded into a single object.
* @param obj - The instance of the class that defines meth.
* @param method - The accumulating function.
* The method must have the prototype
* Object method_name(Object input, Object accIn)
* @param acc0 - Initial accumulator
* @return Object
* @throws Exception
*/
public static Object fold(List list, Object obj, String meth, Object acc0) throws Exception {
Class[] types = {Object.class, Object.class};
Method m = obj.getClass().getMethod(meth, types);
for (int i = 0; i < list.size(); i++) {
Object[] args = {list.get(i), acc0};
acc0 = m.invoke(obj, args);
}
return acc0;
}

/**
* Combines the operations of pmap and fold with no limit on the number
* of concurrent threads.
*/
public static Object mapReduce(List list, Object mapObj, String mapMeth, Object foldObj, String foldMeth, Object foldAcc) throws Exception {
return mapReduce(list, mapObj, mapMeth, foldObj, foldMeth, foldAcc, 0);
}

/**
* Combines the operations of pmap and fold with a thread limit.
*/
public static Object mapReduce(List list, Object mapObj, String mapMeth, Object foldObj, String foldMeth, Object foldAcc, int maxThreads) throws Exception {
List mapResult = pmap(list, mapObj, mapMeth, maxThreads);
return fold(mapResult, foldObj, foldMeth, foldAcc);
}


static List createThreads(List list, int begin, int end, String obj, String meth) throws Exception {
return createThreads(list, begin, end, obj, meth, true);
}

static List createThreads(List list, int begin, int end, Object obj, String meth) throws Exception {
return createThreads(list, begin, end, obj, meth, false);
}

static List createThreads(List list, int begin, int end, Object obj, String meth, boolean isStaticMethod) throws Exception {
ArrayList threads = new ArrayList(end - begin);
for (int i = begin; i < end; i++) {
try {
MapReduce p = isStaticMethod ? new MapReduce((String)obj, meth, list.get(i)) : new MapReduce(obj, meth, list.get(i));
threads.add(p);
p.start();
} catch (java.lang.OutOfMemoryError e) {
System.err.println("Error: thread " + i);
throw e;
}
}
return threads;
}


static void waitForThreads(List threads) {
for (int i = 0; i < threads.size(); i++) {
Thread thread = (Thread) threads.get(i);
try {
thread.join();
} catch (InterruptedException e) {}
}
}

//
// Non-static instance methods and fields
//
Object obj;
Method meth;
Object input;
Object output;

// This class should never be instantiated except by its own static methods.
private MapReduce(String classname, String meth, Object in) throws Exception {
Class[] types = {Object.class};
this.meth = Class.forName(classname).getMethod(meth, types);
this.input = in;
}

// This class should never be instantiated except by its own static methods.
private MapReduce(Object obj, String meth, Object in) throws Exception {
this.obj = obj;
Class[] types = {Object.class};
this.meth = obj.getClass().getMethod(meth, types);
this.input = in;
}

public void run() {
Object[] args = {this.input};
try {
this.output = this.meth.invoke(this.obj, args);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

Enjoy!

No comments: