You've probably seen 8x7b everywhere with news of Mistral's new model beating ChatGPT on several benchmarks. Besides showing that you shouldn't underestimate a small group of determined frenchies (especially when they got former DeepMind and Meta folks, that plays a little), it also further testifies that well-crafted smaller models can provide tremendous value at a fraction of cost. In this edition we're diving into a case study of another such well-constructed smaller model instead of our usual library featuring.
As usual, feel free to check out our latest blog post, this time diving into the complex landscape of model serving solutions. See you there!
Segment faster with your thumb
With its ability to efficiently detect image components across a wide array of images, Segment Anything (SAM) quickly stood out from the rest (intended) when it was released. Yet, its use remains restricted on mobile devices given the model's size despite the emergence of more lightweight versions like Mobile-SAM which provide some performance improvements but at the expense of accuracy. Enter EdgeSAM which promises even faster inference with accuracy comparable to the original model!
Why would you care? EdgeSAM is 40x faster than the original model and even 1.6x faster than the edge-friendly MobileSAM. Its performance unlocks real-time edge applications on a large range of devices. The implementation also makes use of interesting techniques that can be re-used when deploying lightweight versions of similar models, which we wanted to share in this piece.
Encoder-only KD trains the feature encoder. Prompt-in-the-loop KD refines the mask decoder.
How does it work? EdgeSAM uses knowledge distillation to build a more compact and edge-efficient version of SAM's architecture. To understand the process, let's get a quick recap on the original model.
Knowledge Distillation? Knowledge Distillation (KD) is a model compression technique used to transfer knowledge from a large and complex (teacher) network to a smaller and simpler (student) network. The goal of knowledge distillation is to reduce the size of the network by leveraging knowledge learned by the teacher network which is passed onto a smaller version.
At a high level, SAM determines image parts through a mask decoder which makes segmentation predictions by using self-attention and cross-attention to update the (a) image embeddings generated from a ViT-based image encoder, and (b) prompt embeddings generated from a prompt encoder that takes points, boxes, masks, or text as input. Along with the interactive model design, SAM is also trained on a large dataset to improve its generality, making it both a powerful and versatile model.
EdgeSAM is built on top of the above architecture with:
Encoder-level Knowledge Distillation: Which uses a pixel-wise distillation loss function to train the student model and adjusts for the teacher and student models' mismatch in (a) down sampling, by introducing a feature up sampler to the student, and (b) feature channels, by using a projection layer to align the channels, as MobileSAM does. EdgeSAM also makes use of a CNN backbone instead of the original model's ViT layers to improve its performance on current mobile accelerators that tend to be more adapted to the former structure.
Prompt-in-the-loop Knowledge Distillation: Because fine-tuning the original mask decoder compromises its generalisation capabilities (for e.g training only with points degrades box prompts performance), EdgeSAM was trained using a dynamic prompt sampling strategy to actively align it with SAM's output masks and adjust regions where it is inaccurate by iteratively introducing a diverse set of new prompts to guide the decoder. The process starts with an initial prompt that can be either a box or point prompt with equal probability, which is fed into both the teacher and student's decoders to determine areas where the outputs diverge. Inclusion prompts are then placed on the parts the student incorrectly removed, while exclusion prompts are placed on the parts incorrectly included.
Granularity Priors: The original SAM identifies granular inputs better with box prompts, while point prompts tend to produce partial masks and require interactive refining as a result. Because mobile use has applications that favour one-click inputs (let's face it, we're too lazy to tap a hundred times to get the perfect contour), EdgeSAM comes equipped with a new module that embeds the granularity priors of specific datasets to improve the one-shot performance of point prompts.
Check out the model's repo to get started using it.
The Lab
Do you even multi-task? Because training large DNN models is resource-intensive, many organisations invest in DL clusters to parallelize jobs across GPUs using complex parallelization strategies and must manage DL clusters efficiently to minimise the high cost of running accelerators. When submitting jobs to run on such clusters, a scheduler typically assigns them to GPUs and needs to change GPU allocation dynamically at runtime to manage resources efficiently, but current DL training frameworks (PyTorch, TensorFlow, MindSpore) do not permit this dynamic allocation during runtime, and existing solutions either support a limited range of applications or don’t support multi-dimensional parallelization. Researchers from the Imperial College, the University of Edinburgh and Aalto University propose TENPLEX, a state management library for DL frameworks that enables jobs with multi-dimensional parallelism to support dynamic changes to GPU resources during training, cutting down on training time and costs by making the most out of GPU clusters! This is done by using a tensor-abstraction they call parallelizable tensor collection (PTC) to decouple the job state from the data parallelism pipeline and apply state transformations and optimizations to the PTC-contained job to ensure consistency and timely allocation across workers. TENPLEX integrates with existing DL frameworks like PyTorch and model libraries like Megatron-LM and DeepSpeed, and achieves a 24% reduction in training time while reducing resource reconfiguration time by 43% compared to naively state migrate, and 75% compared to central state maintenance.
Two brains are better than one - Much like with supervised learning algorithms, growing reinforcement learning (RL) structures lead to increasing computational time complexity which is even more problematic in RL as prolonged training times exacerbate the algorithms’ inherent instability. Full parallelism is also difficult to achieve as RL needs to continuously interact with the environment to update and despite the advent of parallelisation-enabling frameworks, parts of the pipeline remain sequential with the main computing bottleneck lying in the data transmission between CPU and GPU. Spreeze proposes to address this issue and maximise the performance of various hardware components involved by optimising the efficiency in information sharing. It introduces an asynchronous environment where multiple sampling processes continuously run and interact with the environment, while the policy network generates actions from the collective experience data generated which gets transmitted to a single network update process for parallelization by GPU through the batch size that achieves the best training efficiency given the available memory and computing. Actor-critic model parallelism further optimises the process by having one GPU handle policy network updating and another GPU responsible for updating the value network, to reduce the time required for network updates. Shared memory is used throughout the process to minimise waiting time during the transmission of sampled experience data. Spreeze reduces training time by an average of 73% and can accommodate different types of hardware, making it a versatile RL optimisation framework.
The Pulse
Big Brother is (better be ready to pay for) watching you: The AI Act dropped last week following a series of lengthy debates. The new European legislation now officially bans AI applications that involve any form of manipulative use of AI technology and segregation-based data collection. Additionally, a set of transparency-promoting requirements are imposed upon the use of any general purpose AI systems such as disseminating information on the data used for training, providing technical documentation, and more stringent measures for models deemed to as high-impact like conducting pre-release adversarial tests or monitoring systemic risks for e.g. AI applications intended for use within the EU will need to comply with this yet-to-be-formally-adopted law in the near future.
Transformers have become a ubiquitous archite-: This probably won’t age very well as we’re starting to see more models using alternative architectures that outperform or at least perform on par with current SOTA. In the previous edition, we’ve discussed Mamba which promises manifold performance improvements over typical transformers. It’s now Together AI’s turn to release a series of StripedHyena models which build on top of similar SSM building blocks and achieve comparable performance with Llama-2, Yi and Mistral 7B on OpenLLM leaderboard tasks while outperforming on long-context summarization! We generally avoid discussing new model releases in our newsletter, but this could mark a turning point as we start to see a new way paved for architecturally memory efficient model releases!
Apples can do AI too: Apple recently released MLX, their custom framework dedicated to building and deploying machine learning models on apple silicon devices. MLX offers a low-level Python API as well as a fully featured C++ API that closely mirrors it, along with a higher-level API that can be used to create more complex models following PyTorch’s API. One of MLX’s unique perks is its efficient memory management which uses Apple Silicons’ unified memory to minimise data transfers between the CPU and GPU, providing better overall performance. Notably, this allows it to provide around 40% more throughput than PyTorch on medium batch sizes and ~15% higher when comparing the optimal batch sizes. By fully leveraging their bespoke hardware, MLX could greatly improve local inference on Apple devices, bringing local AI closer to peak performance!
And that’s all for this edition, we hope you enjoyed reading through and wish you a wonderful holiday season!
The Unify Dev Team