Cua-BenchExamples

RL Training with GRPO

Train a GUI agent using TRL's GRPOTrainer with CUA-Bench environments on Modal

Train a multimodal GUI agent using GRPO (Group Relative Policy Optimization) with TRL and CUA-Bench on Modal.

How It Works

The training follows TRL's OpenEnv pattern:

┌─────────────────────────────────────────────────────────────┐
│                     GRPO Training Loop                      │
├─────────────────────────────────────────────────────────────┤
│  1. rollout_func() called with prompts                      │
│     ├─ Reset CUA-Bench environment                          │
│     ├─ Generate actions via vLLM                            │
│     ├─ Step environment, collect rewards                    │
│     └─ Return {prompt_ids, completion_ids, logprobs, ...}   │
│                                                             │
│  2. Reward function receives env_reward from rollout        │
│                                                             │
│  3. GRPO updates policy using rewards                       │
└─────────────────────────────────────────────────────────────┘

Step 1: Install Modal

pip install modal
modal setup

Optionally, create a wandb secret for logging:

modal secret create wandb-secret WANDB_API_KEY=your_key

Step 2: Create the Training Script

Create modal_grpo_training.py:

#!/usr/bin/env python3
"""GRPO Training with CUA-Bench on Modal.

This script runs GRPO (Group Relative Policy Optimization) training using
TRL's GRPOTrainer with CUA-Bench environments on Modal's cloud infrastructure.

Usage:
    # Setup Modal
    pip install modal
    modal setup

    # Run training
    modal run modal_grpo_training.py

    # With custom settings
    modal run modal_grpo_training.py --num-workers 4 --max-steps 1000
"""

from __future__ import annotations

from dataclasses import dataclass

import modal


# =============================================================================
# Training Configuration
# =============================================================================

@dataclass
class TrainerConfig:
    """Configuration for GRPO training hyperparameters."""

    # Model
    model_id: str = "Qwen/Qwen3-VL-2B-Instruct"
    task_prompt: str = "Complete the computer task successfully."

    # Environment
    num_workers: int = 2
    max_steps: int = 10
    max_history: int = 2

    # Generation
    num_generations: int = 4
    max_completion_length: int = 256
    temperature: float = 0.7

    # Training
    learning_rate: float = 5e-6
    gradient_accumulation_steps: int = 4
    per_device_train_batch_size: int = 1
    num_train_epochs: int = 1
    warmup_steps: int = 10

    # vLLM
    use_vllm: bool = True
    vllm_mode: str = "colocate"
    vllm_gpu_memory_utilization: float = 0.4
    vllm_max_model_length: int = 32768

    # Checkpointing
    output_dir: str = "/checkpoints/grpo-output"
    save_strategy: str = "steps"
    save_steps: int = 100

    # Logging
    logging_steps: int = 1
    use_wandb: bool = True
    wandb_project: str = "cua-bench-grpo"

    # Dataset
    dataset_size: int = 1000

    # Misc
    bf16: bool = True
    debug: bool = False


DEFAULT_CONFIG = TrainerConfig()


# =============================================================================
# Modal App Configuration
# =============================================================================

app = modal.App("cua-bench-grpo-training")

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("git", "chromium", "chromium-driver")
    .uv_pip_install(
        "trl[vllm]==0.27.1",
        "datasets==3.5.1",
        "verl",
        "pillow",
        "requests",
        "fastapi",
        "uvicorn",
        "playwright",
        "cua-bench",
    )
    .run_commands("playwright install chromium")
    # .add_local_dir("cua_bench", remote_path="/root/cua_bench")
)

checkpoint_volume = modal.Volume.from_name("cua-bench-grpo-checkpoints", create_if_missing=True)

# Optional: wandb secret for logging
# Create with: modal secret create wandb-secret WANDB_API_KEY=your_key


# =============================================================================
# Constants
# =============================================================================

