Pages

Saturday, July 23, 2011

Decision trees with scala-recog: a machine learning library in Scala

scala-recog is a project I created on Google code, a library of machine learning algorithms written in Scala. A feature is the implementation of ID3 algorithm: it allows to build a decision tree based on a training set of known elements.

Example: classifying curricula for a job site

In a job site you need to classify curricula to show to the logged users and you can extract a huge quantity of data, but let you extract some of them:
  • if the employee has some certifications
  • if the guy is talkative
  • (s)he joined a golf club
  • (s)he earned a master degree
and put all in a class:
  case class Person(
                val hasCertifications : Boolean,
                val isTalkative : Boolean,
                val golfClub : Boolean,
                val hasMasterDegree : Boolean,
                val job : String
               )

  val persons = Person(hasCertifications = true, isTalkative = false, 
                       golfClub = false, hasMasterDegree = true, 
                       job = "Programmer") ::
                Person(hasCertifications = false, isTalkative = false, 
                       golfClub = false, hasMasterDegree = true, 
                       job = "Junior Programmer") ::
                Person(hasCertifications = true, isTalkative = false, 
                       golfClub = false, hasMasterDegree = false, 
                       job = "Programmer") ::
                Person(hasCertifications = false, isTalkative = true, 
                       golfClub = false, hasMasterDegree = true, 
                       job = "Seller") ::
                Person(hasCertifications = false, isTalkative = true, 
                       golfClub = false, hasMasterDegree = false, 
                       job = "Seller") ::
                Person(hasCertifications = true, isTalkative = true, 
                       golfClub = false, hasMasterDegree = false, 
                       job = "Seller") ::
                Person(hasCertifications = false, isTalkative = true, 
                       golfClub = true, hasMasterDegree = true, 
                       job = "CEO") ::
                Person(hasCertifications = false, isTalkative = false, 
                       golfClub = true, hasMasterDegree = false, 
                       job = "CEO") ::
                Person(hasCertifications = false, isTalkative = false, 
                       golfClub = true, hasMasterDegree = false, 
                       job = "CEO") ::
                Nil
In the list persons I put my training set. To use the ID3 algorithm, just import the right object and use the list for training:
import org.scalarecog.decisiontree._

def toVector(p : Person) = Vector(p.hasCertifications, p.isTalkative, p.golfClub, p.hasMasterDegree)
val dataset = persons map (p => (toVector(p), p.job))

val tree = new ID3[Boolean,String] buildTree dataset
The reason for toVector is that the ID3 class needs a Vector.
Now tree can classify a person:

val newPerson = Person(false, false, false, false, "?")
assert(
   tree.classify(toVector(newPerson)) == "Junior Programmer"
)

But it would be fine to see the decision tree created by ID3. With JGraph it's straightforward, and i get this:
Hey, it's the real life! ^_^
Here is the full code:

package scalarecoggraph

import org.scalarecog.decisiontree._
import javax.swing.JFrame
import com.mxgraph.swing.mxGraphComponent
import com.mxgraph.view.mxGraph

class Program(tree : DecisionTree[Vector[Boolean], String], propertyNames : Vector[String]) extends JFrame("ScalaRecog") {
  type Tree = DecisionTree[Vector[Boolean], String]
  type Vertex = (AnyRef, (Double, Double))

  draw()

  def draw() {
    val graph: mxGraph = new mxGraph
    val root = graph.getDefaultParent

    def draw(t : Tree, parentPos : (Double, Double), offset : (Int, Int)) : Vertex = {
      def createVertex(label : String, action : Vertex => Unit = v => {}) : Vertex = {
        val vertexSize = (100, 30)
        val newPos = (parentPos._1 + offset._1, parentPos._2 + offset._2)
        val created = (graph.insertVertex(root, null, label, newPos._1, newPos._2 , vertexSize._1, vertexSize._2), newPos)
        action(created)
        created
      }
      def createEdge(label : String, from : Vertex, to : Vertex) = graph.insertEdge(root, null, label, from._1, to._1)

      t match {
        case a : DecisionLeaf[Vector[Boolean],String] => createVertex(a.label)
        case a : DecisionBranchVector[String,Boolean] =>
          createVertex(propertyNames(a.index), n => {
            for (  ((label, child), index) <- a.branches.zipWithIndex  )
              createEdge(label.toString, n, draw(child, n._2, (120*index, offset._2)))
            })
      }
    }

    graph.getModel.beginUpdate
    try {
      draw(tree, (0, 0), (120, 120))
    }
    finally {
      graph.getModel.endUpdate
    }
    getContentPane.add(new mxGraphComponent(graph))
  }
}

