How do I get online predictions in C# for my model on Cloud Machine Learning Engine?

前端 未结 1 1083
梦毁少年i
梦毁少年i 2021-01-26 10:49

I have successfully deployed on model on Cloud ML Engine and verified it is working with gcloud ml-engine models predict by following the instructions, now I want t

1条回答
  •  深忆病人
    2021-01-26 11:25

    The online prediction API is a REST API, so you can use any library for sending HTTPS requests, although you will need to use Google's OAuth library to get your credentials.

    The format of the request is JSON, as described in the docs.

    To exemplify, consider the Census example. A client for that might look like:

    using System;
    using System.Collections.Generic;
    using System.Net.Http;
    using System.Net.Http.Headers;
    using System.Text;
    using System.Threading.Tasks;
    using Google.Apis.Auth.OAuth2;
    using Newtonsoft.Json;
    
    namespace prediction_client
    {
        class Person
        {
            public int age { get; set; }
            public String workclass { get; set; }
            public String education { get; set; }
            public int education_num { get; set; }
            public string marital_status { get; set; }
            public string occupation { get; set; }
            public string relationship { get; set; }
            public string race { get; set; }
            public string gender { get; set; }
            public int capital_gain { get; set; }
            public int capital_loss { get; set; }
            public int hours_per_week { get; set; }
            public string native_country { get; set; }
        }
    
        class Prediction
        {
            public List probabilities { get; set; }
            public List logits { get; set; }
            public Int32 classes { get; set; }
            public List logistic { get; set; }
    
            public override string ToString()
            {
                return JsonConvert.SerializeObject(this);
            }
        }
    
        class MainClass
        {
            static PredictClient client = new PredictClient();
            static String project = "MY_PROJECT";
            static String model = "census";  // Whatever you deployed your model as
    
            public static void Main(string[] args)
            {
                RunAsync().Wait();
            }
    
            static async Task RunAsync()
            {
                try
                {
                    Person person = new Person
                    {
                        age = 25,
                        workclass = " Private",
                        education = " 11th",
                        education_num = 7,
                        marital_status = " Never - married",
                        occupation = " Machine - op - inspct",
                        relationship = " Own - child",
                        race = " Black",
                        gender = " Male",
                        capital_gain = 0,
                        capital_loss = 0,
                        hours_per_week = 40,
                        native_country = " United - Stats"
                    };
                    var instances = new List { person };
    
                    List predictions = await client.Predict(project, model, instances);
                    Console.WriteLine(String.Join("\n", predictions));
                }
                catch (Exception e)
                {
                    Console.WriteLine(e.Message);
                }
            }
        }
    
        class PredictClient {
    
            private HttpClient client;
    
            public PredictClient() 
            {
                this.client = new HttpClient();
                client.BaseAddress = new Uri("https://ml.googleapis.com/v1/");
                client.DefaultRequestHeaders.Accept.Clear();
                client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
            }        
    
            public async Task> Predict(String project, String model, List instances, String version = null)
            {
                var version_suffix = version == null ? "" : $"/version/{version}";
                var model_uri = $"projects/{project}/models/{model}{version_suffix}";
                var predict_uri = $"{model_uri}:predict";
    
                GoogleCredential credential = await GoogleCredential.GetApplicationDefaultAsync();
                var bearer_token = await credential.UnderlyingCredential.GetAccessTokenForRequestAsync();
                client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", bearer_token);
    
                var request = new { instances = instances };
                var content = new StringContent(JsonConvert.SerializeObject(request), Encoding.UTF8, "application/json");
    
                var responseMessage = await client.PostAsync(predict_uri, content);
                responseMessage.EnsureSuccessStatusCode();
    
                var responseBody = await responseMessage.Content.ReadAsStringAsync();
                dynamic response = JsonConvert.DeserializeObject(responseBody);
    
                return response.predictions.ToObject>();
            }
        }
    }
    

    You may have to run gcloud auth login to initialize your credentials before running locally, if you haven't already.

    0 讨论(0)
提交回复
热议问题