// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
	"fmt"
	"log"
	"net/http"
	"os"
	"path/filepath"
	"strings"

	"google.golang.org/api/googleapi"
	prediction "google.golang.org/api/prediction/v1.6"
)

func init() {
	scopes := []string{
		prediction.DevstorageFull_controlScope,
		prediction.DevstorageRead_onlyScope,
		prediction.DevstorageRead_writeScope,
		prediction.PredictionScope,
	}
	registerDemo("prediction", strings.Join(scopes, " "), predictionMain)
}

type predictionType struct {
	api              *prediction.Service
	projectNumber    string
	bucketName       string
	trainingFileName string
	modelName        string
}

// This example demonstrates calling the Prediction API.
// Training data is uploaded to a pre-created Google Cloud Storage Bucket and
// then the Prediction API is called to train a model based on that data.
// After a few minutes, the model should be completely trained and ready
// for prediction. At that point, text is sent to the model and the Prediction
// API attempts to classify the data, and the results are printed out.
//
// To get started, follow the instructions found in the "Hello Prediction!"
// Getting Started Guide located here:
// https://developers.google.com/prediction/docs/hello_world
//
// Example usage:
//   go-api-demo -clientid="my-clientid" -secret="my-secret" prediction
//       my-project-number my-bucket-name my-training-filename my-model-name
//
// Example output:
//   Predict result: language=Spanish
//   English Score: 0.000000
//   French Score: 0.000000
//   Spanish Score: 1.000000
//   analyze: output feature text=&{157 English}
//   analyze: output feature text=&{149 French}
//   analyze: output feature text=&{100 Spanish}
//   feature text count=406
func predictionMain(client *http.Client, argv []string) {
	if len(argv) != 4 {
		fmt.Fprintln(os.Stderr,
			"Usage: prediction project_number bucket training_data model_name")
		return
	}

	api, err := prediction.New(client)
	if err != nil {
		log.Fatalf("unable to create prediction API client: %v", err)
	}

	t := &predictionType{
		api:              api,
		projectNumber:    argv[0],
		bucketName:       argv[1],
		trainingFileName: argv[2],
		modelName:        argv[3],
	}

	t.trainModel()
	t.predictModel()
}

func (t *predictionType) trainModel() {
	// First, check to see if our trained model already exists.
	res, err := t.api.Trainedmodels.Get(t.projectNumber, t.modelName).Do()
	if err != nil {
		if ae, ok := err.(*googleapi.Error); ok && ae.Code != http.StatusNotFound {
			log.Fatalf("error getting trained model: %v", err)
		}
		log.Printf("Training model not found, creating new model.")
		res, err = t.api.Trainedmodels.Insert(t.projectNumber, &prediction.Insert{
			Id:                  t.modelName,
			StorageDataLocation: filepath.Join(t.bucketName, t.trainingFileName),
		}).Do()
		if err != nil {
			log.Fatalf("unable to create trained model: %v", err)
		}
	}
	if res.TrainingStatus != "DONE" {
		// Wait for the trained model to finish training.
		fmt.Printf("Training model. Please wait and re-run program after a few minutes.")
		os.Exit(0)
	}
}

func (t *predictionType) predictModel() {
	// Model has now been trained.  Predict with it.
	input := &prediction.Input{
		&prediction.InputInput{
			[]interface{}{
				"Hola, con quien hablo",
			},
		},
	}
	res, err := t.api.Trainedmodels.Predict(t.projectNumber, t.modelName, input).Do()
	if err != nil {
		log.Fatalf("unable to get trained prediction: %v", err)
	}
	fmt.Printf("Predict result: language=%v\n", res.OutputLabel)
	for _, m := range res.OutputMulti {
		fmt.Printf("%v Score: %v\n", m.Label, m.Score)
	}

	// Now analyze the model.
	an, err := t.api.Trainedmodels.Analyze(t.projectNumber, t.modelName).Do()
	if err != nil {
		log.Fatalf("unable to analyze trained model: %v", err)
	}
	for _, f := range an.DataDescription.OutputFeature.Text {
		fmt.Printf("analyze: output feature text=%v\n", f)
	}
	for _, f := range an.DataDescription.Features {
		fmt.Printf("feature text count=%v\n", f.Text.Count)
	}
}