SYSTEM_PROMPT = """You are a GUI agent. Complete tasks by interacting with the screen.

Available actions:
- click(x, y): Click at coordinates (0-1000 range)
- type(text): Type text
- hotkey(key1, key2, ...): Press keyboard shortcut
- scroll(x, y, direction): Scroll up/down
- wait(): Wait 1 second
- done(): Mark task complete

Reply with thinking in <|think_start|>...<|think_end|> tags,
then action in <|action_start|>...<|action_end|> tags.
Once you complete the task, output a done() action.

Example:
<|think_start|>I see a blue button. I'll click it.<|think_end|>
<|action_start|>click(500, 300)<|action_end|>"""

CLICK_TARGET_HTML = """<!DOCTYPE html>
<html>
<head>
    <style>
        body {{ margin: 0; display: flex; justify-content: center; align-items: center; height: 100vh; background: #f0f0f0; }}
        .target {{ position: absolute; left: {x}px; top: {y}px; padding: 12px 24px; background: #3b82f6; color: white; border: none; border-radius: 8px; cursor: pointer; }}
    </style>
</head>
<body>
    <h1>Click the Button</h1>
    <button class="target" onclick="window.__clicked=true;">Click Me</button>
    <script>window.__clicked = false;</script>
</body>
</html>"""

CLICK_TARGET_TASK_PY = '''
import cua_bench as cb
import random

HTML = """{html}"""
pid = None

@cb.tasks_config(split="train")
def get_tasks():
    return [cb.Task(
        description="Click the blue 'Click Me' button",
        computer={{"provider": "simulated", "setup_config": {{"width": 1024, "height": 720}}}},
        metadata={{"x": random.randint(100, 350), "y": random.randint(100, 250)}}
    ) for _ in range(50)]

@cb.setup_task(split="train")
async def setup(task, session):
    global pid
    html = HTML.format(x=task.metadata["x"], y=task.metadata["y"])
    pid = await session.launch_window(html=html, title="Click Target", width=512, height=384)

@cb.evaluate_task(split="train")
async def evaluate(task, session):
    return [1.0 if await session.execute_javascript(pid, "window.__clicked") else 0.0]
'''


# =============================================================================
# Training Function
# =============================================================================

