Deploy the Hugging Face (PyAnnote) speaker diarization model to Amazon SageMaker as an asynchronous endpoint

Machine Learning


Speaker diarization is an essential process in speech analysis that segments audio files based on speaker identity. This post details Hugging Face's integration of PyAnnote with Amazon SageMaker asynchronous endpoints for speaker diarization.

We provide a comprehensive guide on how to deploy speaker segmentation and clustering solutions using SageMaker on the AWS Cloud. This solution can be used for applications that handle multi-speaker (more than 100) audio recordings.

Solution overview

Amazon Transcribe is the go-to service for speaker diarization on AWS. However, for unsupported languages, you can use other models (in this case PyAnnote) that are deployed to SageMaker for inference. For short audio files that take up to 60 seconds to infer, you can use real-time inference. If it's longer than 60 seconds, you should use asynchronous inference. An additional benefit of asynchronous inference is that it can save costs by automatically scaling the number of instances to zero when there are no requests to process.

Hugging Face is a popular open source hub for machine learning (ML) models. Through SageMaker, AWS and Hugging Face offer seamless integration between a set of AWS deep learning containers (DLCs) for training and inference in PyTorch or TensorFlow and Hugging Face estimators and predictors in the SageMaker Python SDK. We have partnerships that make it possible. SageMaker features help developers and data scientists easily get started with natural language processing (NLP) on their AWS.

Integration of this solution includes using Hugging Face's pre-trained speaker diarization model using the PyAnnote library. PyAnnote is an open source toolkit for speaker diarization written in Python. The model is trained on a sample audio dataset and enables effective speaker segmentation within audio files. The model is deployed to SageMaker as an asynchronous endpoint configuration, providing efficient and scalable processing of diary tasks.

The following diagram shows the solution architecture.solution architecture

This article uses the following audio files:

Stereo or multichannel audio files are automatically downmixed to mono by averaging the channels. Audio files sampled at different rates are automatically resampled to 16kHz when loaded.

Prerequisites

Meet the following prerequisites:

  1. Create a SageMaker domain.
  2. Verify that your AWS Identity and Access Management (IAM) user has the necessary permissions to create the SageMaker role.
  3. Make sure your AWS account has the service quota to host SageMaker endpoints for ml.g5.2xlarge instances.

Create a model function to access PyAnnote speaker diarization from Hugging Face

You can use Hugging Face Hub to access the pre-trained PyAnnote speaker diarization model of your choice. Use the same script to download the model file when creating the SageMaker endpoint.

hugging face

See the code below.

from PyAnnote.audio import Pipeline

def model_fn(model_dir):
# Load the model from the specified model directory
model = Pipeline.from_pretrained(
"PyAnnote/speaker-diarization-3.1",
use_auth_token="Replace-with-the-Hugging-face-auth-token")
return model

Package the model code

Prepare important files such as inference.py that contain your inference code.

%%writefile model/code/inference.py
from PyAnnote.audio import Pipeline
import subprocess
import boto3
from urllib.parse import urlparse
import pandas as pd
from io import StringIO
import os
import torch

def model_fn(model_dir):
    # Load the model from the specified model directory
    model = Pipeline.from_pretrained(
        "PyAnnote/speaker-diarization-3.1",
        use_auth_token="hf_oBxxxxxxxxxxxx)
    return model 


def diarization_from_s3(model, s3_file, language=None):
    s3 = boto3.client("s3")
    o = urlparse(s3_file, allow_fragments=False)
    bucket = o.netloc
    key = o.path.lstrip("/")
    s3.download_file(bucket, key, "tmp.wav")
    result = model("tmp.wav")
    data = {} 
    for turn, _, speaker in result.itertracks(yield_label=True):
        data[turn] = (turn.start, turn.end, speaker)
    data_df = pd.DataFrame(data.values(), columns=["start", "end", "speaker"])
    print(data_df.shape)
    result = data_df.to_json(orient="split")
    return result


def predict_fn(data, model):
    s3_file = data.pop("s3_file")
    language = data.pop("language", None)
    result = diarization_from_s3(model, s3_file, language)
    return {
        "diarization_from_s3": result
    }

Prepare. requirements.txt The file contains the Python libraries needed to perform inference.

with open("model/code/requirements.txt", "w") as f:
    f.write("transformers==4.25.1\n")
    f.write("boto3\n")
    f.write("PyAnnote.audio\n")
    f.write("soundfile\n")
    f.write("librosa\n")
    f.write("onnxruntime\n")
    f.write("wget\n")
    f.write("pandas")

Finally, compress inference.py and create a requirements.txt file and save it as: model.tar.gz:

Configure the SageMaker model

Define a SageMaker model resource by specifying the image URI, the location of the model data in Amazon Simple Storage Service (S3), and the SageMaker role.

import sagemaker
import boto3

sess = sagemaker.Session()

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

Upload the model to Amazon S3

Upload the zipped PyAnnote Hugging Face model file to your S3 bucket.

s3_location = f"s3://{sagemaker_session_bucket}/whisper/model/model.tar.gz"
!aws s3 cp model.tar.gz $s3_location

Create a SageMaker asynchronous endpoint

Configure an asynchronous endpoint to deploy your model to SageMaker using the provided asynchronous inference configuration.

from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join
from sagemaker.utils import name_from_base

async_endpoint_name = name_from_base("custom-asyc")

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data=s3_location,  # path to your model and script
    role=role,  # iam role with permissions to create an Endpoint
    transformers_version="4.17",  # transformers version used
    pytorch_version="1.10",  # pytorch version used
    py_version="py38",  # python version used
)

