How AWS Lambda SnapStart drastically reduces cold starts for Serverless Machine Learning Inference

Rustem Feyzkhanov
AWS Tip
Published in
4 min readNov 29, 2022

--

Challenges with cold start for ML inference

One of the main challenges with Serverless Machine Learning Inference was always a cold start. And in the case of ML inference there are multiple things contributing to it:

  • runtime initialization
  • loading libraries and dependencies
  • loading the model itself (from S3 or package)
  • initializing the model

Some of these steps could be optimized by either using provisioned concurrency or optimizing the model size. However, for some ML applications, it was prohibitive to use AWS Lambda for inference due to the cold start.

SnapStart feature

With the newly announced SnapStart feature for AWS Lambda cold start is replaced with SnapStart. AWS Lambda will create an immutable, encrypted snapshot of the memory and disk state, and will cache it for reuse. This snapshot will have ML model loaded in memory and ready to use.

Things to keep in mind:

  • [Java runtime] SnapStart is currently only supported for Java runtime. That adds limitations, but ONNX does work on Java and it’s possible to run ONNX with SnapStart.
  • [Model load] Model load has to happen within the initialization step, not the run step and the model should be reused between runs. In java, it’s a static block. The good thing is that we are not limited by function timeout to load the model and the max amount of initialization is 15 minutes.
  • [Snap-Resilient] SnapStart has specific limitations — uniqueness since SnapStart uses snapshot. It means for example that if a random seed is defined during the init phase then all lambda invocation will have the same generator. Read more on how to make Lambda resilient here.

ONNX Example

An example with ONNX and SnapStart is publicly available here and can be used with Sam to deploy the ONNX Inception V3 endpoint and test it.

To highlight the architecture for the SnapStart in case of ONNX:

  • onnxSession — has preloaded model and is reused between invocations.
  • getOnnxSession — loads the model if it wasn’t loaded before and skips it if it was used loaded before.
  • static block — run the code during SnapStart creation. This is the important part — code in handler won’t be run during creation of the snapshot.
package onnxsnapstart;

/**
* Handler for Onnx predictions on Lambda function.
*/
public class App implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {

// Onnx session with preloaded model which will be reused between invocations and will be
// initialized as part of snapshot creation
private static OrtSession onnxSession;

// Returns Onnx session with preloaded model. Reuses existing session if exists.
private static OrtSession getOnnxSession() {
String modelPath = "inception_v3.onnx";
if (onnxSession==null) {
System.out.println("Start model load");
try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
OrtSession.SessionOptions options = new SessionOptions()) {
try {
OrtSession session = env.createSession(modelPath, options);
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
Map<String, NodeInfo> outputInfoList = session.getOutputInfo();
System.out.println(inputInfoList);
System.out.println(outputInfoList);
onnxSession = session;
return onnxSession;
}
catch(OrtException exc) {
exc.printStackTrace();
}
}
}
return onnxSession;
}

// This code runs during snapshot initialization. In the normal lambda that would run in init phase.
static {
System.out.println("Start model init");
getOnnxSession();
System.out.println("Finished model init");
}

// Main handler for the Lambda
public APIGatewayProxyResponseEvent handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("X-Custom-Header", "application/json");


float[][][][] testData = new float[1][3][299][299];

try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath")) {
OnnxTensor test = OnnxTensor.createTensor(env, testData);
OrtSession session = getOnnxSession();
String inputName = session.getInputNames().iterator().next();
Result output = session.run(Collections.singletonMap(inputName, test));
System.out.println(output);
}
catch(OrtException exc) {
exc.printStackTrace();
}


APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent().withHeaders(headers);
String output = String.format("{ \"message\": \"made prediction\" }");

return response
.withStatusCode(200)
.withBody(output);
}
}

Here is how the logs look like (check how Init Duration is now replaced with Restore Duration):

  • Traditional Lambda
Picked up JAVA_TOOL_OPTIONS: -XX:+TieredCompilation -XX:TieredStopAtLevel=1
Start model init
Start model load
{x.1=NodeInfo(name=x.1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 299, 299]))}
{924=NodeInfo(name=924,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 1000]))}
Finished model init
START RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Version: $LATEST
ai.onnxruntime.OrtSession$Result@e580929
END RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a
REPORT RequestId: 27b8e0e6-2e26-4356-a015-540fc79c080a Duration: 244.99 ms Billed Duration: 245 ms Memory Size: 1769 MB Max Memory Used: 531 MB Init Duration: 8615.62 ms
  • SnapStart Lambda
RESTORE_START Runtime Version: java:11.v15 Runtime Version ARN: arn:aws:lambda:us-east-1::runtime:0a25e3e7a1cc9ce404bc435eeb2ad358d8fa64338e618d0c224fe509403583ca
RESTORE_REPORT Restore Duration: 571.67 ms
START RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Version: 1
ai.onnxruntime.OrtSession$Result@47f6473
END RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029
REPORT RequestId: 9eafdbf2-37f0-430d-930e-de9ca14ad029 Duration: 496.51 ms Billed Duration: 645 ms Memory Size: 1769 MB Max Memory Used: 342 MB Restore Duration: 571.67 ms

ONNX Benchmark

This is the result of the ApacheBench on the ONNX Lambda which uses the Inception V3 model for the inference with 1000 requests and a concurrency of 50.

We still have additional latency due to restoring the snapshot, but now our tail is significantly shorted and we don’t have requests which would take more than 2.5 seconds.

  • Traditional Lambda
Percentage of the requests served within a certain time (ms)
50% 352
66% 377
75% 467
80% 473
90% 488
95% 9719
98% 10329
99% 10419
100% 12825
  • SnapStart Lambda
  50%    365
66% 445
75% 477
80% 487
90% 556
95% 1392
98% 2233
99% 2319
100% 2589 (longest request)

Summary

SnapStart revolutionizes how AWS Lambda can be used for different tasks and ML is definitely one of them. Cold start elimination opens a lot of possibilities for Serverless ML and looking forward to SnapStart being available for more runtimes (specifically Python :-) ). Please share your use cases in the comments and would be happy to answer any questions.

--

--

I'm a staff machine learning engineer at Instrumental, where I work on analytical models for the manufacturing industry, and AWS Machine Learning Hero.