@app.function(
    image=image,
    gpu="H100",
    timeout=60 * 60 * 4,  # 4 hours
    volumes={"/checkpoints": checkpoint_volume},
    secrets=[modal.Secret.from_name("wandb-secret")],
)
def train_grpo(
    model_id: str = DEFAULT_CONFIG.model_id,
    task_prompt: str = DEFAULT_CONFIG.task_prompt,
    num_workers: int = DEFAULT_CONFIG.num_workers,
    max_steps: int = DEFAULT_CONFIG.max_steps,
    max_history: int = DEFAULT_CONFIG.max_history,
    num_generations: int = DEFAULT_CONFIG.num_generations,
    dataset_size: int = DEFAULT_CONFIG.dataset_size,
    max_completion_length: int = DEFAULT_CONFIG.max_completion_length,
    temperature: float = DEFAULT_CONFIG.temperature,
    learning_rate: float = DEFAULT_CONFIG.learning_rate,
    gradient_accumulation_steps: int = DEFAULT_CONFIG.gradient_accumulation_steps,
    save_steps: int = DEFAULT_CONFIG.save_steps,
    use_wandb: bool = DEFAULT_CONFIG.use_wandb,
    wandb_project: str = DEFAULT_CONFIG.wandb_project,
    debug: bool = DEFAULT_CONFIG.debug,
):
    """Run GRPO training with CUA-Bench environments."""
    import asyncio
    import base64
    import io
    import re
    import tempfile
    from pathlib import Path

    from datasets import Dataset
    from PIL import Image
    from transformers import AutoTokenizer

    from trl import GRPOConfig, GRPOTrainer
    from vllm import SamplingParams

    from cua_bench.workers import CBEnvWorkerClient, cleanup_workers, create_workers

    # -------------------------------------------------------------------------
    # Multimodal rollout generation (based on TRL's generate_rollout_completions)
    # -------------------------------------------------------------------------

    def generate_rollout_completions_multimodal(
        trainer,
        prompts: list[str],
        images: list[list[Image.Image]] | None = None,
    ) -> list[dict]:
        """
        Generate completions for multimodal prompts using vLLM in colocate mode.

        Args:
            trainer: GRPOTrainer instance with vLLM configured
            prompts: List of text prompts
            images: Optional list of image lists, one per prompt

        Returns:
            List of dicts with prompt_ids, completion_ids, logprobs, and text
        """
        if not prompts:
            return []

        if not trainer.use_vllm or trainer.vllm_mode != "colocate":
            raise RuntimeError("Multimodal rollouts require vLLM in colocate mode.")

        # Build sampling params
        sampling_params = SamplingParams(
            n=1,
            temperature=trainer.temperature,
            top_k=trainer.top_k,
            min_p=0.0 if trainer.min_p is None else trainer.min_p,
            max_tokens=trainer.max_completion_length,
            logprobs=0,
        )
        if trainer.repetition_penalty is not None:
            sampling_params.repetition_penalty = trainer.repetition_penalty
        if trainer.top_p is not None:
            sampling_params.top_p = trainer.top_p

        # Wake up vLLM if sleep mode is enabled
        if trainer.args.vllm_enable_sleep_mode:
            trainer.llm.wake_up(tags=["kv_cache"])
            trainer.llm.collective_rpc("reload_weights")

        # Build inputs with multimodal data if images provided
        if images:
            inputs = []
            for i, prompt in enumerate(prompts):
                prompt_images = images[i] if i < len(images) else []
                if prompt_images:
                    inputs.append({
                        "prompt": prompt,
                        "multi_modal_data": {"image": prompt_images},
                    })
                else:
                    inputs.append({"prompt": prompt})
            vllm_outputs = trainer.llm.generate(inputs, sampling_params=sampling_params, use_tqdm=False)
        else:
            vllm_outputs = trainer.llm.generate(prompts, sampling_params=sampling_params, use_tqdm=False)

        # Process outputs
        results = []
        for request in vllm_outputs:
            if not request.outputs:
                results.append({
                    "prompt_ids": request.prompt_token_ids,
                    "completion_ids": [],
                    "logprobs": [],
                    "text": "",
                })
                continue
            sequence = request.outputs[0]
            logprobs = [
                next(iter(token_logprob.values())).logprob
                for token_logprob in sequence.logprobs
            ] if sequence.logprobs else []
            results.append({
                "prompt_ids": request.prompt_token_ids,
                "completion_ids": list(sequence.token_ids),
                "logprobs": logprobs,
                "text": sequence.text,
            })

        # Sleep vLLM if sleep mode is enabled
        if trainer.args.vllm_enable_sleep_mode:
            trainer.llm.sleep(level=2)

        return results

    # -------------------------------------------------------------------------
    # Helper functions
    # -------------------------------------------------------------------------

    def decode_image(b64_str: str) -> Image.Image:
        img = Image.open(io.BytesIO(base64.b64decode(b64_str)))
        return img.convert("RGB") if img.mode != "RGB" else img

    def extract_images(obs: str) -> tuple[str, list[Image.Image]]:
        """Extract base64 images from obs and replace with vLLM-compatible placeholder."""
        images = []
        pattern = r"<\|vision_start\|>(.*?)<\|vision_end\|>"
        def repl(m):
            images.append(decode_image(m.group(1)))
            # Use Qwen2-VL's placeholder format that vLLM expects
            return "<|vision_start|><|image_pad|><|vision_end|>"
        cleaned = re.sub(pattern, repl, obs, flags=re.DOTALL)
        return cleaned, images

    def make_prompt(
        tok,
        instruction: str,
        obs: str,
        step: int,
        history: list[tuple[str, str]] | None = None,
    ) -> tuple[str, list[Image.Image]]:
        """Build prompt with optional history of previous observations and actions.

        Args:
            tok: Tokenizer
            instruction: Task instruction
            obs: Current observation (may contain base64 images)
            step: Current step number
            history: List of (prev_obs, prev_action) tuples
        """
        all_images = []

        # Build history section
        history_parts = []
        if history:
            for i, (prev_obs, prev_action) in enumerate(history):
                cleaned_prev_obs, prev_images = extract_images(prev_obs)
                all_images.extend(prev_images)
                history_parts.append(f"Step {step - len(history) + i + 1}:\n{cleaned_prev_obs}\nAction: {prev_action}")

        # Current observation
        cleaned_obs, curr_images = extract_images(obs)
        all_images.extend(curr_images)

        # Build user content
        if history_parts:
            history_text = "\n\n".join(history_parts)
            user_content = f"Task: {instruction}\n\nHistory:\n{history_text}\n\nStep {step + 1} (current):\n{cleaned_obs}\n\nWhat action?"
        else:
            user_content = f"Task: {instruction}\n\nStep {step + 1}:\n{cleaned_obs}\n\nWhat action?"

        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ]
        prompt = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        return prompt, all_images

    def parse_action(response: str) -> str:
        match = re.search(r"<\|action_start\|>(.*?)<\|action_end\|>", response, re.DOTALL)
        return match.group(1).strip() if match else "wait()"

    # -------------------------------------------------------------------------
    # Episode rollout
    # -------------------------------------------------------------------------

    def rollout_episode(trainer, env, tok, prompt_text: str, max_hist: int = 2) -> dict:
        env_ret, meta = env.reset()

        # Get instruction from env.prompt (set during reset)
        instruction = env.prompt["instruction"] if env.prompt else prompt_text

        all_prompt_ids, all_completion_ids, all_logprobs = [], [], []
        all_images = []
        step_rewards = []
        history = []  # List of (screenshot, action) tuples

        for step in range(max_steps):
            if env_ret.get("done", False):
                break

            # Get current screenshot from env.prompt["steps"][-1]
            # steps contains: [screenshot0, action0, screenshot1, action1, ...]
            current_screenshot = env.prompt["steps"][-1] if env.prompt else ""

            # Get recent history (last max_hist steps)
            recent_history = history[-max_hist:] if history else None

            prompt, images = make_prompt(tok, instruction, current_screenshot, step, history=recent_history)
            if images:
                all_images.extend(images)

            rollout_output = generate_rollout_completions_multimodal(
                trainer, [prompt], images=[images] if images else None
            )[0]
            all_prompt_ids.extend(rollout_output["prompt_ids"])
            all_completion_ids.extend(rollout_output["completion_ids"])
            all_logprobs.extend(rollout_output["logprobs"])

            completion = rollout_output.get("text") or tok.decode(
                rollout_output["completion_ids"], skip_special_tokens=True
            )

            # Add current screenshot and action to history for next step
            action = parse_action(completion)
            history.append((current_screenshot, action))

            if debug:
                print(f"  Step {step + 1}: {action}")

            env_ret, _ = env.step(completion)
            step_rewards.append(float(env_ret.get("reward", 0.0)))

        result = {
            "prompt_ids": all_prompt_ids,
            "completion_ids": all_completion_ids,
            "logprobs": all_logprobs,
            "env_reward": step_rewards[-1] if step_rewards else 0.0,
        }
        if all_images:
            result["images"] = all_images
        return result

    # -------------------------------------------------------------------------
    # Reward function
    # -------------------------------------------------------------------------

    def reward_evaluator_func(completions: list[str], env_reward=None, **_) -> list[float]:
        if env_reward is not None:
            return [float(r) for r in env_reward]
        return [0.0] * len(completions)

    # -------------------------------------------------------------------------
    # Create task
    # -------------------------------------------------------------------------

    def create_task(tmp_dir: Path) -> Path:
        task_dir = tmp_dir / "click-target"
        task_dir.mkdir(exist_ok=True)
        task_code = CLICK_TARGET_TASK_PY.format(html=CLICK_TARGET_HTML)
        (task_dir / "main.py").write_text(task_code)
        return task_dir

    # -------------------------------------------------------------------------
    # Main training logic
    # -------------------------------------------------------------------------

    import os

    # Initialize wandb if available
    wandb_enabled = False
    if use_wandb and os.environ.get("WANDB_API_KEY"):
        try:
            import wandb
            wandb.init(
                project=wandb_project,
                config={
                    "model_id": model_id,
                    "num_workers": num_workers,
                    "max_steps": max_steps,
                    "num_generations": num_generations,
                    "learning_rate": learning_rate,
                    "temperature": temperature,
                    "dataset_size": dataset_size,
                },
            )
            wandb_enabled = True
            print("Wandb logging enabled")
        except ImportError:
            print("Wandb not installed, skipping logging")

    print("=" * 60)
    print("GRPO Training with CUA-Bench on Modal")
    print("=" * 60)
    print(f"Model:       {model_id}")
    print(f"Workers:     {num_workers}")
    print(f"Max steps:   {max_steps}")
    print(f"Generations: {num_generations}")
    print(f"Wandb:       {wandb_enabled}")
    print("=" * 60)

    # Setup tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Create temp task
    tmp_dir = Path(tempfile.mkdtemp(prefix="grpo-modal-"))
    task_path = str(create_task(tmp_dir))
    print(f"Created task: {task_path}")

    # Start workers
    print(f"Starting {num_workers} worker(s)...")
    workers = asyncio.run(create_workers(
        n_workers=num_workers,
        allowed_ips=["127.0.0.1"],
        startup_timeout=120.0,
    ))
    worker_urls = [w.api_url for w in workers]
    print(f"Workers ready: {worker_urls}")

    # Create environment clients
    task_configs = [{"env_path": task_path, "task_index": i % 10, "split": "train"} for i in range(num_workers)]
    envs = [
        CBEnvWorkerClient({
            "server_url": url,
            "task_configs": task_configs,
            "max_step": max_steps,
            "max_hist": 2,
            "timeout": 300,
        })
        for url in worker_urls
    ]

    # Dataset
    dataset = Dataset.from_dict({"prompt": [task_prompt] * dataset_size})

    # GRPO config
    config = GRPOConfig(
        output_dir=DEFAULT_CONFIG.output_dir,
        use_vllm=DEFAULT_CONFIG.use_vllm,
        vllm_mode=DEFAULT_CONFIG.vllm_mode,
        vllm_gpu_memory_utilization=DEFAULT_CONFIG.vllm_gpu_memory_utilization,
        vllm_max_model_length=DEFAULT_CONFIG.vllm_max_model_length,
        num_train_epochs=DEFAULT_CONFIG.num_train_epochs,
        per_device_train_batch_size=DEFAULT_CONFIG.per_device_train_batch_size,
        warmup_steps=DEFAULT_CONFIG.warmup_steps,
        logging_steps=DEFAULT_CONFIG.logging_steps,
        save_strategy=DEFAULT_CONFIG.save_strategy,
        bf16=DEFAULT_CONFIG.bf16,
        # From function args (overridable)
        learning_rate=learning_rate,
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_generations=num_generations,
        max_completion_length=max_completion_length,
        temperature=temperature,
        save_steps=save_steps,
        report_to="wandb" if wandb_enabled else "none",
    )

    # Rollout function
    env_idx = [0]

    def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
        all_prompt_ids, all_completion_ids, all_logprobs = [], [], []
        all_rewards, all_images = [], []

        if debug:
            print(f"\n[Rollout] Processing {len(prompts)} prompts")

        for i, prompt in enumerate(prompts):
            env = envs[env_idx[0] % len(envs)]
            env_idx[0] += 1

            if debug:
                print(f"[Rollout] Episode {i + 1}/{len(prompts)}")

            episode = rollout_episode(trainer, env, tokenizer, prompt, max_hist=max_history)

            all_prompt_ids.append(episode["prompt_ids"])
            all_completion_ids.append(episode["completion_ids"])
            all_logprobs.append(episode["logprobs"])
            all_rewards.append(episode["env_reward"])

            if "images" in episode:
                all_images.append(episode["images"])

        result = {
            "prompt_ids": all_prompt_ids,
            "completion_ids": all_completion_ids,
            "logprobs": all_logprobs,
            "env_reward": all_rewards,
        }
        if all_images:
            result["images"] = all_images
        return result

    # Create trainer
    trainer = GRPOTrainer(
        model=model_id,
        processing_class=tokenizer,
        reward_funcs=[reward_evaluator_func],
        train_dataset=dataset,
        args=config,
        rollout_func=rollout_func,
    )

    try:
        print("\nStarting training...")
        trainer.train()

        # Save final checkpoint
        trainer.save_model("/checkpoints/final")
        checkpoint_volume.commit()

        print("\nTraining complete!")
        return {"status": "complete", "checkpoints": "/checkpoints/final"}

    finally:
        print("Cleaning up workers...")
        asyncio.run(cleanup_workers(workers))
        import shutil
        shutil.rmtree(tmp_dir, ignore_errors=True)


