MapReduce Implementation for Union-Find

A fun problem I had to solve a while back deals with finding all connected components in a very large graph to build a list of disjoint sets.  A simple serial solution might involve picking a node at random, running a BFS from it, partitioning out the visited nodes, and repeating until no nodes remain.  Unfortunately, that’s not feasible for a graph with billions of nodes, as it requires holding the entire graph in memory.  Luckily, I had a hadoop cluster at my disposal, so once I could formulate a solution in MapReduce, I’d be ready to go.

Before diving into the solution, here’s a better breakdown of the problem.  Imagine a graph where connected components form equivalence classes, and in the universe of nodes there are a lot of different classes.  In my case, we’re dealing with more than 10 billion nodes, which form yield more than one billion disjoint sets.  The size of individual components can range from one node to hundreds of thousands.  Some clusters within the larger components might be strongly connected, with many paths from any two nodes, while others could have a sparse collection of edges.  The process needed to account for all of these variations and eventually yield the simple disjoint sets.

The Inputs:
A long list of lists.  Each list represents nodes that we know are connected.  Initially, each list may contain only one or two nodes.

The Algorithm:
1) Finding intersections between sets using set representatives and merge those that overlap
2) Determine which sets are isolated (their intersection with every other set is empty), and isolate them.
3) Repeat 1 and 2 until are sets are isolated.

A naive approach for growing the frontier of nodes could involve emitting every pair of nodes from a set as a way to connect nodes by way of a third party.  However, this would be very inefficient in later stages of the algorithm, as the number of emitted pairs from a set of size n would grow to O(n2).  A more efficient implementation is to pick a consistent representative, R, from each set (I use the smallest), and emit <R, {Nodes}>.  To keep the graph bi-directional, also emit the inverse.  That is, for E in Nodes emit <E, {R}>.  This step serves two purposes.  First, it finds intersections between sets, and establishes connections between their representatives.  Second, it produces unions as representatives for sets are improved through those intersections.  The Reduce phase simply passes the Key through unaltered and shrinks the values down to only distinct elements.

Job 2 builds directly off the output from Job 1.  The goal is to determine if a representative for a set of nodes is the only representative for each node.  The mapper reads a key, and it’s list of representatives.  If there is only one representative, then it emits <R, {K}> (a constituent swap).  If there are many representatives, then it passes the pair through unaltered (ambiguous representatives or THE representative).  Think of it as each node reporting back to its representative when there’s only one.  The reducer then looks for keys where each value appears twice.  Once for the representative pass through, and once for each constituent swap.  Those are flagged as disjoint sets.  Any constituents with ambiguous representatives won’t appear in that list, which means there are more iterations to perform.

For simplicity, the output key from the second job is either DISJOINT or OPEN and the value is a list of all nodes that appeared in a reduce call.  This feeds the next iteration, which ignores any DISJOINT records.  Once the algorithm completes, an final job is run to collect only the DISJOINT records from each iteration’s output.

Because the frontier of each set grows out from all nodes simultaneously, the maximum number of iterations O(log n) where n is the longest path between any two nodes in the set.  Also, graphs with mostly small components benefit from early isolation of disjoint sets, since they are removed from inputs to later iterations.

Finally, here’s the code:

package chaser.hadoop;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.TreeSet;
import java.util.UUID;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

public class UnionFind extends Configured implements Tool {
   
   private enum Counter{
      OPEN,
      DISJOINT
   }
   private static final Text DISJOINT = new Text("D");
   private static final Text OPEN = new Text("O");
   
   /**
    * For potentially overlapping sets, elect a representative.
    *
    * Emits <R, {Nodes}> and <N, {R}> for each N in Nodes.
    * Ignores known DISJOINT sets.
    */
   public static class ElectMap extends Mapper<Text, TextArrayWritable, Text, TextArrayWritable> {
      @Override
      protected void map(Text key, TextArrayWritable value, Context context)
         throws IOException,InterruptedException {

         // If it was a disjoint output from the last iteration, then don't
         // continue to propogate it.
         if( key.equals(DISJOINT) ) {
            context.getCounter(Counter.DISJOINT).increment(1);
            return;
         }

         context.getCounter(Counter.OPEN).increment(1);

         // Use a tree set so it's easier to find the smallest while uniquifying
         TreeSet<Text> distinct = new TreeSet<Text>( Arrays.asList(value.get()) );

         TextArrayWritable all = new TextArrayWritable( distinct );

         Text representative = distinct.pollFirst();
         TextArrayWritable representative_val = new TextArrayWritable( representative );
         
         context.write(representative, all);
         for( Text other : distinct )
            context.write(other, representative_val);
      }
   }