# create async endpoint configuration
async_config = AsyncInferenceConfig(
    output_path=s3_path_join(
        "s3://", sagemaker_session_bucket, "async_inference/output"
    ),  # Where our results will be stored
    # Add nofitication SNS if needed
    notification_config={
        # "SuccessTopic": "PUT YOUR SUCCESS SNS TOPIC ARN",
        # "ErrorTopic": "PUT YOUR ERROR SNS TOPIC ARN",
    },  #  Notification configuration
)

env = {"MODEL_SERVER_WORKERS": "2"}

# deploy the endpoint endpoint
async_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.xx",
    async_inference_config=async_config,
    endpoint_name=async_endpoint_name,
    env=env,
)

Test the endpoint

Evaluate the functionality of the endpoint by sending an audio file for diarization and retrieving the JSON output stored in the specified S3 output path.

# Replace with a path to audio object in S3
from sagemaker.async_inference import WaiterConfig
res = async_predictor.predict_async(data=data)
print(f"Response output path: {res.output_path}")
print("Start Polling to get response:")

config = WaiterConfig(
  max_attempts=10, #  number of attempts
  delay=10#  time in seconds to wait between attempts
  )
res.get_result(config)
#import waiterconfig

To deploy this solution at scale, we recommend using AWS Lambda, Amazon Simple Notice Service (Amazon SNS), or Amazon Simple Queue Service (Amazon SQS). These services are designed for scalability, event-driven architecture, and efficient resource utilization. These help decouple the asynchronous inference process from result processing, allowing each component to scale independently and handle bursts of inference requests more effectively.

result

Model output is saved to: s3://sagemaker-xxxx /async_inference/output/. The output shows the audio recording divided into three columns.

  • Start (Start time (sec))
  • End (end time in seconds)
  • Speaker (speaker label)

The following code shows an example result.

[0.9762308998, 8.9049235993, "SPEAKER_01"]

[9.533106961, 12.1646859083, "SPEAKER_01"]

[13.1324278438, 13.9303904924, "SPEAKER_00"]

[14.3548387097, 26.1884550085, "SPEAKER_00"]

[27.2410865874, 28.2258064516, "SPEAKER_01"]

[28.3446519525, 31.298811545, "SPEAKER_01"]

cleaning

You can set the scaling policy to zero by setting MinCapacity to 0. Asynchronous inference allows you to autoscale to zero without any requests. There is no need to delete the endpoint. It scales from scratch when you need it again, reducing costs when you're not using it. See the code below.

# Common class representing application autoscaling for SageMaker 
client = boto3.client('application-autoscaling') 

# This is the format in which application autoscaling references the endpoint
resource_id='endpoint/' + <endpoint_name> + '/variant/' + <'variant1'> 

# Define and register your endpoint variant
response = client.register_scalable_target(
    ServiceNamespace="sagemaker", 
    ResourceId=resource_id,
    ScalableDimension='sagemaker:variant:DesiredInstanceCount', # The number of EC2 instances for your Amazon SageMaker model endpoint variant.
    MinCapacity=0,
    MaxCapacity=5
)

If you want to delete an endpoint, use the following code:

async_predictor.delete_endpoint(async_endpoint_name)

Advantages of implementing asynchronous endpoints

This solution has the following advantages:

  • This solution can efficiently process multiple or large audio files.
  • This example uses a single instance for demonstration purposes. If you use this solution for hundreds or thousands of videos, processed across multiple instances using asynchronous endpoints, you can use autoscaling policies designed for large numbers of source documents. Autoscaling dynamically adjusts the number of instances provisioned to your model as your workload changes.
  • This solution optimizes resources and reduces system load by separating long-running tasks from real-time inference.

conclusion

In this post, I provided a simple approach to deploying a Hugging Face speaker diarization model to SageMaker using a Python script. Asynchronous endpoints provide an efficient and scalable means to provide diarization prediction as a service and seamlessly accommodate concurrent requests.

Get started with asynchronous speaker diarization for your audio projects today. If you have any questions about getting your own asynchronous diarization endpoint up and running, let us know in the comments.


About the author

Sanjay Tiwary AI/ML Specialist Solutions Architects work with strategic customers to define business requirements, deliver L300 sessions on specific use cases, and deliver scalable, reliable, and performant AI/ML solutions. I spend my time designing ML applications and services. He helped launch and scale his Amazon SageMaker service powered by AI/ML and implemented several proofs of concept using Amazon AI services. He also developed an advanced analytics platform as part of the digital transformation.

Kiran Charapalli is a deep technology business developer in the AWS public sector. He has over 8 years of experience in his AI/ML and 23 years of experience in software development and sales in general. Kiran helps public sector enterprises across India explore and co-create cloud-based solutions that use generative AI technologies, including AI, ML, and large-scale language models.



Source link

Leave a Reply

Your email address will not be published. Required fields are marked *