Image Classification in Browser

How we deployed a 10MB image classification model that runs entirely in the browser using TensorFlow.js and WebAssembly.

Why Run ML in the Browser?

Training happens on powerful servers, but what about inference? Running models in the browser offers unique advantages:

  • Privacy - User data never leaves their device
  • Speed - No network round-trip to a server
  • Cost - Zero inference costs on your infrastructure
  • Offline - Works without internet connection

We built an image classification app that identifies 1000+ objects in real-time, entirely client-side.

Choosing the Right Model

Not all models work well in browsers. We needed:

  • Small file size (< 10MB)
  • Fast inference (< 100ms)
  • Good accuracy (> 75% top-5)

Model Options

MobileNetV2

  • Size: 14MB
  • Speed: 45ms
  • Accuracy: 88% top-5
  • ✅ Our choice

Resnet50

  • Size: 98MB
  • Speed: 250ms
  • Accuracy: 93% top-5
  • ❌ Too large and slow

Tiny YOLO

  • Size: 35MB
  • Speed: 120ms
  • Accuracy: Good for object detection
  • ❌ Overkill for simple classification

Converting the Model

TensorFlow.js requires models in a specific format. Here’s how to convert:

1. Export from Python/TensorFlow

import tensorflow as tf
import tensorflowjs as tfjs

# Load your trained model
model = tf.keras.models.load_model('model.h5')

# Convert to TensorFlow.js format
tfjs.converters.save_keras_model(model, 'tfjs_model')

This creates:

  • model.json - Model architecture and weights metadata
  • group1-shard1of1.bin - Binary weight data

2. Optimize for Size

# Quantize weights from float32 to uint16 (50% size reduction)
tensorflowjs_converter \
  --input_format=keras \
  --output_format=tfjs_graph_model \
  --quantization_bytes=2 \
  model.h5 \
  tfjs_model_quantized

Result: 14MB → 7MB with minimal accuracy loss

Loading the Model in JavaScript

import * as tf from '@tensorflow/tfjs'

// Enable WebGL backend for GPU acceleration
await tf.setBackend('webgl')

// Load model
const model = await tf.loadGraphModel('tfjs_model/model.json')

console.log('Model loaded successfully!')

Preloading for Better UX

Don’t make users wait. Load the model on page load:

let model = null

// Start loading immediately
const modelPromise = tf.loadGraphModel('tfjs_model/model.json')

// Use when needed
async function classify(image) {
  if (!model) {
    model = await modelPromise
  }

  return predict(model, image)
}

Preprocessing Images

Neural networks expect specific input formats:

function preprocessImage(imgElement) {
  return tf.tidy(() => {
    // Convert HTML image to tensor
    let tensor = tf.browser.fromPixels(imgElement)

    // Resize to 224x224 (MobileNet input size)
    tensor = tf.image.resizeBilinear(tensor, [224, 224])

    // Normalize pixel values to [-1, 1]
    tensor = tensor.div(127.5).sub(1)

    // Add batch dimension
    tensor = tensor.expandDims(0)

    return tensor
  })
}

Important: Use tf.tidy() to prevent memory leaks!

Running Inference

async function classifyImage(imgElement) {
  // Preprocess
  const input = preprocessImage(imgElement)

  // Run inference
  const predictions = await model.predict(input)

  // Get top 5 predictions
  const top5 = await getTop5(predictions)

  // Clean up tensors
  input.dispose()
  predictions.dispose()

  return top5
}

async function getTop5(predictions) {
  const values = await predictions.data()
  const indices = Array.from(values)
    .map((prob, index) => ({ prob, index }))
    .sort((a, b) => b.prob - a.prob)
    .slice(0, 5)
    .map(x => ({
      className: IMAGENET_CLASSES[x.index],
      probability: x.prob
    }))

  return indices
}

Performance Optimization

1. Use WebGL Backend

WebGL uses the GPU for calculations:

// Check which backend is being used
console.log(tf.getBackend()) // Should be 'webgl'

// Manually set if needed
await tf.setBackend('webgl')

WebGL is 10-100x faster than CPU!

2. Warm Up the Model

The first inference is always slower. Warm it up:

async function warmUpModel(model) {
  const warmupTensor = tf.zeros([1, 224, 224, 3])
  await model.predict(warmupTensor)
  warmupTensor.dispose()
}

// After loading model
await warmUpModel(model)

3. Batch Predictions

Processing multiple images? Use batching:

async function batchPredict(images) {
  // Stack multiple images into one tensor
  const batch = tf.tidy(() => {
    const tensors = images.map(preprocessImage)
    return tf.stack(tensors)
  })

  const predictions = await model.predict(batch)

  batch.dispose()
  return predictions
}

4. Use Web Workers

Don’t block the main thread:

// worker.js
import * as tf from '@tensorflow/tfjs'

let model

self.addEventListener('message', async (e) => {
  const { type, data } = e.data

  if (type === 'load') {
    model = await tf.loadGraphModel(data.modelUrl)
    self.postMessage({ type: 'loaded' })
  }

  if (type === 'predict') {
    const result = await classify(data.image)
    self.postMessage({ type: 'result', data: result })
  }
})

