Image Classification in Browser
How we deployed a 10MB image classification model that runs entirely in the browser using TensorFlow.js and WebAssembly.
Why Run ML in the Browser?
Training happens on powerful servers, but what about inference? Running models in the browser offers unique advantages:
- Privacy - User data never leaves their device
- Speed - No network round-trip to a server
- Cost - Zero inference costs on your infrastructure
- Offline - Works without internet connection
We built an image classification app that identifies 1000+ objects in real-time, entirely client-side.
Choosing the Right Model
Not all models work well in browsers. We needed:
- Small file size (< 10MB)
- Fast inference (< 100ms)
- Good accuracy (> 75% top-5)
Model Options
MobileNetV2
- Size: 14MB
- Speed: 45ms
- Accuracy: 88% top-5
- ✅ Our choice
Resnet50
- Size: 98MB
- Speed: 250ms
- Accuracy: 93% top-5
- ❌ Too large and slow
Tiny YOLO
- Size: 35MB
- Speed: 120ms
- Accuracy: Good for object detection
- ❌ Overkill for simple classification
Converting the Model
TensorFlow.js requires models in a specific format. Here’s how to convert:
1. Export from Python/TensorFlow
import tensorflow as tf
import tensorflowjs as tfjs
# Load your trained model
model = tf.keras.models.load_model('model.h5')
# Convert to TensorFlow.js format
tfjs.converters.save_keras_model(model, 'tfjs_model')
This creates:
model.json- Model architecture and weights metadatagroup1-shard1of1.bin- Binary weight data
2. Optimize for Size
# Quantize weights from float32 to uint16 (50% size reduction)
tensorflowjs_converter \
--input_format=keras \
--output_format=tfjs_graph_model \
--quantization_bytes=2 \
model.h5 \
tfjs_model_quantized
Result: 14MB → 7MB with minimal accuracy loss
Loading the Model in JavaScript
import * as tf from '@tensorflow/tfjs'
// Enable WebGL backend for GPU acceleration
await tf.setBackend('webgl')
// Load model
const model = await tf.loadGraphModel('tfjs_model/model.json')
console.log('Model loaded successfully!')
Preloading for Better UX
Don’t make users wait. Load the model on page load:
let model = null
// Start loading immediately
const modelPromise = tf.loadGraphModel('tfjs_model/model.json')
// Use when needed
async function classify(image) {
if (!model) {
model = await modelPromise
}
return predict(model, image)
}
Preprocessing Images
Neural networks expect specific input formats:
function preprocessImage(imgElement) {
return tf.tidy(() => {
// Convert HTML image to tensor
let tensor = tf.browser.fromPixels(imgElement)
// Resize to 224x224 (MobileNet input size)
tensor = tf.image.resizeBilinear(tensor, [224, 224])
// Normalize pixel values to [-1, 1]
tensor = tensor.div(127.5).sub(1)
// Add batch dimension
tensor = tensor.expandDims(0)
return tensor
})
}
Important: Use tf.tidy() to prevent memory leaks!
Running Inference
async function classifyImage(imgElement) {
// Preprocess
const input = preprocessImage(imgElement)
// Run inference
const predictions = await model.predict(input)
// Get top 5 predictions
const top5 = await getTop5(predictions)
// Clean up tensors
input.dispose()
predictions.dispose()
return top5
}
async function getTop5(predictions) {
const values = await predictions.data()
const indices = Array.from(values)
.map((prob, index) => ({ prob, index }))
.sort((a, b) => b.prob - a.prob)
.slice(0, 5)
.map(x => ({
className: IMAGENET_CLASSES[x.index],
probability: x.prob
}))
return indices
}
Performance Optimization
1. Use WebGL Backend
WebGL uses the GPU for calculations:
// Check which backend is being used
console.log(tf.getBackend()) // Should be 'webgl'
// Manually set if needed
await tf.setBackend('webgl')
WebGL is 10-100x faster than CPU!
2. Warm Up the Model
The first inference is always slower. Warm it up:
async function warmUpModel(model) {
const warmupTensor = tf.zeros([1, 224, 224, 3])
await model.predict(warmupTensor)
warmupTensor.dispose()
}
// After loading model
await warmUpModel(model)
3. Batch Predictions
Processing multiple images? Use batching:
async function batchPredict(images) {
// Stack multiple images into one tensor
const batch = tf.tidy(() => {
const tensors = images.map(preprocessImage)
return tf.stack(tensors)
})
const predictions = await model.predict(batch)
batch.dispose()
return predictions
}
4. Use Web Workers
Don’t block the main thread:
// worker.js
import * as tf from '@tensorflow/tfjs'
let model
self.addEventListener('message', async (e) => {
const { type, data } = e.data
if (type === 'load') {
model = await tf.loadGraphModel(data.modelUrl)
self.postMessage({ type: 'loaded' })
}
if (type === 'predict') {
const result = await classify(data.image)
self.postMessage({ type: 'result', data: result })
}
})
// main.js
const worker = new Worker('worker.js')
worker.postMessage({
type: 'load',
data: { modelUrl: 'model.json' }
})
worker.addEventListener('message', (e) => {
if (e.data.type === 'result') {
console.log('Predictions:', e.data.data)
}
})
Real-Time Video Classification
Classify webcam frames in real-time:
async function setupCamera() {
const video = document.getElementById('webcam')
const stream = await navigator.mediaDevices.getUserMedia({
video: { width: 640, height: 480 }
})
video.srcObject = stream
return new Promise(resolve => {
video.onloadedmetadata = () => resolve(video)
})
}
async function detectFrame(video) {
const predictions = await classifyImage(video)
// Display results
displayPredictions(predictions)
// Request next frame
requestAnimationFrame(() => detectFrame(video))
}
// Start
const video = await setupCamera()
await video.play()
detectFrame(video)
Throttling for Performance
Don’t classify every frame:
let lastPredictionTime = 0
const PREDICTION_INTERVAL = 200 // ms
async function detectFrame(video) {
const now = Date.now()
if (now - lastPredictionTime > PREDICTION_INTERVAL) {
const predictions = await classifyImage(video)
displayPredictions(predictions)
lastPredictionTime = now
}
requestAnimationFrame(() => detectFrame(video))
}
Memory Management
TensorFlow.js uses GPU memory. Leaks cause crashes!
Manual Disposal
const tensor = tf.tensor([1, 2, 3])
// Use tensor...
tensor.dispose() // Free memory
Use tf.tidy()
Automatically cleans up intermediate tensors:
const result = tf.tidy(() => {
const a = tf.tensor([1, 2, 3])
const b = tf.tensor([4, 5, 6])
const c = a.add(b) // Intermediate tensor
return c.mean() // Return value is NOT disposed
})
// Only 'result' exists, a/b/c were cleaned up
Monitor Memory
console.log('Num tensors:', tf.memory().numTensors)
console.log('Num bytes:', tf.memory().numBytes)
// Watch for leaks
setInterval(() => {
const mem = tf.memory()
console.log(`Tensors: ${mem.numTensors}, Bytes: ${mem.numBytes}`)
}, 5000)
Handling Edge Cases
Mobile Performance
Mobile GPUs are slower:
function isMobile() {
return /Android|iPhone|iPad/i.test(navigator.userAgent)
}
const config = isMobile()
? { inputSize: 192, maxFPS: 15 } // Smaller, slower
: { inputSize: 224, maxFPS: 30 } // Full quality
Fallback for Unsupported Browsers
async function setupTensorFlow() {
try {
await tf.setBackend('webgl')
await tf.ready()
} catch (e) {
console.warn('WebGL not available, falling back to CPU')
await tf.setBackend('cpu')
}
}
Progressive Loading
Load a tiny model first, then upgrade:
// Load tiny model (1MB)
let model = await tf.loadGraphModel('tiny_model/model.json')
console.log('Tiny model ready!')
// Upgrade to full model in background
tf.loadGraphModel('full_model/model.json').then(fullModel => {
model = fullModel
console.log('Full model ready!')
})
Bundle Size Optimization
TensorFlow.js is large. Optimize your bundle:
Tree Shaking
Import only what you need:
// ❌ Imports everything (500KB)
import * as tf from '@tensorflow/tfjs'
// ✅ Import specific modules (150KB)
import '@tensorflow/tfjs-backend-webgl'
import { loadGraphModel } from '@tensorflow/tfjs-converter'
import { browser, tidy, image } from '@tensorflow/tfjs-core'
Code Splitting
Lazy load TensorFlow.js:
async function loadModel() {
const tf = await import('@tensorflow/tfjs')
return tf.loadGraphModel('model.json')
}
Use a CDN
Let browsers cache TensorFlow.js:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>
Production Checklist
Before deploying:
- Quantize model weights (50% size reduction)
- Warm up model on load
- Use
tf.tidy()everywhere - Monitor memory usage
- Add loading indicators
- Test on mobile devices
- Add error handling
- Consider Web Worker for long predictions
- Implement progressive loading
- Use WASM backend as fallback
Results
Our production app:
| Metric | Value |
|---|---|
| Model size | 7MB (quantized) |
| Load time | 1.2s (3G network) |
| Inference time | 45ms (desktop GPU) |
| Inference time | 180ms (mobile GPU) |
| Memory usage | ~150MB |
| Accuracy | 87% top-5 |
Lessons Learned
- Quantization is essential - 50% size reduction with <1% accuracy loss
- Memory leaks are easy - Always use
tf.tidy()or manual disposal - Mobile is much slower - Plan for 4-8x slower inference
- WebGL isn’t always available - Have a CPU fallback
- Bundle size matters - Tree-shake and code-split TensorFlow.js