jax-development
Use this skill when the user is writing, debugging, profiling, refactoring, reviewing, benchmarking, parallelising, exporting, or explaining JAX code, or when they mention JAX, jax.numpy, jit, grad, value_and_grad, vmap, scan, lax, random keys, pytrees, jax.Array, sharding, Mesh, PartitionSpec, NamedSharding, pmap, shard_map, Pallas, XLA, StableHLO, checkify, profiler, or the JAX repo. It helps turn NumPy or PyTorch-style code into pure functional JAX, fix tracer/control-flow/shape/PRNG bugs, remove recompiles and host-device syncs, choose transforms and sharding strategies, inspect jaxpr/lowering/IR, and benchmark compiled code correctly.
What this skill does
# JAX Development Use this skill for substantial JAX work. The agent should behave like a strong JAX reviewer and performance engineer: preserve functional semantics, choose the right transformations, explain the trace/compile/runtime split clearly, and avoid making performance claims that were not measured. This version is designed to be unusually agent-friendly. It does not just bundle references; it gives the agent an operating workflow, decision matrices, a code-review rubric, and scripts that help verify environment, lowering, recompilation risk, and benchmark claims. ## Core promise When this skill is active, the default standard is: 1. produce runnable JAX code, not generic advice 2. explain why the change works in JAX terms 3. call out likely sharp bits even if the user did not ask 4. verify claims with the bundled scripts when possible 5. separate compile-time, run-time, transfer, and sharding issues instead of mixing them together ## When this skill should own the task Use this skill when the difficult part of the request is any of the following: - translating NumPy, SciPy, TensorFlow, or PyTorch code into idiomatic JAX - fixing tracer, control-flow, PRNG, shape, dtype, or side-effect bugs - choosing between `jit`, `vmap`, `scan`, `fori_loop`, `while_loop`, `cond`, `grad`, `jacrev`, `jacfwd`, `remat`, `shard_map`, or export - removing recompiles, host-device round trips, Python overhead, or dishonest benchmarking - reasoning about `jax.Array`, meshes, `PartitionSpec`, `NamedSharding`, explicit sharding, `pmap` migration, multi-host semantics, or collectives - using `jax.debug.print`, `checkify`, `make_jaxpr`, lowering, compiler IR, profiler traces, or memory profiling - using custom derivatives, export, AOT lowering, custom partitioning, Pallas, or the JAX source tree Compose this skill with framework-specific skills when needed, but let this one own the JAX-specific reasoning. ## Do not over-apply the skill Do not force JAX when the real problem is one of these instead: - pure NumPy optimisation where JAX is explicitly out of scope - generic CUDA, Triton, NCCL, or driver debugging with no meaningful JAX component - framework-only design questions whose hard part is not JAX - irregular dynamic object-heavy Python where the right answer is probably to keep the hot path outside JAX When in doubt, ask: “Is the root of the problem tracing, transformations, array semantics, compilation, sharding, or the JAX runtime?” If yes, use this skill. ## First-response workflow ### 1. Classify the task Put the request into one or more lanes immediately: - code design or porting - debugging or correctness - performance or compilation - sharding or distributed execution - advanced extension points - JAX repo navigation or source-level questions Then open the matching reference file: - `references/EXPERT-WORKFLOW.md` for the overall workflow - `references/MENTAL-MODEL.md` for tracing and staging semantics - `references/TRANSFORM-DECISION-MATRIX.md` for choosing primitives - `references/PORTING-PATTERNS.md` for NumPy or PyTorch rewrites - `references/CODE-REVIEW-RUBRIC.md` for self-review before replying - `references/DEBUGGING-TRIAGE.md` for error diagnosis - `references/PERFORMANCE-PLAYBOOK.md` for speed, memory, and compile-time work - `references/SHARDING-PLAYBOOK.md` for distributed and multi-device design - `references/ADVANCED-EXTENSIONS.md` for custom autodiff, export, Pallas, FFI, and internals - `references/REPO-MAP.md` for local source-tree navigation - `references/SOURCES.md` for provenance and maintenance notes ### 2. Inspect before guessing If the problem could be environment-, backend-, or project-specific, inspect first. Environment: ```bash python3 scripts/jax_env_report.py --format json ``` Static project scan: ```bash python3 scripts/jax_project_scan.py PATH --format json ``` Benchmark a callable honestly: ```bash python3 scripts/jax_benchmark_harness.py --help ``` Inspect jaxpr, lowering, and IR: ```bash python3 scripts/jax_compile_probe.py --help ``` Check likely recompile behaviour across cases: ```bash python3 scripts/jax_recompile_explorer.py --help ``` Search a local JAX checkout: ```bash python3 scripts/jax_repo_locator.py --help ``` ### 3. Reduce to a minimal reproducer Prefer the smallest function that still exhibits the behaviour. JAX problems get much easier once shapes, dtypes, batching axes, randomness, and transformation boundaries are explicit. ### 4. Choose the least powerful mechanism that solves the problem Default ordering: - pure eager `jax.numpy` first - then `jit` or `value_and_grad` - then `vmap` or `scan` - then explicit sharding - then `shard_map` - then custom derivative, export, custom partitioning, or Pallas - then FFI or JAX internals Escalate only with evidence. ### 5. End with a high-signal answer Unless the user asked for something else, the reply should end with: - diagnosis or design choice - corrected code or patch - why it works in JAX terms - how to verify it - remaining risks, backend caveats, or performance unknowns ## Expert operating rules 1. **Treat JAX functions as pure.** Inputs in, outputs out. Hidden mutation, global state, or implicit randomness are usually design bugs once transforms enter the picture. 2. **Make randomness explicit.** Thread keys through the program, split once per consumer, and return updated keys when state continues. 3. **Keep the hot path in JAX space.** Host conversion inside transformed code is almost always a bug or a sync point. 4. **Separate static and dynamic values.** Shapes, dtypes, Python objects, and some configuration values influence tracing and compilation. 5. **Use structured control flow.** If a branch or loop depends on array values, use JAX control-flow primitives instead of Python. 6. **Benchmark honestly.** Warm up, block, and distinguish transfer cost, compile cost, and steady-state execution. 7. **Optimise after evidence.** Use scans, compile probes, profiler traces, or lowering inspection before proposing deep rewrites. 8. **Prefer current JAX idioms.** Typed keys, `jax.Array`, and modern sharding APIs are the default unless the codebase is intentionally legacy. 9. **Think globally for sharding first.** Start with global-view code and explicit placement before dropping to per-device manual code. 10. **Never bluff backend-specific behaviour.** CPU, GPU, TPU, and multi-host runs differ materially. Say what was verified and what was inferred. ## Default red flags to proactively check Always scan for these, even if the user did not mention them: - `np.asarray`, `.item()`, `.tolist()`, `jax.device_get`, or printing arrays in a hot path - Python `if`, `for`, or `while` inside transformed code - shape construction or indexing based on traced values - global or reused PRNG keys - repeated creation of jitted callables inside loops - changing shapes, dtypes, or static arguments causing compile storms - very large Python loops that should be `scan` or `fori_loop` - `pmap` code that may be better expressed with modern sharding APIs - unexplained precision assumptions or implicit `x64` expectations - replicated-versus-sharded confusion in distributed code ## Available scripts - `scripts/jax_env_report.py` — report versions, backend, devices, config, env vars, and an optional smoke test. - `scripts/jax_project_scan.py` — AST-based scan for common JAX sharp bits and migration targets. - `scripts/jax_benchmark_harness.py` — benchmark a callable with warm-up, blocking, optional `jit`, and optional donation. - `scripts/jax_compile_probe.py` — inspect `eval_shape`, jaxpr, lowering, and compiler IR; optionally write artefacts to disk. - `scripts/jax_recompile_explorer.py` — run several input cases through a jitted function and flag likely recompiles or signature drift. - `scripts/jax_repo_locator.py` — search a local JAX checkout for relevant docs, tests, or source files by topic. All scripts are non-interactive, support `--h
Related in Writing & Docs
nature-article-writer
IncludedDrafts, rewrites, diagnostically critiques, and style-calibrates primary research manuscripts for Nature and Nature Portfolio journals. Use when the user wants a Nature-style title, summary paragraph or abstract, introduction, results, discussion, methods, figure legends, presubmission enquiry, cover letter, reviewer response, or when a scientific draft sounds generic, jargon-heavy, structurally weak, or AI-ish and needs precise, broad-reader-friendly prose without inventing data, analyses, or references. Best for primary research articles and letters rather than reviews or press releases unless explicitly adapting one.
deckrd
IncludedDocument-driven framework that derives requirements, specifications, implementation plans, and executable tasks from goals through structured AI dialogue. Use when user says "write requirements", "create spec", "plan implementation", "derive tasks", "structure this feature", "break down into tasks", or "document this module". Also use for reverse engineering existing code into docs (/deckrd rev). Do NOT use for direct code writing — use /deckrd-coder after tasks are generated. Do NOT use when the user only wants to run or fix existing code without planning.
clinical-decision-support
IncludedGenerate professional clinical decision support (CDS) documents for pharmaceutical and clinical research settings, including patient cohort analyses (biomarker-stratified with outcomes) and treatment recommendation reports (evidence-based guidelines with decision algorithms). Supports GRADE evidence grading, statistical analysis (hazard ratios, survival curves, waterfall plots), biomarker integration, and regulatory compliance. Outputs publication-ready LaTeX/PDF format optimized for drug development, clinical research, and evidence synthesis.
handling-sf-data
IncludedSalesforce data operations with 130-point scoring. Use this skill to create, update, delete, bulk import/export, generate test data, and clean up org records using sf CLI and anonymous Apex. TRIGGER when: user creates test data, performs bulk import/export, uses sf data CLI commands, needs data factory patterns for Apex tests, or needs to seed/clean records in a Salesforce org. DO NOT TRIGGER when: SOQL query writing only (use querying-soql), Apex test execution (use running-apex-tests), or metadata deployment (use deploying-metadata).
accelint-ac-to-playwright
IncludedConvert and validate acceptance criteria for Playwright test automation. Use when user asks to (1) review/evaluate/check if AC are ready for automation, (2) assess if AC can be converted as-is, (3) validate AC quality for Playwright, (4) turn AC into tests, (5) generate tests from acceptance criteria, (6) convert .md bullets or .feature Gherkin files to Playwright specs, (7) create test automation from requirements. Handles both bullet-style markdown and Gherkin syntax with JSON test plan generation and validation.
clinical-decision-support
IncludedGenerate professional clinical decision support (CDS) documents for pharmaceutical and clinical research settings, including patient cohort analyses (biomarker-stratified with outcomes) and treatment recommendation reports (evidence-based guidelines with decision algorithms). Supports GRADE evidence grading, statistical analysis (hazard ratios, survival curves, waterfall plots), biomarker integration, and regulatory compliance. Outputs publication-ready LaTeX/PDF format optimized for drug development, clinical research, and evidence synthesis.