# =============================================================================
# Entrypoint
# =============================================================================

@app.local_entrypoint()
def main(
    model_id: str | None = None,
    num_workers: int | None = None,
    max_steps: int | None = None,
    num_generations: int | None = None,
    dataset_size: int | None = None,
    learning_rate: float | None = None,
    save_steps: int | None = None,
    debug: bool = False,
):
    """Run GRPO training on Modal."""
    kwargs = {
        k: v for k, v in {
            "model_id": model_id,
            "num_workers": num_workers,
            "max_steps": max_steps,
            "num_generations": num_generations,
            "dataset_size": dataset_size,
            "learning_rate": learning_rate,
            "save_steps": save_steps,
            "debug": debug,
        }.items() if v is not None
    }
    result = train_grpo.remote(**kwargs)
    print(f"Training result: {result}")

Step 3: Run Training

modal run modal_grpo_training.py

With custom settings:

modal run modal_grpo_training.py \
    --model-id Qwen/Qwen3-VL-2B-Instruct \
    --num-workers 4 \
    --max-steps 10 \
    --num-generations 4 \
    --learning-rate 5e-6 \
    --debug

CLI Options

OptionDefaultDescription
--model-idQwen/Qwen3-VL-2B-InstructModel to train
--num-workers2Number of parallel environment workers
--max-steps10Max steps per episode
--num-generations4Rollouts per prompt
--dataset-size1000Number of training samples
--learning-rate5e-6Learning rate
--save-steps100Checkpoint save interval
--debugFalseEnable verbose output

Creating Custom Tasks

To train on your own task, create a task directory with main.py:

import cua_bench as cb

@cb.tasks_config(split="train")
def get_tasks():
    return [cb.Task(
        description="Your task description",
        computer={"provider": "simulated", "setup_config": {"width": 1024, "height": 720}},
    )]

@cb.setup_task(split="train")
async def setup(task, session):
    await session.launch_window(url="https://example.com")

@cb.evaluate_task(split="train")
async def evaluate(task, session):
    success = await session.execute_javascript(None, "checkSuccess()")
    return [1.0 if success else 0.0]

Then modify the script to use your task path instead of the built-in click target task.

Reference

Was this page helpful?