   /**
    * Emits the union of all incoming array writables for a key.
    */
   public static class ElectReduce extends Reducer<Text, TextArrayWritable, Text, TextArrayWritable> {
      @Override
      protected void reduce(Text key, Iterable<TextArrayWritable> values, Context context)
         throws IOException, InterruptedException {

         TreeSet<Text> union = new TreeSet<Text>();
         
         for( TextArrayWritable value : values ) {
            union.addAll( Arrays.asList(value.get()) );
         }
         
         context.write(key, new TextArrayWritable(union) );            
      }
   }
   
   /**
    * Performs representative pass throughs or constituent swaps.
    */
   public static class PartitionMap extends Mapper<Text, TextArrayWritable, Text, TextArrayWritable> {
      @Override
      protected void map(Text key, TextArrayWritable value, Context context) throws IOException ,InterruptedException {
         // Constituent Swap
         if( value.get().length == 1 )
            context.write( value.get()[0], new TextArrayWritable(key) );
         // Representative pass through
         else
            context.write( key, value );
      }
   }
   
   /**
    * Count the number of constituents, and label the set as DISJOINT if each element appears twice.
    */
   public static class PartitionReduce extends Reducer<Text, TextArrayWritable, Text, Text> {
      @Override
      protected void reduce(Text key, Iterable<TextArrayWritable> values, Context context) throws IOException ,InterruptedException {
         HashMap<Text, Integer> counts = new HashMap<Text, Integer>();
         
         // Inject a 1 for the key, so it counts itself twice.
         counts.put(key, 1);
         for( TextArrayWritable value : values ) {
            for( Text text : value.get() )
               if( counts.containsKey(text) )
                  counts.put(text, counts.get(text)+1);
               else
                  counts.put(text, 1);
         }

         // Assume it's DISJOINT until we see an odd man
         TextArrayWritable value = new TextArrayWritable(counts.keySet());
         key = DISJOINT;
         for( Integer count : counts.values() ) {
            if( count != 2 ) {
               key = OPEN;
               break;
            }
         }

         if( key.equals(DISJOINT) )
            context.getCounter(Counter.DISJOINT).increment(1);
         else
            context.getCounter(Counter.OPEN).increment(1);
         
         context.write(key, value);
      }
   }

   /**
    * Simple pass that emits tags all incoming records with the OPEN key
    */
   public static class MarkOpenMap extends Mapper<Writable, TextArrayWritable, Text, TextArrayWritable> {
      @Override
      protected void map(Writable key, TextArrayWritable value, Context context) throws IOException ,InterruptedException {
         context.write( OPEN, value );
      }
   }

   /**
    * Simple pass that emits all DISJOINT records
    */
   public static class EmitDisjointMap extends Mapper<Text, TextArrayWritable, Text, TextArrayWritable> {
      @Override
      protected void map(Text key, TextArrayWritable value, Context context) throws IOException ,InterruptedException {
         if( key.equals(DISJOINT) )
            context.write( key, value );
      }
   }
   
   private String makeTempSpace() {
      String temporary = "/tmp/union_find/" + UUID.randomUUID();
      Path temp_path = new Path(temporary);
      FileSystem fs = temp_path.getFileSystem(getConf());
      fs.mkdirs(temp_path);
      fs.deleteOnExit(temp_path);
      return temporary;
   }

