Einsteinってどんなもの?Heroku+PredictionIOを使って機械学習をわかった気になってみよう!(第2回)

  • Posted on
  • カテゴリ:

はじめに

みなさん、こんにちは。

第1回では、HerokuにPredictionIOの実行環境を構築しました。第2回では、Salesforceのデータを一括取得してインポートするJavaアプリケーションの開発と、インポートした学習データを利用してEngineのトレーニングを行うところまでを試してみたいと思います。

必要なもの

  • Salesforce Developer Editionなど
  • Java開発環境(JDK、Eclipseなど)

Salesforceのカスタマイズ

顧客、ゲーム、お気に入りの3つのカスタムオブジェクトを作成します。(カスタムオブジェクトをMetaデータから作成されたい方はこちらから)

顧客オブジェクト:Customer__c

No. ラベル API参照 データ型 文字数 参照先 必須 数式 補足
1 顧客名 Name string 80        
2 顧客No. CustomerNumber__c string 30     自動採番"C-{000000}"

ゲームオブジェクト:Game__c

No. ラベル API参照 データ型 文字数 参照先 必須 数式 補足
1 ジャンル Genre__c string 255        
2 メーカー Maker__c string 255        
3 画像URL ImageURL__c string 255        
4 イメージ Image__c string 1300     IMAGE( ImageURL__c , "ゲーム画像" , 400, 400)  
5 ゲームNo. GameNumber__c string 30     自動採番"G-{000000}"

お気に入りオブジェクト:Favorite__c

No. ラベル API参照 データ型 文字数 参照先 必須 数式 補足
1 お気に入りNo. Name string 80     自動採番"F-{000000}"
2 顧客 Customer__c reference 18 Customer__c      
3 ゲーム Game__c reference 18 Game__c      
4 ゲームNo. GameNumber__c string 1300     Game__r.GameNumber__c  
5 顧客No. CustomerNumber__c string 1300     Customer__r.CustomerNumber__c  
6 イメージ Image__c string 1300     Game__r.Image__c  
7 ジャンル Genre__c string 1300     Game__r.Genre__c  

作成したオブジェクトにデータを入れておきます。(サンプルデータをDataLoader等でインポートされたい方はこちらから)

Javaアプリケーション

JavaアプリケーションはSalesforceのSOAP(Partner API)を使ってSalesforceからデータを取得する機能、取得したデータをPredictionIO Java SDKを使ってEvent Serverに対してHTTPリクエストする機能(Event Serverはリクエストを受け付けるとPostgreSQLに登録します)、これらをまとめて実行するメイン機能の3つで構成します。

作成するクラス

PIODataImportCmdLineSmple.java mainメソッドを持つアプリケーション
SfdcService.java SalesforceのPartner APIを操作しデータを取得する
PIODataImportService.java PredictionIOに対してEventデータをリクエストする

Salesforceからデータを取得

Salesforceが提供するSOAP APIへのアクセスは、「Force.com Web Services Connector(WSC)」を使うことにより簡単に行えます。Javaのサンプルコードを含む手順は、Developerforceのサイトが参考になります。

1.SalesfoceからWSDLのダウンロード

Salesforceの 設定 > 開発 > API のページから、Partner WSDLをダウンロードし「partner.wsdl」という名前で保存します。

2.WSCのjarファイルをダウンロード

最新はMaven Repositoryから取得できます。(執筆時点はver 38.0.4)

3.jarファイル(SOAP APIのスタブ)をPartner WSDLとWSCを使って自動生成

force-wsc-38.0.4.jar、partner.wsdlを同一フォルダに置き、そのフォルダに移動して以下のコマンドを実行します。

$ java -classpath force-wsc-38.0.4.jar com.sforce.ws.tools.wsdlc partner.wsdl partner.jar

本家サイトでは上記コマンドでしたが、私の環境では依存ライブラリが不足しているエラーが出ましたので、String Template 4のjarファイルをダウンロードし、classpathに追加して実行しました。(執筆時点の最新ver4.0.8)

$ java -classpath ST-4.0.8.jar:force-wsc-38.0.4.jar com.sforce.ws.tools.wsdlc partner.wsdl partner.jar

ダウンロードしたWSCのjarファイルおよび出来上がったpartner.jarをBuild Pathに設定し、Salesforceからデータを取得するコードを記述します。

ソースコード:SfdcService.java

/**
* SalesforceのデータをPartner APIを利用して取得するサービスクラスです
*/
public class SfdcService {
private static final String USER_NAME = "<YOUR_LOGIN_ID>";
private static final String USER_PASSWORD = "<YOUR_PASSWORD>";
private static final String AUTH_END_POINT = "https://login.salesforce.com/services/Soap/u/38.0";
private static final Integer QUERY_BATCH_SIZE = 2000;
private PartnerConnection connection = null;
/**
* コンストラクタ
* @param
* @return SfdcService
*/
private SfdcService(){
login();
}
/**
* SfdcServiceを使用します
* @param
* @return SfdcService
*/
public static SfdcService use(){
return new SfdcService();
}
/**
* Salesforceにログインします
* @param
* @return
*/
private void login(){
ConnectorConfig config = new ConnectorConfig();
config.setUsername(USER_NAME);
config.setPassword(USER_PASSWORD);
config.setAuthEndpoint(AUTH_END_POINT);
try {
connection = new PartnerConnection(config);
} catch (ConnectionException e) {
e.printStackTrace();
}
}
/**
* 顧客レコードを取得します
* @param
* @return
*/
public List<customer> getCustomers(){
List<customer> customers = new ArrayList<customer>();
try{
connection.setQueryOptions(QUERY_BATCH_SIZE);
String soql = "Select Id, Name, CustomerNumber__c From Customer__c Order by CustomerNumber__c";
QueryResult qr = connection.query(soql);
boolean done = false;
while(!done){
SObject[] records = qr.getRecords();
for(int i = 0; i < records.length; i++){
SObject sObj = records[i];
Customer customer = new Customer();
customer.setName(String.valueOf(sObj.getField("Name")));
customer.setCustomerNumber(String.valueOf(sObj.getField("CustomerNumber__c")));
customers.add(customer);
}
if(qr.isDone()){
done = true;
}else{
qr = connection.queryMore(qr.getQueryLocator());
}
}
}catch(ConnectionException ce){
ce.printStackTrace();
}
return customers;
}
/**
* ゲームレコードを取得します
* @param
* @return
*/
public List<game> getGames(){
List<game> games = new ArrayList<game>();
try{
String soql = "Select Id, Name, GameNumber__c, Genre__c, Maker__c, Image__c From Game__c Order by GameNumber__c";
QueryResult qr = connection.query(soql);
boolean done = false;
while(!done){
SObject[] records = qr.getRecords();
for(int i = 0; i  < records.length; i++){
SObject sObj = records[i];
Game game = new Game();
game.setName(String.valueOf(sObj.getField("Name")));
game.setGameNumber(String.valueOf(sObj.getField("GameNumber__c")));
game.setGenre(String.valueOf(sObj.getField("Genre__c")));
games.add(game);
}
if(qr.isDone()){
done = true;
}else{
qr = connection.queryMore(qr.getQueryLocator());
}
}
}catch(ConnectionException ce){
ce.printStackTrace();
}
return games;
}
/**
* お気に入りレコードを取得します
* @param
* @return
*/
public List<favorite> getFavorites(){
List<favorite> favorites = new ArrayList<favorite>();
try{
String soql = "Select Id, Name, Customer__c, CustomerNumber__c, Game__c, GameNumber__c From Favorite__c Order by Name";
QueryResult qr = connection.query(soql);
boolean done = false;
while(!done){
SObject[] records = qr.getRecords();
for(int i = 0; i < records.length; i++){
SObject sObj = records[i];
Favorite favorite = new Favorite();
favorite.setFavoriteNumber(String.valueOf(sObj.getField("Name")));
favorite.setCustomerNumber(String.valueOf(sObj.getField("CustomerNumber__c")));
favorite.setGameNumber(String.valueOf(sObj.getField("GameNumber__c")));
favorites.add(favorite);
}
if(qr.isDone()){
done = true;
}else{
qr = connection.queryMore(qr.getQueryLocator());
}
}
}catch(ConnectionException ce){
ce.printStackTrace();
}
return favorites;
}
/**
* 顧客インナークラス
*/
public class Customer{
private String name;
private String customerNumber;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getCustomerNumber() {
return customerNumber;
}
public void setCustomerNumber(String customerNumber) {
this.customerNumber = customerNumber;
}
}
/**
* ゲームインナークラス
*/
public class Game{
private String name;
private String genre;
private String gameNumber;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getGenre() {
return genre;
}
public void setGenre(String genre) {
this.genre = genre;
}
public String getGameNumber() {
return gameNumber;
}
public void setGameNumber(String gameNumber) {
this.gameNumber = gameNumber;
}
}
/**
* お気に入りインナークラス
*/
public class Favorite{
private String favoriteNumber;
private String customerNumber;
private String gameNumber;
public String getFavoriteNumber() {
return favoriteNumber;
}
public void setFavoriteNumber(String favoriteNumber) {
this.favoriteNumber = favoriteNumber;
}
public String getCustomerNumber() {
return customerNumber;
}
public void setCustomerNumber(String customerNumber) {
this.customerNumber = customerNumber;
}
public String getGameNumber() {
return gameNumber;
}
public void setGameNumber(String gameNumber) {
this.gameNumber = gameNumber;
} 
}
}

Event Serverにデータをリクエスト

Event Serverにデータをリクストする機能の作成にあたり、PredictionIO Java SDKおよび依存するライブラリをダウンロードし、Build pathに登録しておきます。

執筆時点で使用したjarファイル

ソースコード:PIODataImportService.java

/**
* PredictionIOのEventServerにEventを送信するサービスクラスです
*/
public class PIODataImportService {
private static final String APP_URL = "https://<YOUR_EVENT_SERVER_NAME>.herokuapp.com";
private static final String ACCESS_KEY = "<ACCESS_KEY>";
/**
* 顧客レコードを追加します
* @param
* @return
*/
public void addCustomers(){
EventClient client = new EventClient(ACCESS_KEY, APP_URL);
List<FutureAPIResponse> futureAPIResponses = new ArrayList<>();
try {
for(SfdcService.Customer customer : SfdcService.use().getCustomers()){
Map<String, Object> emptyUserProperties = new HashMap<String, Object>();
FutureAPIResponse future = client.setUserAsFuture(customer.getCustomerNumber(), emptyUserProperties);
futureAPIResponses.add(future);
Futures.addCallback(future.getAPIResponse(), getFutureCallback("user " + customer.getCustomerNumber()));
}
} catch (IOException e) {
e.printStackTrace();
}finally{
client.close();
}
}
/**
* ゲームレコードを追加します
* @param
* @return
*/
public void addGames(){
EventClient client = new EventClient(ACCESS_KEY, APP_URL);
List<FutureAPIResponse> futureAPIResponses = new ArrayList<>();
try {
for(SfdcService.Game game : SfdcService.use().getGames()){
Map<String, Object> itemProperties = new HashMap<String, Object>();
List<String> genres = new ArrayList<String>();
genres.add(game.getGenre());
itemProperties.put("genre", genres);
FutureAPIResponse future = client.setItemAsFuture(game.getGameNumber(), itemProperties);
futureAPIResponses.add(future);
Futures.addCallback(future.getAPIResponse(), getFutureCallback("item " + game.getGameNumber()));
}
} catch (IOException e) {
e.printStackTrace();
}finally{
client.close();
}
}
/**
* お気に入りレコードを追加します
* @param
* @return
*/
public void addFavorites(){
EventClient client = new EventClient(ACCESS_KEY, APP_URL);
List<FutureAPIResponse> futureAPIResponses = new ArrayList<>();
try {
for(SfdcService.Favorite favorite : SfdcService.use().getFavorites()){
Map<String, Object> emptyActionItemProperties = new HashMap<String, Object>();
FutureAPIResponse future = client.userActionItemAsFuture("view", favorite.getCustomerNumber(), favorite.getGameNumber(), emptyActionItemProperties);
futureAPIResponses.add(future);
Futures.addCallback(future.getAPIResponse(), getFutureCallback("actionItem customer " + favorite.getCustomerNumber() + " game " + favorite.getGameNumber()));
}
} catch (IOException e) {
e.printStackTrace();
}finally{
client.close();
}
}
/**
* EventServerとのコールバック処理です
* @param name
* @return FutureCallback<APIResponse>
*/
private FutureCallback<APIResponse> getFutureCallback(final String name) {
return new FutureCallback<APIResponse>(){
@Override
public void onSuccess(APIResponse response) {
System.out.println(name + " added: " + response.getMessage());
}
@Override
public void onFailure(Throwable thrown) {
System.out.println("failed to add " + name + ": " + thrown.getMessage());
}
};
}
}

PIODataImportServiceでは addCustomers()、addGames()、addFavorites()の各メソッドでSalesforceの該当オブジェクトからレコードを取得し、EventClientを使用してリクエストするJsonを組み立て、非同期でEvent Serverに対してリクエストを送信しています。Event Serverからのcall backはgetFutureCallback()メソッドで処理しています。

Event Serverに対するリクエストは以下のURLに対してJsonをPOSTしています。

http://<YOUR_EVENT_SERVER_NAME>.herokuapp.com/events.json?accessKey=<ACCESS_KEY> 

Event Serverには以下のようなJsonがリクエストされます。

顧客を登録するEventデータ
{ 
"event" : "$set", 
"entityType" : "user", 
"entityId" : "<顧客No.>" 
}
ゲームを登録するEventデータ
{ 
"event" : "$set", 
"entityType" : "item", 
"entityId" : "<ゲームNo.>", 
"properties" : { "genre" : ["<ジャンル>"] } 
}
お気に入りを登録するEventデータ
{ 
"event" : "view", 
"entityType" : "user", 
"entityId" : "<顧客No.>",  
"targetEntityType" : "item", 
"targetEntityId" : "<ゲームNo.>" 
}

メイン機能

Salesforceからのデータ一括取得&Event Serverへのリクエストの一連を動かすためのクラスを作成します。

ソースコード:PIOCmdLineSampleApp.java

/**
* PredictionIOにSalesforceのデータをインンポートするコマンドラインアプリ
*/
public class PIOCmdLineSampleApp {
/**
* mainメソッドです
* @param args
* @return
*/
public static void main(String[] args) {
PIODataImportService service = new PIODataImportService();
service.addCustomers();
System.out.println("--- User import done ---");
service.addGames();
System.out.println("--- Item import done ---");
service.addFavorites();
System.out.println("--- UserActionItem import done ---");
}
}

このクラスは単純にPIODataImportServeのインスタンスを生成し、addCustomers()、addGames()、addFavorites()メソッドを呼び出すだけです。

Eclipse等からアプリを実行し、Event Serverへのリクエストに成功すると以下のようなログが確認できます。

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
user C-000001 added: {"eventId":"1ee69b1df5274e7a9a94728715ceb607"}
user C-000002 added: {"eventId":"5a60733cc9ee4bf5a61e15b001e74ac7"}
user C-000003 added: {"eventId":"f83812910c9c4ab5a322c2e8dd7b3523"}
user C-000004 added: {"eventId":"c1fb45618e094e13bfeaf9f0b4549680"}
--- User import done ---
item G-000001 added: {"eventId":"876646148a5d494795cc765d4e2bc284"}
item G-000002 added: {"eventId":"c8dd1510edd74fd8aecf3479a9abe488"}
item G-000003 added: {"eventId":"dc71d6fca0a94cc4b8ad8840ec861577"}
item G-000004 added: {"eventId":"78cb97774dae491b9aa9537a493d374e"}
item G-000005 added: {"eventId":"1cbd5f6a106143668ccb700b696e595b"}
item G-000006 added: {"eventId":"9d6e855a82d2411b871d8307c13e1ec2"}
--- Item import done ---
actionItem customer C-000001 game G-000001 added: {"eventId":"38f483d623514df7be33a79e350d832c"}
actionItem customer C-000001 game G-000002 added: {"eventId":"8fc2c51ba45f4111b9a0dcf697a212ce"}
actionItem customer C-000001 game G-000003 added: {"eventId":"14715aed4a1a46d48c8b35806e0bf991"}
actionItem customer C-000002 game G-000004 added: {"eventId":"9e7f5dc56c11469b8628f50122629e72"}
actionItem customer C-000002 game G-000005 added: {"eventId":"c86ac59a935e4be4a8b3ad58d26f925b"}
--- UserActionItem import done ---

ブラウザからEventデータの登録結果を確認

ブラウザから以下のURLを入力して確認します。

http://<YOUR_EVENT_SERVER_NAME>.herokuapp.com/events.json?accessKey=<ACCESS_KEY>&limit=10

Eventデータが正しく登録されていると以下のような結果が確認できます。

Engineのトレーニング

Eventデータが無事に登録されましたので、いよいよEngineのトレーニングです。

Herokuのフリープランの範囲でトレーニングを実行したいので、dynoのスケールダウンを設定します。

$ heroku ps:scale web=0 train=0

Engineのトレーニングを実行します。

$ heroku run train

トレーニングを行うと大量のログが出力されますが、トレーニングに成功すると以下のようなログが確認できます。

[INFO] [Engine$] ALSModel does not support data sanity check. Skipping check.
[INFO] [Engine$] EngineWorkflow.train completed
[INFO] [Engine] engineInstanceId=95acbab2-863a-47b3-b59d-e610c0f5ff8e
[INFO] [CoreWorkflow$] Inserting persistent model
[INFO] [CoreWorkflow$] Updating engine instance
[INFO] [CoreWorkflow$] Training completed successfully.

dynoの設定を戻します。

$ heroku ps:scale web=1 train=0

ブラウザからEngineの状態を確認します。

http://<YOUR_ENGINE_NAME>.herokuapp.com

トレーニングの開始・終了日時やData SourceとしてEvent Serverに作成したアプリケーションと関連付いていることなどが確認できます。

第3回に続く

EngineにSalesforceのデータをインポートし、トレーニングを行いました。ここまでの設定により、リコメンドができる状態になっています。第3回では最終的な仕上げとして、Salesforceの画面からゲームのリコメンドを受け取り、お勧めゲームの表示を行うところを作成します。

すべてのソースコード・サンプルデータのリポジトリ