object Program  {

  case class Person(
                val hasCertifications : Boolean,
                val isTalkative : Boolean,
                val golfClub : Boolean,
                val hasMasterDegree : Boolean,
                val job : String
               )

  def main(args : Array[String]) : Unit = {

    val persons = Person(hasCertifications = true, isTalkative = false, golfClub = false, hasMasterDegree = true, job = "Programmer") ::
                  Person(hasCertifications = false, isTalkative = false, golfClub = false, hasMasterDegree = true, job = "Junior Programmer") ::
                  Person(hasCertifications = true, isTalkative = false, golfClub = false, hasMasterDegree = false, job = "Programmer") ::
                  Person(hasCertifications = false, isTalkative = true, golfClub = false, hasMasterDegree = true, job = "Seller") ::
                  Person(hasCertifications = false, isTalkative = true, golfClub = false, hasMasterDegree = false, job = "Seller") ::
                  Person(hasCertifications = true, isTalkative = true, golfClub = false, hasMasterDegree = false, job = "Seller") ::
                  Person(hasCertifications = false, isTalkative = true, golfClub = true, hasMasterDegree = true, job = "CEO") ::
                  Person(hasCertifications = false, isTalkative = false, golfClub = true, hasMasterDegree = false, job = "CEO") ::
                  Person(hasCertifications = false, isTalkative = false, golfClub = true, hasMasterDegree = false, job = "CEO") ::
                  Nil

    def toVector(p : Person) = Vector(p.hasCertifications, p.isTalkative, p.golfClub, p.hasMasterDegree)
    val dataset = persons map (p => (toVector(p), p.job))

    val tree = new ID3[Boolean,String] buildTree dataset

    val newPerson = Person(false, false, false, false, "?")
    assert(
      tree.classify(toVector(newPerson)) == "Junior Programmer"
    )

    val frame = new Program(tree, Vector("Has certifications?", "Is talkative?", "Likes playing golf?", "Has a master degree?"))
    frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE)
    frame.setSize(400, 320)
    frame.setVisible(true)
  }
}

Enjoy!

Friday, July 15, 2011

static should be deprecated

In a tweet I jokingly proposed to abolish the static keyword in all (garbage collected) languages. At the beginning it was just a joke, but then i wonder:

can I get rid of static?
First of all: why get rid of static? Because it's a source of troubles.

Factory

We have a class and a harmless static method to build it up:

public class DumbClass {

  private String foo;
  public String getFoo() { return foo; }
  public void setFoo(String foo) { this.foo = foo; }

  private String initializedStuff;
  public String getInitializedStuff() { return initializedStuff; }

  private DumbClass(String foo, String initializedStuff) {
    this.initializedStuff = initializedStuff;
    this.foo = foo;
  }

  public static DumbClass create(Context c) {
    return new DumbClass( c.doSomethingToGetFoo(), c.doSomethingToGetStuff() );
  }
}

Requirements always change, Alan Shalloway docet, so a customer wants foo in CamelCase rather than UPPERCASE. Case must not change for all other customers, so I would say:
Easy to do: I will create a new library for that customer and put a class extending.. OOps, I can't! It's a static method. What if I had an abstract factory, instead?

public class DumbClass {

  private String foo;
  public String getFoo() { return foo; }
  public void setFoo(String foo) { this.foo = foo; }

  private String initializedStuff;
  public String getInitializedStuff() { return initializedStuff; }

  DumbClass(String foo, String initializedStuff) {
    this.initializedStuff = initializedStuff;
    this.foo = foo;
  }
}

public interface IDumbClassFactory {
  DumbClass create(Context c);
}

public class DefaultDumbClassFactory {
  public DumbClass create(Context c) {
    return new DumbClass( c.doSomethingToGetFoo(), c.doSomethingToGetStuff() );
  }
}

Then I would make a wrapper of DefaultDumbClassFactory or I would make another implementation of IDumbClassFactory: it would work!
The moral of the story: don't use static methods, especially public ones!

Private fields

Sometimes you want keep a counter of instances of a certain class to know how many of them you have in memory. Ok, I would do it with a static field:

public class Foo {

  private static int howMany;

  public Foo(Something s) {
    howMany++;
  }
  ...
}

After a year you find out that nobody reads that counter anymore and, once deployed your library on a web server running continuously for two years, it could get negative (integer overflow) or cause an exception, unhandled, of course. Or you will need that it doesn't trace derived class or you will need the counter to be saved in a database and then? You find out that you need an abstract factory.
Now: when I'm going to write static, I stop, think it over and finally I nearly alwasy I write: interface

Wolfsburg VW-Werk