   @Override
   public int run(String[] args) throws Exception {
      
      // Create a temporary work location that gets cleaned up on exit.
      String temporary = makeTempSpace();
      String elect_path = temporary + "/elect.";
      String partition_path = temporary + "/partition.";

      int iteration = 0;


      // This step assumes some prior data setup.  Specifically, the input
      // must be in a sequence file of <K, TextArrayWritable>.
      // If IO is very important, the job could be optimized away by tacking the
      // mapper onto the first iteration of the loop below with a ChainMapper.
      Job setup = new Job(getConf());
      setup.setJarByClass(getClass());
      setup.setName("Union Find (setup)");
      setup.setMapperClass(MarkOpenMap.class);
      setup.setOutputDir( partition_path + iteration );
      setup.setNumReduceTasks(0);
      setup.setOutputKeyClass(Text.class);
      setup.setOutputValueClass(TextArrayWritable.class);
      setup.waitForCompletion(false);
      
      while( true ) {

         Job elect = new Job(new Configuration(getConf()));
         Job partition = new Job(new Configuration(getConf()));
         elect.setJarByClass(getClass());
         partition.setJarByClass(getClass());

         // Stitch together paths
         // partition.n => elect => elect.(n+1) => partition => partition.(n+1)
         elect.setInputDir( partition_path + (iteration++) );
         elect.setOutputDir( elect_path + iteration );
         partition.setInputDir( elect_path + iteration );
         partition.setOutputDir( partition_path + iteration );

         elect.setName("Union Find (elect ["+iteration+"])" );
         elect.setMapperClass(ElectMap.class);
         elect.setReducerClass(ElectReduce.class);
         elect.setOutputKeyClass(Text.class);
         elect.setOutputValueClass(TextArrayWritable.class);

         partition.setName("Union Find (partition ["+iteration+"])" );
         partition.MapperClass(PartitionMap.class);
         partition.setReducerClass(PartitionReduce.class);
         partition.setOutputKeyClass(Text.class);
         partition.setOutputValueClass(TextArrayWritable.class);

         elect.waitForCompletion(false);
         if( !elect.isSuccessful() )
            throw new RuntimeError();

         // All the sets were disjoint.  No more work to do.
         // Otherwise, run partition and repeat.
         if( elect.getCounters().findCounter(Counter.OPEN).getValue() == 0 )
            break;
         else
            partition.waitForCompletion(false);
      }

      // Collect all the disjoint values.
      Job emit = new Job(getConf());
      emit.setName("Union Find (emit)" );
      emit.setMapperClass(EmitDisjointMap.class);
      emit.setNumReduceTasks(0);
      emit.setOutputKeyClass(Text.class);
      emit.setOutputValueClass(TextArrayWritable.class);
      emit.setInputDir(partition_path + '*');
      emit.waitForCompletion(true);

      return emit.isSuccessful() ? 0 : 1;
   }
   
   public static void main(String[] args) throws Exception
   {
      int result = ToolRunner.run(new UnionFind(), args);
      System.exit(result);
      
   }

   public static class TextArrayWritable extends ArrayWritable {
      public TextArrayWritable() {
         super(Text.class);
      }
      public TextArrayWritable(Text... elements) {
         super(elements);
      }
      public TextArrayWritable(Collection<Text> elements {
         super( elements.toArray(new Text[0]);
      }
      public Text[] get() {
         Writable[] writables = super.get();
         Text[] texts = new Text[writables.length];
         for(int i=0; i<writables.length; ++i)
            texts[i] = (Text)writables[i];
         return texts;
      }
   }
}
About these ads

About Chase Bradford
I'm a simple software developer interested in anything that requires solving big problems faster. This blog is where I post things that others might find useful.

5 Responses to MapReduce Implementation for Union-Find

  1. Raj says:

    Very nice, thank you!

  2. Hey chase,

    great post, great alg. However, I had a hard time understanding it until I pictured it to myself visually, I would like to take a stab at reiterating your explanation, maybe someone would benefit from this:

    what we would like to do is:
    for each set we define its representative R as the smallest node. E.g. in {a,b,c} it’s “a”.
    we now draw edges from each representative to everybody else in the set and also vice-versa.
    let’s also color them red if their source node points to more than one node, and green if they’re the only edge originating from it:
    a->{a,b,c} (red)
    b->a (green)
    c->a (green)

    now, if this set is disjoint we’ll have exactly 2 reds vs. 2 greens.
    if the set intersect with some other set, we’ll use the magic of MR to detect these cases:
    * there’s some other set with “a” as a rep: they will be merged cleanly, we will draw red edges from rep->everybody in both sets.
    * there’s some other set that includes “b”, (either as a rep or a simple node of the other set): “b” would now have another edge originating from it. This will turn the edge b->a to red, as its no longer the only edge originating from “b”.

    note that in the second case, “b” destroys the balance of 2:2 red vs. green edges in our set. It also creates a new set that contains “a”. We would do another iteration, with the new sets. “a” would be chosen as a rep of that new set, and we’ll probably merge them cleanly.

    now for the MR gymnastics: in the first MR job, we draw red and green edges. In the second one we detect the ones which are still green and reverse them, just so we can count how many were left.

    HTH anyone

  3. Yet another question: Another search for MR connected-components algs dug up this one: http://blog.piccolboni.info/2010/07/map-reduce-algorithm-for-connected.html
    looks different than the one you came up with.. did you happen to try it or compare performance with your version?
    Thanks!

  4. jaehong.choi says:

    Nice implementation! thank you!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

Follow

Get every new post delivered to your Inbox.

%d bloggers like this: