Einsteinってどんなもの?Heroku+PredictionIOを使って機械学習をわかった気になってみよう!(第2回)

はじめに

みなさん、こんにちは。

第1回では、HerokuにPredictionIOの実行環境を構築しました。第2回では、Salesforceのデータを一括取得してインポートするJavaアプリケーションの開発と、インポートした学習データを利用してEngineのトレーニングを行うところまでを試してみたいと思います。

必要なもの

  • Salesforce Developer Editionなど
  • Java開発環境(JDK、Eclipseなど)

Salesforceのカスタマイズ

顧客、ゲーム、お気に入りの3つのカスタムオブジェクトを作成します。(カスタムオブジェクトをMetaデータから作成されたい方はこちらから)

顧客オブジェクト:Customer__c

No. ラベル API参照 データ型 文字数 参照先 必須 数式 補足
1 顧客名 Name string 80        
2 顧客No. CustomerNumber__c string 30     自動採番"C-{000000}"

ゲームオブジェクト:Game__c

No. ラベル API参照 データ型 文字数 参照先 必須 数式 補足
1 ジャンル Genre__c string 255        
2 メーカー Maker__c string 255        
3 画像URL ImageURL__c string 255        
4 イメージ Image__c string 1300     IMAGE( ImageURL__c , "ゲーム画像" , 400, 400)  
5 ゲームNo. GameNumber__c string 30     自動採番"G-{000000}"

お気に入りオブジェクト:Favorite__c

No. ラベル API参照 データ型 文字数 参照先 必須 数式 補足
1 お気に入りNo. Name string 80     自動採番"F-{000000}"
2 顧客 Customer__c reference 18 Customer__c      
3 ゲーム Game__c reference 18 Game__c      
4 ゲームNo. GameNumber__c string 1300     Game__r.GameNumber__c  
5 顧客No. CustomerNumber__c string 1300     Customer__r.CustomerNumber__c  
6 イメージ Image__c string 1300     Game__r.Image__c  
7 ジャンル Genre__c string 1300     Game__r.Genre__c  

作成したオブジェクトにデータを入れておきます。(サンプルデータをDataLoader等でインポートされたい方はこちらから)

Javaアプリケーション

JavaアプリケーションはSalesforceのSOAP(Partner API)を使ってSalesforceからデータを取得する機能、取得したデータをPredictionIO Java SDKを使ってEvent Serverに対してHTTPリクエストする機能(Event Serverはリクエストを受け付けるとPostgreSQLに登録します)、これらをまとめて実行するメイン機能の3つで構成します。

作成するクラス

PIODataImportCmdLineSmple.java mainメソッドを持つアプリケーション
SfdcService.java SalesforceのPartner APIを操作しデータを取得する
PIODataImportService.java PredictionIOに対してEventデータをリクエストする

Salesforceからデータを取得

Salesforceが提供するSOAP APIへのアクセスは、「Force.com Web Services Connector(WSC)」を使うことにより簡単に行えます。Javaのサンプルコードを含む手順は、Developerforceのサイトが参考になります。

1.SalesfoceからWSDLのダウンロード

Salesforceの 設定 > 開発 > API のページから、Partner WSDLをダウンロードし「partner.wsdl」という名前で保存します。

2.WSCのjarファイルをダウンロード

最新はMaven Repositoryから取得できます。(執筆時点はver 38.0.4)

3.jarファイル(SOAP APIのスタブ)をPartner WSDLとWSCを使って自動生成

force-wsc-38.0.4.jar、partner.wsdlを同一フォルダに置き、そのフォルダに移動して以下のコマンドを実行します。

$ java -classpath force-wsc-38.0.4.jar com.sforce.ws.tools.wsdlc partner.wsdl partner.jar

本家サイトでは上記コマンドでしたが、私の環境では依存ライブラリが不足しているエラーが出ましたので、String Template 4のjarファイルをダウンロードし、classpathに追加して実行しました。(執筆時点の最新ver4.0.8)

$ java -classpath ST-4.0.8.jar:force-wsc-38.0.4.jar com.sforce.ws.tools.wsdlc partner.wsdl partner.jar

ダウンロードしたWSCのjarファイルおよび出来上がったpartner.jarをBuild Pathに設定し、Salesforceからデータを取得するコードを記述します。

ソースコード:SfdcService.java

/**
 * SalesforceのデータをPartner APIを利用して取得するサービスクラスです
 */
public class SfdcService {
	private static final String USER_NAME = "<YOUR_LOGIN_ID>";
	private static final String USER_PASSWORD = "<YOUR_PASSWORD>";
	private static final String AUTH_END_POINT = "https://login.salesforce.com/services/Soap/u/38.0";
	private static final Integer QUERY_BATCH_SIZE = 2000;

	private PartnerConnection connection = null;

	/**
	 * コンストラクタ
	 * @param
	 * @return SfdcService
	 */
	private SfdcService(){
		login();
	}
	
	/**
	 * SfdcServiceを使用します
	 * @param
	 * @return SfdcService
	 */
	public static SfdcService use(){
		return new SfdcService();
	}
	
	/**
	 * Salesforceにログインします
	 * @param
	 * @return
	 */
	private void login(){
		ConnectorConfig config = new ConnectorConfig();
		config.setUsername(USER_NAME);
		config.setPassword(USER_PASSWORD);
		config.setAuthEndpoint(AUTH_END_POINT);
		try {
			connection = new PartnerConnection(config);
		} catch (ConnectionException e) {
			e.printStackTrace();
		}
	}
	
	/**
	 * 顧客レコードを取得します
	 * @param
	 * @return
	 */
	public List<customer> getCustomers(){
		List<customer> customers = new ArrayList<customer>();
		try{
			connection.setQueryOptions(QUERY_BATCH_SIZE);
			String soql = "Select Id, Name, CustomerNumber__c From Customer__c Order by CustomerNumber__c";
			QueryResult qr = connection.query(soql);
			boolean done = false;
			while(!done){
				SObject[] records = qr.getRecords();
				for(int i = 0; i < records.length; i++){
					SObject sObj = records[i];
					Customer customer = new Customer();
					customer.setName(String.valueOf(sObj.getField("Name")));
					customer.setCustomerNumber(String.valueOf(sObj.getField("CustomerNumber__c")));
					customers.add(customer);
				}
				if(qr.isDone()){
					done = true;
				}else{
					qr = connection.queryMore(qr.getQueryLocator());
				}
			}
		}catch(ConnectionException ce){
			ce.printStackTrace();
		}
		return customers;
	}
	
	/**
	 * ゲームレコードを取得します
	 * @param
	 * @return
	 */
	public List<game> getGames(){
		List<game> games = new ArrayList<game>();
		try{
			String soql = "Select Id, Name, GameNumber__c, Genre__c, Maker__c, Image__c From Game__c Order by GameNumber__c";
			QueryResult qr = connection.query(soql);
			boolean done = false;
			while(!done){
				SObject[] records = qr.getRecords();
				for(int i = 0; i  < records.length; i++){
					SObject sObj = records[i];
					Game game = new Game();
					game.setName(String.valueOf(sObj.getField("Name")));
					game.setGameNumber(String.valueOf(sObj.getField("GameNumber__c")));
					game.setGenre(String.valueOf(sObj.getField("Genre__c")));
					games.add(game);
				}
				if(qr.isDone()){
					done = true;
				}else{
					qr = connection.queryMore(qr.getQueryLocator());
				}
			}
		}catch(ConnectionException ce){
			ce.printStackTrace();
		}
		return games;
	}
	
	/**
	 * お気に入りレコードを取得します
	 * @param
	 * @return
	 */
	public List<favorite> getFavorites(){
		List<favorite> favorites = new ArrayList<favorite>();
		try{
			String soql = "Select Id, Name, Customer__c, CustomerNumber__c, Game__c, GameNumber__c From Favorite__c Order by Name";
			QueryResult qr = connection.query(soql);
			boolean done = false;
			while(!done){
				SObject[] records = qr.getRecords();
				for(int i = 0; i < records.length; i++){
					SObject sObj = records[i];
					Favorite favorite = new Favorite();
					favorite.setFavoriteNumber(String.valueOf(sObj.getField("Name")));
					favorite.setCustomerNumber(String.valueOf(sObj.getField("CustomerNumber__c")));
					favorite.setGameNumber(String.valueOf(sObj.getField("GameNumber__c")));
					favorites.add(favorite);
				}
				if(qr.isDone()){
					done = true;
				}else{
					qr = connection.queryMore(qr.getQueryLocator());
				}
			}
		}catch(ConnectionException ce){
			ce.printStackTrace();
		}
		return favorites;
	}

	/**
	 * 顧客インナークラス
	 */
	public class Customer{
		private String name;
		private String customerNumber;
		public String getName() {
			return name;
		}
		public void setName(String name) {
			this.name = name;
		}
		public String getCustomerNumber() {
			return customerNumber;
		}
		public void setCustomerNumber(String customerNumber) {
			this.customerNumber = customerNumber;
		}
	}
	
	/**
	 * ゲームインナークラス
	 */
	public class Game{
		private String name;
		private String genre;
		private String gameNumber;
		public String getName() {
			return name;
		}
		public void setName(String name) {
			this.name = name;
		}
		public String getGenre() {
			return genre;
		}
		public void setGenre(String genre) {
			this.genre = genre;
		}
		public String getGameNumber() {
			return gameNumber;
		}
		public void setGameNumber(String gameNumber) {
			this.gameNumber = gameNumber;
		}
	}

	/**
	 * お気に入りインナークラス
	 */
	public class Favorite{
		private String favoriteNumber;
		private String customerNumber;
		private String gameNumber;
		public String getFavoriteNumber() {
			return favoriteNumber;
		}
		public void setFavoriteNumber(String favoriteNumber) {
			this.favoriteNumber = favoriteNumber;
		}
		public String getCustomerNumber() {
			return customerNumber;
		}
		public void setCustomerNumber(String customerNumber) {
			this.customerNumber = customerNumber;
		}
		public String getGameNumber() {
			return gameNumber;
		}
		public void setGameNumber(String gameNumber) {
			this.gameNumber = gameNumber;
		} 
	}
}

Event Serverにデータをリクエスト

Event Serverにデータをリクストする機能の作成にあたり、PredictionIO Java SDKおよび依存するライブラリをダウンロードし、Build pathに登録しておきます。

執筆時点で使用したjarファイル

ソースコード:PIODataImportService.java

/**
 * PredictionIOのEventServerにEventを送信するサービスクラスです
 */
public class PIODataImportService {
	private static final String APP_URL = "https://<YOUR_EVENT_SERVER_NAME>.herokuapp.com";
	private static final String ACCESS_KEY = "<ACCESS_KEY>";

	/**
	 * 顧客レコードを追加します
	 * @param
	 * @return
	 */
	public void addCustomers(){
		EventClient client = new EventClient(ACCESS_KEY, APP_URL);
		List<FutureAPIResponse> futureAPIResponses = new ArrayList<>();
		
		try {
			for(SfdcService.Customer customer : SfdcService.use().getCustomers()){
				Map<String, Object> emptyUserProperties = new HashMap<String, Object>();
				FutureAPIResponse future = client.setUserAsFuture(customer.getCustomerNumber(), emptyUserProperties);
				futureAPIResponses.add(future);
				Futures.addCallback(future.getAPIResponse(), getFutureCallback("user " + customer.getCustomerNumber()));
			}
		} catch (IOException e) {
			e.printStackTrace();
		}finally{
			client.close();
		}
	}
	
	/**
	 * ゲームレコードを追加します
	 * @param
	 * @return
	 */
	public void addGames(){
		EventClient client = new EventClient(ACCESS_KEY, APP_URL);
		List<FutureAPIResponse> futureAPIResponses = new ArrayList<>();
		
		try {
			for(SfdcService.Game game : SfdcService.use().getGames()){
				Map<String, Object> itemProperties = new HashMap<String, Object>();
				List<String> genres = new ArrayList<String>();
				genres.add(game.getGenre());
				itemProperties.put("genre", genres);
				FutureAPIResponse future = client.setItemAsFuture(game.getGameNumber(), itemProperties);
				futureAPIResponses.add(future);
				Futures.addCallback(future.getAPIResponse(), getFutureCallback("item " + game.getGameNumber()));
			}
		} catch (IOException e) {
			e.printStackTrace();
		}finally{
			client.close();
		}
	}
	
	/**
	 * お気に入りレコードを追加します
	 * @param
	 * @return
	 */
	public void addFavorites(){
		EventClient client = new EventClient(ACCESS_KEY, APP_URL);
		List<FutureAPIResponse> futureAPIResponses = new ArrayList<>();
		try {
			for(SfdcService.Favorite favorite : SfdcService.use().getFavorites()){
				Map<String, Object> emptyActionItemProperties = new HashMap<String, Object>();
				FutureAPIResponse future = client.userActionItemAsFuture("view", favorite.getCustomerNumber(), favorite.getGameNumber(), emptyActionItemProperties);
				futureAPIResponses.add(future);
				Futures.addCallback(future.getAPIResponse(), getFutureCallback("actionItem customer " + favorite.getCustomerNumber() + " game " + favorite.getGameNumber()));
			}
		} catch (IOException e) {
			e.printStackTrace();
		}finally{
			client.close();
		}
	}

	/**
	 * EventServerとのコールバック処理です
	 * @param name
	 * @return FutureCallback<APIResponse>
	 */
	private FutureCallback<APIResponse> getFutureCallback(final String name) {
		return new FutureCallback<APIResponse>(){
			@Override
			public void onSuccess(APIResponse response) {
				System.out.println(name + " added: " + response.getMessage());
			}

			@Override
			public void onFailure(Throwable thrown) {
				System.out.println("failed to add " + name + ": " + thrown.getMessage());
			}
		};
	}
}

PIODataImportServiceでは addCustomers()、addGames()、addFavorites()の各メソッドでSalesforceの該当オブジェクトからレコードを取得し、EventClientを使用してリクエストするJsonを組み立て、非同期でEvent Serverに対してリクエストを送信しています。Event Serverからのcall backはgetFutureCallback()メソッドで処理しています。

Event Serverに対するリクエストは以下のURLに対してJsonをPOSTしています。

http://<YOUR_EVENT_SERVER_NAME>.herokuapp.com/events.json?accessKey=<ACCESS_KEY> 

Event Serverには以下のようなJsonがリクエストされます。

顧客を登録するEventデータ
{ 
  "event" : "$set", 
  "entityType" : "user", 
  "entityId" : "<顧客No.>" 
}

ゲームを登録するEventデータ
{ 
  "event" : "$set", 
  "entityType" : "item", 
  "entityId" : "<ゲームNo.>", 
  "properties" : { "genre" : ["<ジャンル>"] } 
}

お気に入りを登録するEventデータ
{ 
  "event" : "view", 
  "entityType" : "user", 
  "entityId" : "<顧客No.>",  
  "targetEntityType" : "item", 
  "targetEntityId" : "<ゲームNo.>" 
}

メイン機能

Salesforceからのデータ一括取得&Event Serverへのリクエストの一連を動かすためのクラスを作成します。

ソースコード:PIOCmdLineSampleApp.java

/**
 * PredictionIOにSalesforceのデータをインンポートするコマンドラインアプリ
 */
public class PIOCmdLineSampleApp {
	/**
	 * mainメソッドです
	 * @param args
	 * @return
	 */
	public static void main(String[] args) {
		PIODataImportService service = new PIODataImportService();

		service.addCustomers();
		System.out.println("--- User import done ---");
		
		service.addGames();
		System.out.println("--- Item import done ---");
		
		service.addFavorites();
		System.out.println("--- UserActionItem import done ---");
	}
}

このクラスは単純にPIODataImportServeのインスタンスを生成し、addCustomers()、addGames()、addFavorites()メソッドを呼び出すだけです。

Eclipse等からアプリを実行し、Event Serverへのリクエストに成功すると以下のようなログが確認できます。

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
user C-000001 added: {"eventId":"1ee69b1df5274e7a9a94728715ceb607"}
user C-000002 added: {"eventId":"5a60733cc9ee4bf5a61e15b001e74ac7"}
user C-000003 added: {"eventId":"f83812910c9c4ab5a322c2e8dd7b3523"}
user C-000004 added: {"eventId":"c1fb45618e094e13bfeaf9f0b4549680"}
--- User import done ---
item G-000001 added: {"eventId":"876646148a5d494795cc765d4e2bc284"}
item G-000002 added: {"eventId":"c8dd1510edd74fd8aecf3479a9abe488"}
item G-000003 added: {"eventId":"dc71d6fca0a94cc4b8ad8840ec861577"}
item G-000004 added: {"eventId":"78cb97774dae491b9aa9537a493d374e"}
item G-000005 added: {"eventId":"1cbd5f6a106143668ccb700b696e595b"}
item G-000006 added: {"eventId":"9d6e855a82d2411b871d8307c13e1ec2"}
--- Item import done ---
actionItem customer C-000001 game G-000001 added: {"eventId":"38f483d623514df7be33a79e350d832c"}
actionItem customer C-000001 game G-000002 added: {"eventId":"8fc2c51ba45f4111b9a0dcf697a212ce"}
actionItem customer C-000001 game G-000003 added: {"eventId":"14715aed4a1a46d48c8b35806e0bf991"}
actionItem customer C-000002 game G-000004 added: {"eventId":"9e7f5dc56c11469b8628f50122629e72"}
actionItem customer C-000002 game G-000005 added: {"eventId":"c86ac59a935e4be4a8b3ad58d26f925b"}
--- UserActionItem import done ---

ブラウザからEventデータの登録結果を確認

ブラウザから以下のURLを入力して確認します。

http://<YOUR_EVENT_SERVER_NAME>.herokuapp.com/events.json?accessKey=<ACCESS_KEY>&limit=10

Eventデータが正しく登録されていると以下のような結果が確認できます。

Engineのトレーニング

Eventデータが無事に登録されましたので、いよいよEngineのトレーニングです。

Herokuのフリープランの範囲でトレーニングを実行したいので、dynoのスケールダウンを設定します。

$ heroku ps:scale web=0 train=0

Engineのトレーニングを実行します。

$ heroku run train

トレーニングを行うと大量のログが出力されますが、トレーニングに成功すると以下のようなログが確認できます。

[INFO] [Engine$] ALSModel does not support data sanity check. Skipping check.
[INFO] [Engine$] EngineWorkflow.train completed
[INFO] [Engine] engineInstanceId=95acbab2-863a-47b3-b59d-e610c0f5ff8e
[INFO] [CoreWorkflow$] Inserting persistent model
[INFO] [CoreWorkflow$] Updating engine instance
[INFO] [CoreWorkflow$] Training completed successfully.

dynoの設定を戻します。

$ heroku ps:scale web=1 train=0

ブラウザからEngineの状態を確認します。

http://<YOUR_ENGINE_NAME>.herokuapp.com

トレーニングの開始・終了日時やData SourceとしてEvent Serverに作成したアプリケーションと関連付いていることなどが確認できます。

第3回に続く

EngineにSalesforceのデータをインポートし、トレーニングを行いました。ここまでの設定により、リコメンドができる状態になっています。第3回では最終的な仕上げとして、Salesforceの画面からゲームのリコメンドを受け取り、お勧めゲームの表示を行うところを作成します。

すべてのソースコード・サンプルデータのリポジトリ