// main.js
const worker = new Worker('worker.js')

worker.postMessage({
  type: 'load',
  data: { modelUrl: 'model.json' }
})

worker.addEventListener('message', (e) => {
  if (e.data.type === 'result') {
    console.log('Predictions:', e.data.data)
  }
})

Real-Time Video Classification

Classify webcam frames in real-time:

async function setupCamera() {
  const video = document.getElementById('webcam')

  const stream = await navigator.mediaDevices.getUserMedia({
    video: { width: 640, height: 480 }
  })

  video.srcObject = stream

  return new Promise(resolve => {
    video.onloadedmetadata = () => resolve(video)
  })
}

async function detectFrame(video) {
  const predictions = await classifyImage(video)

  // Display results
  displayPredictions(predictions)

  // Request next frame
  requestAnimationFrame(() => detectFrame(video))
}

// Start
const video = await setupCamera()
await video.play()
detectFrame(video)

Throttling for Performance

Don’t classify every frame:

let lastPredictionTime = 0
const PREDICTION_INTERVAL = 200 // ms

async function detectFrame(video) {
  const now = Date.now()

  if (now - lastPredictionTime > PREDICTION_INTERVAL) {
    const predictions = await classifyImage(video)
    displayPredictions(predictions)
    lastPredictionTime = now
  }

  requestAnimationFrame(() => detectFrame(video))
}

Memory Management

TensorFlow.js uses GPU memory. Leaks cause crashes!

Manual Disposal

const tensor = tf.tensor([1, 2, 3])
// Use tensor...
tensor.dispose() // Free memory

Use tf.tidy()

Automatically cleans up intermediate tensors:

const result = tf.tidy(() => {
  const a = tf.tensor([1, 2, 3])
  const b = tf.tensor([4, 5, 6])
  const c = a.add(b) // Intermediate tensor
  return c.mean() // Return value is NOT disposed
})

// Only 'result' exists, a/b/c were cleaned up

Monitor Memory

console.log('Num tensors:', tf.memory().numTensors)
console.log('Num bytes:', tf.memory().numBytes)

// Watch for leaks
setInterval(() => {
  const mem = tf.memory()
  console.log(`Tensors: ${mem.numTensors}, Bytes: ${mem.numBytes}`)
}, 5000)

Handling Edge Cases

Mobile Performance

Mobile GPUs are slower:

function isMobile() {
  return /Android|iPhone|iPad/i.test(navigator.userAgent)
}

const config = isMobile()
  ? { inputSize: 192, maxFPS: 15 }  // Smaller, slower
  : { inputSize: 224, maxFPS: 30 }  // Full quality

Fallback for Unsupported Browsers

async function setupTensorFlow() {
  try {
    await tf.setBackend('webgl')
    await tf.ready()
  } catch (e) {
    console.warn('WebGL not available, falling back to CPU')
    await tf.setBackend('cpu')
  }
}

Progressive Loading

Load a tiny model first, then upgrade:

// Load tiny model (1MB)
let model = await tf.loadGraphModel('tiny_model/model.json')
console.log('Tiny model ready!')

// Upgrade to full model in background
tf.loadGraphModel('full_model/model.json').then(fullModel => {
  model = fullModel
  console.log('Full model ready!')
})

Bundle Size Optimization

TensorFlow.js is large. Optimize your bundle:

Tree Shaking

Import only what you need:

// ❌ Imports everything (500KB)
import * as tf from '@tensorflow/tfjs'

// ✅ Import specific modules (150KB)
import '@tensorflow/tfjs-backend-webgl'
import { loadGraphModel } from '@tensorflow/tfjs-converter'
import { browser, tidy, image } from '@tensorflow/tfjs-core'

Code Splitting

Lazy load TensorFlow.js:

async function loadModel() {
  const tf = await import('@tensorflow/tfjs')
  return tf.loadGraphModel('model.json')
}

Use a CDN

Let browsers cache TensorFlow.js:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>

Production Checklist

Before deploying:

  • Quantize model weights (50% size reduction)
  • Warm up model on load
  • Use tf.tidy() everywhere
  • Monitor memory usage
  • Add loading indicators
  • Test on mobile devices
  • Add error handling
  • Consider Web Worker for long predictions
  • Implement progressive loading
  • Use WASM backend as fallback

Results

Our production app:

MetricValue
Model size7MB (quantized)
Load time1.2s (3G network)
Inference time45ms (desktop GPU)
Inference time180ms (mobile GPU)
Memory usage~150MB
Accuracy87% top-5

Lessons Learned

  1. Quantization is essential - 50% size reduction with <1% accuracy loss
  2. Memory leaks are easy - Always use tf.tidy() or manual disposal
  3. Mobile is much slower - Plan for 4-8x slower inference
  4. WebGL isn’t always available - Have a CPU fallback
  5. Bundle size matters - Tree-shake and code-split TensorFlow.js

Resources