问题
I spin up a Sagemaker notebook using the conda_python3
kernel, and follow the example Notebook for Random Cut Forest.
As of this writing, the Sagemaker SDK that comes with conda_python3
is version 1.72.0, but I want to use new features, so I update my notebook to use the latest
%%bash
pip install -U sagemaker
And I see it updates.
print(sagemaker.__version__)
# 2.4.1
A change from version 1.x to 2.x was the serializer/deserializer classes
Previously (in version 1.72.0) I'd update my predictor to use the proper serializer/deserializer, and could run inference on my model
from sagemaker.predictor import csv_serializer, json_deserializer
rcf_inference = rcf.deploy(
initial_instance_count=1,
instance_type='ml.m4.xlarge',
)
rcf_inference.content_type = 'text/csv'
rcf_inference.serializer = csv_serializer
rcf_inference.accept = 'application/json'
rcf_inference.deserializer = json_deserializer
results = rcf_inference.predict(some_numpy_array)
(Note this all comes from the example
I try and replicate this using sagemaker 2.4.1 like so
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import CSVSerializer
rcf_inference = rcf.deploy(
initial_instance_count=1,
instance_type='ml.m5.xlarge',
serializer=CSVSerializer,
deserializer=JSONDeserializer
)
results = rcf_inference.predict(some_numpy_array)
And I receive an error of
TypeError: serialize() missing 1 required positional argument: 'data'
I know I'm using the serliaizer/deserializer incorrectly, but can't find good documentation on how this should be used
回答1:
in order to use the new serializers/deserializers, you will need to init them, for example:
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import CSVSerializer
rcf_inference = rcf.deploy(
initial_instance_count=1,
instance_type='ml.m5.xlarge',
serializer=CSVSerializer(),
deserializer=JSONDeserializer()
)
回答2:
In the case of a custom Serializer we can do it this way in SageMaker 2.x:
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer
class FMSerializer(JSONSerializer):
def serialize(self, data):
js = {'instances': []}
for row in data:
js['instances'].append({'features': row.tolist()})
return json.dumps(js)
predictor = estimator.deploy(
initial_instance_count=1,
instance_type="ml.m4.xlarge",
serializer=FMSerializer(),
deserializer=JSONDeserializer()
)
来源:https://stackoverflow.com/questions/63568274/how-to-use-serializer-and-deserializer-in-sagemaker-2