Skip to content

jax 0.9.1

Version history | Download JSON

  • PyPI https://pypi.org/project/jax/
    Repository https://github.com/jax-ml/jax

  • py.typed

  • Coverage


    %%{init: {"pie": {"textPosition": 0.85}, "theme": "neutral", "themeVariables": {"pieStrokeWidth": "1px"}}}%%
    pie title
        "Typed" : 7898
        "Any" : 278
        "Untyped" : 3915
    • 67.6% coverage
    • 65.3% coverage (strict)
    • 12091 typable
      • 7898 typed
      • 3915 untyped
      • 278 Any
  • Typables


    %%{init: {"pie": {"textPosition": 0.85}, "theme": "neutral", "themeVariables": {"pieStrokeWidth": "1px"}}}%%
    pie title
        "functions" : 7905
        "classes" : 3365
        "other" : 494
    • 1867 functions (+78 overloads)
      • 6038 parameters
    • 343 classes
      • 1264 methods (+0 overloads)
        • 1839 parameters
      • 262 properties
    • 280 modules
      • 494 attrs

Modules

Module Coverage Coverage (strict) Typables Ignores
jax 0.0% 0.0% 0 0
jax._src.abstract_arrays 100.0% 100.0% 1 0
jax._src.ad_checkpoint 54.5% 54.5% 22 2
jax._src.ad_util 56.5% 52.2% 23 1
jax._src.api 83.1% 77.4% 124 2
jax._src.api_util 56.4% 56.4% 39 0
jax._src.array 66.7% 66.7% 24 1
jax._src.basearray 96.6% 93.1% 319 7
jax._src.buffer_callback 60.0% 60.0% 10 0
jax._src.callback 88.5% 65.4% 26 0
jax._src.checkify 41.5% 41.5% 41 1
jax._src.compilation_cache 66.7% 66.7% 3 0
jax._src.compiler 91.7% 91.7% 12 2
jax._src.compute_on 50.0% 50.0% 2 0
jax._src.config 5.9% 5.9% 34 1
jax._src.core 26.5% 25.2% 524 18
jax._src.custom_batching 45.5% 45.5% 11 0
jax._src.custom_dce 58.3% 58.3% 12 0
jax._src.custom_derivatives 61.1% 53.7% 54 0
jax._src.custom_partitioning 4.2% 0.0% 24 1
jax._src.custom_partitioning_sharding_rule 56.5% 56.5% 23 0
jax._src.custom_transpose 40.0% 40.0% 10 0
jax._src.debugger.core 57.1% 57.1% 7 0
jax._src.debugging 69.0% 62.1% 29 1
jax._src.dispatch 0.0% 0.0% 5 0
jax._src.distributed 85.7% 85.7% 14 0
jax._src.dlpack 66.7% 66.7% 6 0
jax._src.dtypes 88.2% 70.6% 17 1
jax._src.earray 6.7% 6.7% 30 13
jax._src.effects 100.0% 100.0% 3 0
jax._src.environment_info 100.0% 100.0% 2 0
jax._src.errors 53.8% 53.8% 13 0
jax._src.export._export 81.0% 81.0% 42 0
jax._src.export.shape_poly 58.1% 58.1% 31 3
jax._src.extend.random 100.0% 100.0% 8 0
jax._src.ffi 95.5% 84.1% 44 0
jax._src.flatten_util 100.0% 50.0% 2 0
jax._src.hijax 36.1% 25.4% 122 23
jax._src.image.scale 47.1% 47.1% 17 0
jax._src.indexing 57.1% 57.1% 14 0
jax._src.interpreters.ad 6.1% 5.1% 98 0
jax._src.interpreters.batching 32.1% 32.1% 28 3
jax._src.interpreters.mlir 83.5% 80.6% 170 2
jax._src.interpreters.partial_eval 67.1% 67.1% 70 7
jax._src.interpreters.pxla 66.7% 66.7% 3 3
jax._src.lax.ann 93.3% 93.3% 15 0
jax._src.lax.control_flow.conditionals 50.0% 31.2% 16 1
jax._src.lax.control_flow.loops 68.5% 68.5% 54 4
jax._src.lax.control_flow.solves 57.1% 42.9% 14 0
jax._src.lax.convolution 63.6% 63.6% 66 0
jax._src.lax.fft 40.0% 40.0% 5 0
jax._src.lax.lax 70.3% 69.7% 495 3
jax._src.lax.linalg 80.5% 80.5% 82 0
jax._src.lax.other 100.0% 100.0% 20 0
jax._src.lax.parallel 6.4% 6.4% 94 2
jax._src.lax.scaled_dot 85.7% 85.7% 7 0
jax._src.lax.slicing 91.7% 91.7% 132 0
jax._src.lax.special 72.0% 72.0% 50 0
jax._src.lax.windowed_reductions 39.1% 26.1% 23 0
jax._src.layout 37.9% 37.9% 29 0
jax._src.linear_util 62.3% 62.3% 77 0
jax._src.memory 0.0% 0.0% 1 0
jax._src.mesh 13.8% 13.8% 87 1
jax._src.mesh_utils 100.0% 100.0% 13 0
jax._src.monitoring 97.0% 97.0% 33 0
jax._src.named_sharding 71.8% 71.8% 39 0
jax._src.nn.initializers 100.0% 88.6% 70 0
jax._src.numpy.fft 100.0% 100.0% 86 0
jax._src.numpy.linalg 96.5% 96.5% 115 0
jax._src.ops.scatter 100.0% 100.0% 32 0
jax._src.ops.special 100.0% 100.0% 7 0
jax._src.pallas.core 54.7% 50.9% 53 1
jax._src.pallas.cost_estimate 25.0% 25.0% 4 0
jax._src.pallas.einshape 100.0% 100.0% 5 1
jax._src.pallas.fuser.block_spec 46.7% 46.7% 15 0
jax._src.pallas.fuser.custom_evaluate 33.3% 33.3% 3 0
jax._src.pallas.fuser.custom_fusion_lib 50.0% 50.0% 18 3
jax._src.pallas.fuser.fusible 33.3% 0.0% 3 1
jax._src.pallas.fuser.fusion 37.5% 37.5% 8 0
jax._src.pallas.fuser.jaxpr_fusion 50.0% 50.0% 4 0
jax._src.pallas.helpers 77.4% 71.0% 31 0
jax._src.pallas.mosaic.core 75.0% 75.0% 40 0
jax._src.pallas.mosaic.helpers 44.4% 44.4% 9 0
jax._src.pallas.mosaic.interpret.interpret_pallas_call 50.0% 50.0% 6 0
jax._src.pallas.mosaic.lowering 0.0% 0.0% 0 0
jax._src.pallas.mosaic.pipeline 35.5% 34.9% 152 0
jax._src.pallas.mosaic.primitives 60.8% 59.5% 79 2
jax._src.pallas.mosaic.random 64.3% 64.3% 14 4
jax._src.pallas.mosaic.sc_core 69.0% 69.0% 29 0
jax._src.pallas.mosaic.sc_primitives 98.6% 98.6% 71 5
jax._src.pallas.mosaic.tpu_info 100.0% 100.0% 11 0
jax._src.pallas.mosaic_gpu.core 76.4% 72.6% 106 3
jax._src.pallas.mosaic_gpu.helpers 78.9% 78.9% 19 0
jax._src.pallas.mosaic_gpu.pipeline 95.0% 95.0% 20 0
jax._src.pallas.mosaic_gpu.primitives 77.4% 77.4% 106 0
jax._src.pallas.mosaic_gpu.torch 0.0% 0.0% 2 0
jax._src.pallas.pallas_call 93.8% 81.2% 16 7
jax._src.pallas.primitives 42.2% 38.6% 83 1
jax._src.pallas.triton.core 0.0% 0.0% 0 0
jax._src.pallas.triton.primitives 100.0% 100.0% 21 0
jax._src.pallas.utils 100.0% 100.0% 7 0
jax._src.partition_spec 6.7% 6.7% 30 0
jax._src.pjit 10.5% 10.5% 19 5
jax._src.prng 35.3% 35.3% 17 9
jax._src.profiler 61.5% 61.5% 26 0
jax._src.public_test_util 0.0% 0.0% 24 0
jax._src.random 94.0% 94.0% 233 0
jax._src.ref 100.0% 33.3% 3 0
jax._src.scipy.cluster.vq 100.0% 100.0% 4 0
jax._src.scipy.fft 100.0% 100.0% 24 0
jax._src.scipy.integrate 100.0% 100.0% 5 0
jax._src.scipy.linalg 100.0% 98.5% 134 0
jax._src.scipy.ndimage 83.3% 83.3% 6 0
jax._src.scipy.optimize.minimize 100.0% 100.0% 7 0
jax._src.scipy.signal 100.0% 100.0% 85 0
jax._src.scipy.sparse.linalg 0.0% 0.0% 26 0
jax._src.scipy.spatial.transform 63.3% 63.3% 49 0
jax._src.scipy.special 100.0% 100.0% 117 0
jax._src.scipy.stats._core 100.0% 100.0% 16 1
jax._src.scipy.stats.bernoulli 100.0% 100.0% 14 0
jax._src.scipy.stats.beta 100.0% 100.0% 36 0
jax._src.scipy.stats.betabinom 100.0% 100.0% 12 0
jax._src.scipy.stats.binom 100.0% 100.0% 10 0
jax._src.scipy.stats.cauchy 100.0% 100.0% 32 0
jax._src.scipy.stats.chi2 100.0% 100.0% 30 0
jax._src.scipy.stats.dirichlet 100.0% 100.0% 6 0
jax._src.scipy.stats.expon 100.0% 100.0% 28 0
jax._src.scipy.stats.gamma 100.0% 100.0% 30 0
jax._src.scipy.stats.gennorm 100.0% 100.0% 9 0
jax._src.scipy.stats.geom 100.0% 100.0% 8 0
jax._src.scipy.stats.gumbel_l 100.0% 100.0% 28 0
jax._src.scipy.stats.gumbel_r 100.0% 100.0% 28 0
jax._src.scipy.stats.kde 2.5% 2.5% 40 0
jax._src.scipy.stats.laplace 100.0% 100.0% 12 0
jax._src.scipy.stats.logistic 100.0% 100.0% 24 0
jax._src.scipy.stats.multinomial 100.0% 100.0% 8 0
jax._src.scipy.stats.multivariate_normal 100.0% 100.0% 9 0
jax._src.scipy.stats.nbinom 100.0% 100.0% 10 0
jax._src.scipy.stats.norm 100.0% 100.0% 32 0
jax._src.scipy.stats.pareto 100.0% 100.0% 35 0
jax._src.scipy.stats.poisson 100.0% 100.0% 15 0
jax._src.scipy.stats.t 100.0% 100.0% 10 0
jax._src.scipy.stats.truncnorm 0.0% 0.0% 36 0
jax._src.scipy.stats.uniform 100.0% 100.0% 16 0
jax._src.scipy.stats.vonmises 100.0% 100.0% 6 0
jax._src.scipy.stats.wrapcauchy 100.0% 100.0% 6 0
jax._src.shard_alike 0.0% 0.0% 3 0
jax._src.shard_map 100.0% 83.3% 12 0
jax._src.sharding 96.2% 96.2% 26 0
jax._src.sharding_impls 65.4% 65.4% 78 1
jax._src.source_info_util 87.1% 87.1% 31 0
jax._src.sourcemap 50.0% 50.0% 8 0
jax._src.stages 39.7% 37.2% 78 0
jax._src.state.discharge 100.0% 100.0% 2 2
jax._src.state.primitives 100.0% 100.0% 19 7
jax._src.state.types 16.5% 14.7% 109 5
jax._src.stateful_rng 92.9% 92.9% 28 0
jax._src.third_party.scipy.interpolate 0.0% 0.0% 22 0
jax._src.third_party.scipy.linalg 100.0% 100.0% 4 0
jax._src.third_party.scipy.special 100.0% 100.0% 2 0
jax._src.tpu.linalg.qdwh 66.7% 66.7% 6 0
jax._src.tpu_custom_call 100.0% 95.7% 47 1
jax._src.tree 96.3% 59.3% 54 0
jax._src.tree_util 93.7% 68.4% 95 3
jax._src.typing 100.0% 100.0% 1 0
jax._src.util 85.7% 85.7% 7 2
jax._src.xla_bridge 97.2% 97.2% 36 0
jax._src.xla_metadata 0.0% 0.0% 3 0
jax.collect_profile 46.2% 46.2% 13 0
jax.core 0.0% 0.0% 0 0
jax.example_libraries.optimizers 20.5% 17.8% 73 0
jax.example_libraries.stax 0.0% 0.0% 63 0
jax.experimental.array_serialization.pytree_serialization 100.0% 72.7% 22 0
jax.experimental.array_serialization.pytree_serialization_utils 50.0% 50.0% 4 0
jax.experimental.array_serialization.serialization 51.1% 51.1% 47 0
jax.experimental.array_serialization.tensorstore_impl 69.8% 67.4% 43 0
jax.experimental.colocated_python.api 83.3% 83.3% 6 0
jax.experimental.colocated_python.func 76.5% 76.5% 17 0
jax.experimental.colocated_python.func_backend 0.0% 0.0% 1 0
jax.experimental.colocated_python.obj 75.0% 75.0% 4 0
jax.experimental.colocated_python.obj_backend 0.0% 0.0% 1 0
jax.experimental.colocated_python.serialization 0.0% 0.0% 0 1
jax.experimental.fused 0.0% 0.0% 3 0
jax.experimental.jax2tf.call_tf 47.8% 43.5% 23 0
jax.experimental.jax2tf.jax2tf 61.0% 48.8% 41 0
jax.experimental.jet 1.1% 1.1% 93 0
jax.experimental.mosaic.dialects 0.0% 0.0% 0 0
jax.experimental.mosaic.gpu 0.0% 0.0% 0 0
jax.experimental.mosaic.gpu.constraints 78.0% 78.0% 59 0
jax.experimental.mosaic.gpu.core 46.2% 46.2% 39 0
jax.experimental.mosaic.gpu.dialect_lowering 90.9% 90.9% 33 0
jax.experimental.mosaic.gpu.fragmented_array 64.7% 63.8% 340 0
jax.experimental.mosaic.gpu.inference_utils 96.4% 96.4% 55 0
jax.experimental.mosaic.gpu.launch_context 82.6% 80.4% 138 0
jax.experimental.mosaic.gpu.layout_inference 93.1% 93.1% 58 0
jax.experimental.mosaic.gpu.layouts 100.0% 100.0% 35 0
jax.experimental.mosaic.gpu.mma 55.6% 55.6% 9 0
jax.experimental.mosaic.gpu.mma_utils 73.7% 73.7% 19 0
jax.experimental.mosaic.gpu.profiler 63.3% 63.3% 49 1
jax.experimental.mosaic.gpu.tcgen05 88.4% 88.4% 95 0
jax.experimental.mosaic.gpu.utils 58.9% 58.9% 275 1
jax.experimental.mosaic.gpu.wgmma 58.8% 58.8% 34 0
jax.experimental.multihost_utils 46.8% 21.3% 47 0
jax.experimental.ode 3.4% 3.4% 59 0
jax.experimental.pallas 0.0% 0.0% 0 0
jax.experimental.pallas.ops.gpu.all_gather_mgpu 100.0% 100.0% 7 0
jax.experimental.pallas.ops.gpu.attention 54.2% 50.8% 59 0
jax.experimental.pallas.ops.gpu.attention_mgpu 18.2% 18.2% 22 0
jax.experimental.pallas.ops.gpu.blackwell_matmul_mgpu 33.3% 33.3% 6 0
jax.experimental.pallas.ops.gpu.blackwell_ragged_dot_mgpu 33.3% 33.3% 36 0
jax.experimental.pallas.ops.gpu.collective_matmul_mgpu 85.7% 85.7% 7 0
jax.experimental.pallas.ops.gpu.decode_attention 44.3% 42.0% 88 0
jax.experimental.pallas.ops.gpu.hopper_matmul_mgpu 20.0% 20.0% 15 0
jax.experimental.pallas.ops.gpu.hopper_mixed_type_matmul_mgpu 76.9% 76.9% 13 0
jax.experimental.pallas.ops.gpu.layer_norm 37.1% 37.1% 62 0
jax.experimental.pallas.ops.gpu.paged_attention 83.9% 80.4% 56 0
jax.experimental.pallas.ops.gpu.ragged_dot_mgpu 47.1% 47.1% 17 0
jax.experimental.pallas.ops.gpu.reduce_scatter_mgpu 87.5% 87.5% 8 0
jax.experimental.pallas.ops.gpu.rms_norm 39.0% 39.0% 59 0
jax.experimental.pallas.ops.gpu.softmax 100.0% 100.0% 6 0
jax.experimental.pallas.ops.gpu.transposed_ragged_dot_mgpu 40.0% 40.0% 15 0
jax.experimental.pallas.ops.tpu.all_gather 58.8% 58.8% 17 0
jax.experimental.pallas.ops.tpu.example_kernel 0.0% 0.0% 5 0
jax.experimental.pallas.ops.tpu.flash_attention 32.1% 32.1% 56 0
jax.experimental.pallas.ops.tpu.matmul 50.0% 50.0% 12 0
jax.experimental.pallas.ops.tpu.megablox.common 100.0% 100.0% 9 0
jax.experimental.pallas.ops.tpu.megablox.gmm 100.0% 96.4% 28 0
jax.experimental.pallas.ops.tpu.megablox.ops 0.0% 0.0% 1 0
jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel 28.9% 28.9% 83 0
jax.experimental.pallas.ops.tpu.paged_attention.quantization_utils 100.0% 100.0% 14 0
jax.experimental.pallas.ops.tpu.paged_attention.util 87.5% 87.5% 8 0
jax.experimental.pallas.ops.tpu.ragged_paged_attention.kernel 64.9% 64.9% 97 0
jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes 21.1% 21.1% 19 0
jax.experimental.pallas.ops.tpu.random.philox 42.9% 42.9% 35 0
jax.experimental.pallas.ops.tpu.random.prng_utils 62.5% 62.5% 8 0
jax.experimental.pallas.ops.tpu.random.threefry 54.5% 54.5% 11 0
jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel 55.9% 54.8% 93 0
jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask 66.3% 66.3% 83 0
jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info 0.0% 0.0% 4 0
jax.experimental.profiler 100.0% 100.0% 2 0
jax.experimental.rnn 70.3% 69.3% 101 1
jax.experimental.roofline.roofline 94.1% 88.2% 34 0
jax.experimental.scheduling_groups 27.3% 27.3% 11 0
jax.experimental.serialize_executable 37.5% 37.5% 8 0
jax.experimental.shard_map 0.0% 0.0% 6 0
jax.experimental.source_mapper.common 30.8% 30.8% 13 0
jax.experimental.source_mapper.generate_map 50.0% 50.0% 4 0
jax.experimental.source_mapper.hlo 47.4% 47.4% 19 6
jax.experimental.source_mapper.jaxpr 27.3% 27.3% 11 2
jax.experimental.source_mapper.mlir 50.0% 50.0% 8 0
jax.experimental.sparse 0.0% 0.0% 0 0
jax.experimental.sparse._base 19.0% 19.0% 42 1
jax.experimental.sparse.ad 64.0% 64.0% 25 0
jax.experimental.sparse.api 82.4% 82.4% 17 0
jax.experimental.sparse.bcoo 73.6% 73.6% 174 0
jax.experimental.sparse.bcsr 52.8% 52.8% 89 0
jax.experimental.sparse.coo 78.9% 78.9% 57 0
jax.experimental.sparse.csr 26.5% 26.5% 83 0
jax.experimental.sparse.linalg 30.8% 30.8% 13 0
jax.experimental.sparse.nm 69.2% 69.2% 13 0
jax.experimental.sparse.random 0.0% 0.0% 12 0
jax.experimental.sparse.test_util 23.9% 23.9% 67 0
jax.experimental.sparse.transform 31.5% 20.2% 89 0
jax.experimental.sparse.util 0.0% 0.0% 9 0
jax.experimental.topologies 78.6% 78.6% 14 0
jax.experimental.transfer 52.6% 31.6% 19 0
jax.extend.backend 0.0% 0.0% 0 1
jax.extend.linear_util 50.0% 50.0% 4 0
jax.extend.sharding 100.0% 100.0% 6 0
jax.extend.source_info_util 0.0% 0.0% 0 0
jax.interpreters.mlir 0.0% 0.0% 3 0
jax.interpreters.partial_eval 0.0% 0.0% 0 0
jax.interpreters.xla 0.0% 0.0% 1 0
jax.nn 100.0% 98.3% 117 0
jax.nn.initializers 0.0% 0.0% 0 0
jax.numpy 98.5% 94.2% 1421 0
jax.profiler 0.0% 0.0% 3 0
jax.sharding 0.0% 0.0% 0 0
jax.tree_util 0.0% 0.0% 1 0
jax.version 0.0% 0.0% 2 0

Incomplete Annotations

jax._src.ad_checkpoint (10 missing, 0 any)

Symbol Typable Typed Any
attr checkpoint_policies 1 0 0
func checkpoint_name 3 0 0
func print_saved_residuals 4 0 0
attr remat_p 1 0 0
attr name_p 1 0 0

jax._src.ad_util (10 missing, 1 any)

Symbol Typable Typed Any
func zero_from_primal 3 0 0
meth SymbolicZero.getattr 2 0 0
meth SymbolicZero.from_primal_value 2 2 1
attr SymbolicZero.aval 1 0 0
meth Zero.init 2 1 0
meth Zero.instantiate 1 0 0
attr Zero.aval 1 0 0
attr add_jaxvals_p 1 0 0

jax._src.api (21 missing, 7 any)

Symbol Typable Typed Any
func vjp 5 5 1
func clear_caches 1 0 0
func vmap 8 8 1
func jit 13 13 2
func disable_jit 2 1 0
func eval_shape 4 1 0
func effects_barrier 1 0 0
func pmap 11 11 2
func device_put 6 4 0
func device_get 2 1 1
func jvp 5 3 0
func linearize 4 3 0
func linear_transpose 4 2 0
func block_until_ready 2 0 0
func copy_to_host_async 2 0 0
func live_arrays 2 0 0
func clear_backends 1 0 0

jax._src.api_util (17 missing, 0 any)

Symbol Typable Typed Any
func flatten_axes 6 0 0
func argnums_partial 5 3 0
func flatten_fun 5 3 0
func donation_vector 5 2 0
func rebase_donate_argnums 3 1 0
func flatten_fun_nokwargs 5 3 0

jax._src.array (8 missing, 0 any)

Symbol Typable Typed Any
func make_array_from_process_local_data 4 0 0
meth Shard.init 5 4 0
meth Shard.repr 1 0 0
prop Shard.device 1 0 0
prop Shard.data 1 0 0

jax._src.basearray (11 missing, 11 any)

Symbol Typable Typed Any
meth Array.init 7 0 0
meth Array.getitem 2 1 0
meth Array.setitem 3 1 0
meth Array.iter 1 1 1
meth Array.reversed 1 1 1
meth Array.round 2 1 0
meth Array.dot 4 4 1
meth Array.item 2 2 1
meth Array.reshape 4 4 1
meth Array.transpose 2 2 1
meth Array.dlpack 1 1 1
meth Array._rewrap_with_aval_and_sharding 3 3 2
prop Array.traceback 1 1 1
attr Array.aval 1 1 1

jax._src.buffer_callback (4 missing, 0 any)

Symbol Typable Typed Any
attr ExecutionContext 1 0 0
attr ExecutionStage 1 0 0
func buffer_callback 7 6 0
attr Buffer 1 0 0

jax._src.callback (3 missing, 6 any)

Symbol Typable Typed Any
func pure_callback 8 7 3
func io_callback 7 6 3
func emit_python_callback 11 10 0

jax._src.checkify (24 missing, 0 any)

Symbol Typable Typed Any
meth Error.throw 1 0 0
meth Error.str 1 0 0
meth Error._update 6 1 0
meth Error._add_placeholder_effects 2 1 0
meth Error._replace 3 0 0
meth Error.tree_flatten 1 0 0
meth Error.tree_unflatten 3 0 0
attr user_checks 1 0 0
func check 6 4 0
attr init_error 1 0 0
func debug_check 5 3 0
attr index_checks 1 0 0
attr nan_checks 1 0 0
attr div_checks 1 0 0

jax._src.compilation_cache (1 missing, 0 any)

Symbol Typable Typed Any
func set_cache_dir 2 1 0

jax._src.compiler (1 missing, 0 any)

Symbol Typable Typed Any
func get_compile_options 12 11 0

jax._src.compute_on (1 missing, 0 any)

Symbol Typable Typed Any
func compute_on 2 1 0

jax._src.config (32 missing, 0 any)

Symbol Typable Typed Any
attr check_tracer_leaks 1 0 0
attr threefry_partitionable 1 0 0
attr checking_leaks 1 0 0
attr debug_infs 1 0 0
attr array_garbage_collection_guard 1 0 0
attr numpy_dtype_promotion 1 0 0
attr transfer_guard_device_to_device 1 0 0
attr remove_size_one_mesh_axis_from_type 1 0 0
attr enable_custom_prng 1 0 0
attr thread_guard 1 0 0
attr numpy_rank_promotion 1 0 0
attr no_tracing 1 0 0
attr no_execution 1 0 0
func make_user_context 2 0 0
attr softmax_custom_jvp 1 0 0
attr default_matmul_precision 1 0 0
attr transfer_guard_device_to_host 1 0 0
attr default_prng_impl 1 0 0
attr legacy_prng_key 1 0 0
attr default_device 1 0 0
attr debug_nans 1 0 0
attr debug_key_reuse 1 0 0
attr jax2tf_associative_scan_reductions 1 0 0
attr allow_f16_reductions 1 0 0
attr transfer_guard_host_to_device 1 0 0
attr config 1 0 0
attr enable_x64 1 0 0
attr enable_checks 1 0 0
attr explain_cache_misses 1 0 0
attr log_compiles 1 0 0
attr enable_custom_vjp_by_custom_transpose 1 0 0

jax._src.core (385 missing, 7 any)

Symbol Typable Typed Any
func empty_ref 3 0 0
meth ShapeDtypeStruct.init 7 0 0
meth ShapeDtypeStruct.len 1 0 0
meth ShapeDtypeStruct.repr 1 0 0
meth ShapeDtypeStruct.str 1 0 0
meth ShapeDtypeStruct.eq 2 0 0
meth ShapeDtypeStruct.hash 1 0 0
meth ShapeDtypeStruct.setattr 3 0 0
meth ShapeDtypeStruct.update 2 0 0
prop ShapeDtypeStruct.sharding 1 0 0
prop ShapeDtypeStruct.format 1 0 0
attr ShapeDtypeStruct.shape 1 0 0
attr ShapeDtypeStruct.dtype 1 0 0
attr ShapeDtypeStruct.weak_type 1 0 0
attr ShapeDtypeStruct.vma 1 0 0
attr ShapeDtypeStruct.is_ref 1 0 0
attr ShapeDtypeStruct.size 1 0 0
attr ShapeDtypeStruct.ndim 1 0 0
meth Ref.init 3 0 0
meth Ref.getitem 2 0 0
meth Ref.setitem 3 0 0
meth Ref.addupdate 3 0 0
meth Ref.unsafe_buffer_pointer 1 0 0
prop Ref.at 1 0 0
attr Ref.refs 1 1 1
attr Ref.aval 1 0 0
attr Ref.shape 1 0 0
attr Ref.size 1 0 0
attr Ref.ndim 1 0 0
attr Ref.dtype 1 0 0
attr Ref.sharding 1 0 0
attr Ref.format 1 0 0
attr Ref.committed 1 0 0
attr Ref._committed 1 0 0
func free_ref 2 1 0
func ensure_compile_time_eval 1 0 0
func typeof 2 2 2
func shaped_abstractify 2 0 0
meth Trace.__init__ 1 0 0
meth Trace.process_primitive 4 0 0
meth Trace.invalidate 1 0 0
meth Trace.is_valid 1 0 0
meth Trace.__repr__ 1 0 0
meth Trace.process_call 5 0 0
meth Trace.process_map 5 0 0
meth Trace.process_custom_jvp_call 6 0 0
meth Trace.process_custom_transpose 5 2 0
meth Trace.process_custom_vjp_call 8 0 0
meth Trace.full_raise 2 0 0
prop Trace.main 1 0 0
attr Trace.requires_low 1 0 0
meth CallPrimitive.bind 3 0 0
meth CallPrimitive.bind_with_trace 4 0 0
meth CallPrimitive.get_bind_params 2 0 0
func get_opaque_trace_state 2 0 0
func eval_context 1 0 0
attr no_axis_name 1 0 0
func find_top_trace 2 0 0
meth Tracer.__init__ 2 1 0
meth Tracer._error_repr 1 0 0
meth Tracer.__array__ 3 0 0
meth Tracer._is_traced_array 1 0 0
meth Tracer.__dlpack__ 3 0 0
meth Tracer.tolist 1 0 0
meth Tracer.tobytes 2 0 0
meth Tracer.full_lower 1 0 0
meth Tracer.__iter__ 1 0 0
meth Tracer.__reversed__ 1 0 0
meth Tracer.__len__ 1 0 0
meth Tracer.to_concrete_value 1 0 0
meth Tracer.get_referent 1 1 1
meth Tracer.__bool__ 1 0 0
meth Tracer.__int__ 1 0 0
meth Tracer.__float__ 1 0 0
meth Tracer.__complex__ 1 0 0
meth Tracer.__hex__ 1 0 0
meth Tracer.__oct__ 1 0 0
meth Tracer.__index__ 1 0 0
meth Tracer.__reduce__ 1 0 0
meth Tracer.__setitem__ 3 0 0
meth Tracer.__array_module__ 2 0 0
meth Tracer.__getattr__ 2 0 0
meth Tracer.__repr__ 1 0 0
meth Tracer._contents 1 0 0
meth Tracer.addressable_data 2 0 0
meth Tracer.delete 1 0 0
meth Tracer.devices 1 0 0
meth Tracer.is_deleted 1 0 0
meth Tracer.on_device_size_in_bytes 1 0 0
meth Tracer.unsafe_buffer_pointer 1 0 0
prop Tracer.sharding 1 0 0
prop Tracer.committed 1 0 0
prop Tracer.device 1 0 0
prop Tracer.addressable_shards 1 0 0
prop Tracer.at 1 0 0
prop Tracer.aval 1 0 0
prop Tracer.block_until_ready 1 0 0
prop Tracer.copy_to_host_async 1 0 0
prop Tracer.global_shards 1 0 0
prop Tracer.is_fully_addressable 1 0 0
prop Tracer.is_fully_replicated 1 0 0
prop Tracer.traceback 1 0 0
attr Tracer.dtype 1 0 0
attr Tracer.ndim 1 0 0
attr Tracer.size 1 0 0
attr Tracer.shape 1 0 0
func concretization_function_error 3 0 0
func new_jaxpr_eqn 8 1 0
func jaxprs_in_params 2 1 0
meth ShapedArray.__init__ 7 2 0
meth ShapedArray.lower_val 2 0 0
meth ShapedArray.raise_val 2 0 0
meth ShapedArray.lo_ty 1 0 0
meth ShapedArray.update 5 0 0
meth ShapedArray.__eq__ 2 0 0
meth ShapedArray.__hash__ 1 0 0
meth ShapedArray.__ne__ 2 0 0
meth ShapedArray.__repr__ 1 0 0
meth ShapedArray.__str__ 1 0 0
meth ShapedArray.to_tangent_aval 1 0 0
meth ShapedArray.to_cotangent_aval 1 0 0
meth ShapedArray.str_short 3 0 0
meth ShapedArray._len 2 0 0
meth ShapedArray.update_vma 2 0 0
meth ShapedArray.update_weak_type 2 0 0
attr ShapedArray.shape 1 0 0
attr ShapedArray.dtype 1 0 0
attr ShapedArray.weak_type 1 0 0
attr ShapedArray.sharding 1 0 0
attr ShapedArray.vma 1 0 0
attr ShapedArray.memory_space 1 0 0
attr ShapedArray.ndim 1 0 0
attr ShapedArray.size 1 0 0
attr ShapedArray._bool 1 0 0
attr ShapedArray._int 1 0 0
attr ShapedArray._float 1 0 0
attr ShapedArray._complex 1 0 0
attr ShapedArray._hex 1 0 0
attr ShapedArray._oct 1 0 0
attr ShapedArray._index 1 0 0
func check_jaxpr 2 1 0
meth DropVar.__init__ 2 1 0
meth DropVar.__repr__ 1 0 0
meth DropVar.pretty_print 3 2 0
func is_concrete 2 0 0
func primal_dtype_to_tangent_dtype 2 0 0
func valid_jaxtype 2 1 0
func concrete_or_error 4 2 2
attr trace_ctx 1 0 0
func eval_jaxpr 5 2 0
meth AbstractValue.to_tangent_aval 1 0 0
meth AbstractValue.to_cotangent_aval 1 0 0
meth AbstractValue.at_least_vspace 1 0 0
meth AbstractValue.__repr__ 1 0 0
meth AbstractValue.update_weak_type 2 0 0
meth AbstractValue.update_vma 2 0 0
meth AbstractValue.update 2 0 0
meth AbstractValue.lo_ty 1 0 0
meth AbstractValue.lo_ty_qdd 2 0 0
meth AbstractValue.str_short 3 0 0
meth AbstractValue.dec_rank 3 0 0
meth AbstractValue.inc_rank 3 0 0
meth AbstractValue.shard 5 0 0
meth AbstractValue.unshard 4 0 0
func cur_qdd 2 0 0
meth AvalQDD.lo_ty 1 0 0
meth AvalQDD.read_loval 2 0 0
meth AvalQDD.new_from_loval 2 0 0
meth AvalQDD.to_tangent_aval 1 0 0
meth ClosedJaxpr.__init__ 3 2 0
meth ClosedJaxpr.map_jaxpr 2 0 0
meth ClosedJaxpr.replace 3 0 0
meth ClosedJaxpr.__str__ 1 0 0
meth ClosedJaxpr.__repr__ 1 0 0
meth ClosedJaxpr.pretty_print 7 0 0
meth ClosedJaxpr._repr_pretty 3 0 0
prop ClosedJaxpr.in_avals 1 0 0
prop ClosedJaxpr.out_avals 1 0 0
attr ClosedJaxpr.jaxpr 1 0 0
attr ClosedJaxpr.consts 1 0 0
attr ClosedJaxpr.constvars 1 0 0
attr ClosedJaxpr.invars 1 0 0
attr ClosedJaxpr.outvars 1 0 0
attr ClosedJaxpr.eqns 1 0 0
attr ClosedJaxpr.effects 1 0 0
attr ClosedJaxpr.debug_info 1 0 0
attr ClosedJaxpr.is_high 1 0 0
meth Token.init 2 0 0
meth Token.block_until_ready 1 0 0
meth Var.init 4 1 0
meth Var.repr 1 0 0
meth Var.pretty_print 3 2 0
meth Jaxpr.init 8 7 0
meth Jaxpr.str 1 0 0
meth Jaxpr.repr 1 0 0
meth Jaxpr.pretty_print 7 1 0
meth Jaxpr.repr_pretty 3 0 0
meth Jaxpr.replace 2 0 0
prop Jaxpr.in_avals 1 0 0
prop Jaxpr.out_avals 1 0 0
func unmapped_aval 5 4 0
meth JaxprEqn.init 8 0 0
meth JaxprEqn.repr 1 0 0
meth JaxprEqn.replace 8 7 0
func jaxpr_as_fun 3 1 0
func mapped_aval 4 3 0
meth Primitive.init 2 1 0
meth Primitive.repr 1 0 0
meth Primitive.bind 3 0 0
meth Primitive._true_bind 3 0 0
meth Primitive.bind_with_trace 4 0 0
meth Primitive.def_impl 2 0 0
meth Primitive.def_abstract_eval 2 0 0
meth Primitive.def_effectful_abstract_eval 2 0 0
meth Primitive.def_effectful_abstract_eval2 2 0 0
meth Primitive.def_bind_with_trace 2 0 0
meth Primitive.impl 3 0 0
meth Primitive.abstract_eval 3 0 0
meth Primitive.get_bind_params 2 0 0
meth Primitive.is_high 3 1 0
meth Literal.init 3 0 0
meth Literal.pretty_print 3 2 0
meth Literal.repr 1 0 0
prop Literal.hash 1 0 0
attr Literal.val 1 1 1

jax._src.custom_batching (6 missing, 0 any)

Symbol Typable Typed Any
func sequential_vmap 2 0 0
meth custom_vmap.init 2 1 0
meth custom_vmap.call 3 0 0

jax._src.custom_dce (5 missing, 0 any)

Symbol Typable Typed Any
attr custom_dce_p 1 0 0
meth custom_dce.init 3 2 0
meth custom_dce.call 3 0 0

jax._src.custom_derivatives (21 missing, 4 any)

Symbol Typable Typed Any
func custom_gradient 2 0 0
meth custom_jvp.init 4 3 0
meth custom_jvp.call 3 3 2
meth custom_vjp.new 4 0 0
meth custom_vjp.init 4 3 0
meth custom_vjp.call 3 3 2
attr custom_vjp.fun 1 0 0
attr custom_vjp.nondiff_argnums 1 0 0
attr custom_vjp.symbolic_zeros 1 0 0
attr custom_vjp.optimize_remat 1 0 0
func closure_convert 3 2 0
attr remat_opt_p 1 0 0
attr custom_jvp_call_p 1 0 0
attr custom_vjp_call_p 1 0 0
func linear_call 5 2 0
func custom_vjp_primal_tree_values 2 0 0

jax._src.custom_partitioning (23 missing, 1 any)

Symbol Typable Typed Any
meth custom_partitioning.init 3 0 0
meth custom_partitioning.def_partition 10 0 0
meth custom_partitioning.call 3 0 0
attr custom_partitioning.fun 1 0 0
attr custom_partitioning.partition 1 0 0
attr custom_partitioning.static_argnums 1 0 0
attr custom_partitioning.propagate_user_sharding 1 0 0
attr custom_partitioning.infer_sharding_from_operands 1 0 0
attr custom_partitioning.sharding_rule 1 0 0
attr custom_partitioning.getattr 1 1 1
attr custom_partitioning_p 1 0 0

jax._src.custom_partitioning_sharding_rule (10 missing, 0 any)

Symbol Typable Typed Any
meth CompoundFactor.init 2 0 0
meth CompoundFactor.new 2 0 0
meth ArrayMapping.init 2 0 0
meth ArrayMapping.new 2 0 0
meth SdyShardingRule.init 7 6 0
meth SdyShardingRule.str 1 0 0

jax._src.custom_transpose (6 missing, 0 any)

Symbol Typable Typed Any
meth custom_transpose.init 2 1 0
meth custom_transpose.def_transpose 2 1 0
meth custom_transpose.call 4 0 0

jax._src.debugger.core (3 missing, 0 any)

Symbol Typable Typed Any
func breakpoint 7 4 0

jax._src.debugging (9 missing, 2 any)

Symbol Typable Typed Any
func debug_print 8 6 0
attr debug_log 1 0 0
func visualize_sharding 8 7 0
func visualize_array_sharding 3 0 0
func debug_callback 6 6 2
func inspect_array_sharding 3 1 0

jax._src.dispatch (5 missing, 0 any)

Symbol Typable Typed Any
attr device_put_p 1 0 0
func apply_primitive 4 0 0

jax._src.distributed (2 missing, 0 any)

Symbol Typable Typed Any
func initialize 12 11 0
func shutdown 1 0 0

jax._src.dlpack (2 missing, 0 any)

Symbol Typable Typed Any
func from_dlpack 4 2 0

jax._src.dtypes (2 missing, 3 any)

Symbol Typable Typed Any
func scalar_type_of 2 2 1
func canonicalize_dtype 3 3 1
func result_type 3 3 1
func primal_tangent_dtype 4 2 0

jax._src.earray (28 missing, 0 any)

Symbol Typable Typed Any
meth EArray.init 3 0 0
meth EArray.block_until_ready 1 0 0
meth EArray.copy_to_host_async 1 0 0
meth EArray.copy 1 0 0
meth EArray.repr 1 0 0
meth EArray.iter 1 0 0
meth EArray.len 1 0 0
prop EArray.sharding 1 0 0
prop EArray.committed 1 0 0
prop EArray.device 1 0 0
prop EArray.addressable_shards 1 0 0
prop EArray.global_shards 1 0 0
attr EArray.aval 1 0 0
attr EArray.shape 1 0 0
attr EArray.dtype 1 0 0
attr EArray.ndim 1 0 0
attr EArray.size 1 0 0
attr EArray.itemsize 1 0 0
attr EArray.devices 1 0 0
attr EArray._committed 1 0 0
attr EArray.is_fully_addressable 1 0 0
attr EArray.is_fully_replicated 1 0 0
attr EArray.delete 1 0 0
attr EArray.is_deleted 1 0 0
attr EArray.on_device_size_in_bytes 1 0 0
attr EArray.unsafe_buffer_pointer 1 0 0

jax._src.errors (6 missing, 0 any)

Symbol Typable Typed Any
meth TracerIntegerConversionError.init 2 1 0
meth UnexpectedTracerError.init 2 1 0
meth TracerBoolConversionError.init 2 1 0
meth NonConcreteBooleanIndexError.init 2 1 0
meth ConcretizationTypeError.init 3 2 0
meth TracerArrayConversionError.init 2 1 0

jax._src.export._export (8 missing, 0 any)

Symbol Typable Typed Any
meth DisabledSafetyCheck.init 2 1 0
meth DisabledSafetyCheck.str 1 0 0
meth DisabledSafetyCheck.repr 1 0 0
meth DisabledSafetyCheck.eq 2 1 0
meth Exported.str 1 0 0
meth Exported.call 3 0 0

jax._src.export.shape_poly (13 missing, 0 any)

Symbol Typable Typed Any
meth PolyShape.init 2 0 0
meth PolyShape.new 2 0 0
meth PolyShape.str 1 0 0
meth SymbolicScope.init 2 1 0
meth SymbolicScope._parse_and_process_explicit_constraint 2 1 0
meth SymbolicScope._process_explicit_constraint 2 1 0
meth SymbolicScope._check_same_scope 5 4 0
meth SymbolicScope._clear_caches 1 0 0
func symbolic_args_specs 5 2 0

jax._src.ffi (2 missing, 5 any)

Symbol Typable Typed Any
func register_ffi_target 6 6 2
func register_ffi_type_id 4 4 1
func ffi_lowering 7 7 1
func pycapsule 2 0 0
func build_ffi_lowering_function 7 7 1

jax._src.flatten_util (0 missing, 1 any)

Symbol Typable Typed Any
func ravel_pytree 2 2 1

jax._src.hijax (78 missing, 13 any)

Symbol Typable Typed Any
meth HiPrimitive.init 2 0 0
meth HiPrimitive.is_high 3 1 0
meth HiPrimitive.is_effectful 2 1 0
meth HiPrimitive.abstract_eval 3 0 0
meth HiPrimitive.to_lojax 3 0 0
meth HiPrimitive.jvp 4 0 0
meth HiPrimitive.transpose 3 0 0
attr HiPrimitive.name 1 0 0
meth MutableHiType.hash 1 0 0
meth MutableHiType.eq 2 0 0
meth MutableHiType.lo_ty 1 0 0
meth MutableHiType.new_from_loval 3 3 2
meth MutableHiType.read_loval 3 3 1
meth MutableHiType.update_from_loval 4 4 2
attr MutableHiType.type_state 1 0 0
meth HipSpec.to_lo 1 0 0
func register_hitype 3 1 0
meth VJPHiPrimitive.init 1 0 0
meth VJPHiPrimitive.expand 2 0 0
meth VJPHiPrimitive.vjp_fwd 3 0 0
meth VJPHiPrimitive.vjp_bwd 4 0 0
meth VJPHiPrimitive.vjp_bwd_retval 3 0 0
meth VJPHiPrimitive.jvp 3 0 0
meth VJPHiPrimitive.lin 3 0 0
meth VJPHiPrimitive.linearized 3 0 0
meth VJPHiPrimitive.batch 4 0 0
meth VJPHiPrimitive.batch_dim_rule 3 0 0
meth VJPHiPrimitive.dce 2 0 0
meth VJPHiPrimitive.remat 3 0 0
meth VJPHiPrimitive.call 2 0 0
meth VJPHiPrimitive.check 2 0 0
meth VJPHiPrimitive.staging 4 0 0
meth VJPHiPrimitive.repr 1 0 0
meth VJPHiPrimitive.hash 1 0 0
meth VJPHiPrimitive.eq 2 0 0
attr VJPHiPrimitive.out_aval 1 1 1
meth HiType.hash 1 0 0
meth HiType.eq 2 0 0
meth HiType.lower_val 2 2 1
meth HiType.raise_val 2 2 2
meth HiType.vspace_zero 1 1 1
meth HiType.vspace_add 3 3 3
meth HiType.shard 5 4 0
meth HiType.unshard 4 3 0

jax._src.image.scale (9 missing, 0 any)

Symbol Typable Typed Any
func resize 6 3 0
func scale_and_translate 9 4 0
meth ResizeMethod.from_string 2 1 0

jax._src.indexing (6 missing, 0 any)

Symbol Typable Typed Any
meth Slice.post_init 1 0 0
meth Slice.tree_flatten 1 0 0
meth Slice.tree_unflatten 3 1 0
prop Slice.is_dynamic_start 1 0 0
prop Slice.is_dynamic_size 1 0 0

jax._src.interpreters.ad (92 missing, 1 any)

Symbol Typable Typed Any
func instantiate_zeros 2 0 0
func is_undefined_primal 2 0 0
func deflinear 3 0 0
func linearize 5 1 0
meth UndefinedPrimal.init 2 0 0
meth UndefinedPrimal.repr 1 0 0
attr UndefinedPrimal.aval 1 0 0
func add_tangents 3 0 0
func jvp 5 2 1
func defjvp 3 0 0
meth JVPTracer.init 4 0 0
meth JVPTracer._short_repr 1 0 0
meth JVPTracer.cur_qdd 1 0 0
meth JVPTracer.full_lower 1 0 0
meth JVPTracer.to_concrete_value 1 0 0
meth JVPTracer.get_referent 1 0 0
meth JVPTracer.type_state 1 0 0
prop JVPTracer.aval 1 0 0
attr JVPTracer.primal 1 0 0
attr JVPTracer.tangent 1 0 0
meth JVPTrace.init 3 0 0
meth JVPTrace.to_primal_tangent_pair 2 0 0
meth JVPTrace.process_primitive 4 0 0
meth JVPTrace.cur_qdd 2 0 0
meth JVPTrace.process_call 5 0 0
meth JVPTrace.process_map 5 0 0
meth JVPTrace.process_custom_jvp_call 6 0 0
meth JVPTrace.process_custom_vjp_call 8 0 0
meth JVPTrace.process_custom_transpose 5 0 0
attr JVPTrace.tag 1 0 0
attr JVPTrace.parent_trace 1 0 0
attr JVPTrace.requires_low 1 0 0
func defjvp2 3 0 0
func deflinear2 3 0 0
func defbilinear 4 0 0
func get_primitive_transpose 2 0 0

jax._src.interpreters.batching (19 missing, 0 any)

Symbol Typable Typed Any
func defbroadcasting 2 0 0
func register_vmappable 7 6 0
attr axis_primitive_batchers 1 0 0
func bdim_at_front 5 0 0
attr primitive_batchers 1 0 0
func broadcast 5 0 0
func defreducer 2 0 0
func defvectorized 2 0 0

jax._src.interpreters.mlir (28 missing, 5 any)

Symbol Typable Typed Any
func lower_with_sharding_in_types 5 0 0
func register_constant_handler 3 2 0
func ir_attribute 2 2 1
meth TokenSet.init 3 0 0
meth TokenSet.len 1 0 0
func module_to_string 3 2 0
meth ShapePolyLoweringState.init 3 2 0
func i32_attr 2 0 0
meth ModuleContext.init 17 15 0
meth ModuleContext.add_host_callback 2 2 1
meth ModuleContext.add_keepalive 2 2 1
meth ModuleContext.replace 2 0 0
func dense_int_elements 2 1 0
meth LoweringRuleContext.set_tokens_out 2 1 0
meth LoweringRuleContext.replace 2 0 0
func ir_constant 4 4 1
func core_call_lowering 6 2 0
func i64_attr 2 0 0
func register_lowering 6 6 1

jax._src.interpreters.partial_eval (23 missing, 0 any)

Symbol Typable Typed Any
attr Saveable 1 0 0
attr Recompute 1 0 0
meth PartialVal.new 2 1 0
meth DynamicJaxprTracer.init 6 5 0
meth DynamicJaxprTracer._short_repr 1 0 0
meth DynamicJaxprTracer.cur_qdd 1 0 0
meth DynamicJaxprTracer.full_lower 1 0 0
meth DynamicJaxprTracer._contents 1 0 0
meth DynamicJaxprTracer._origin_msg 1 0 0
meth DynamicJaxprTracer.get_const 1 0 0
meth DynamicJaxprTracer.get_referent 1 0 0
prop DynamicJaxprTracer.aval_mutable_qdd 1 0 0
attr DynamicJaxprTracer.aval 1 0 0
attr DynamicJaxprTracer.val 1 0 0
attr DynamicJaxprTracer.mutable_qdd 1 0 0
attr DynamicJaxprTracer.parent 1 0 0
meth JaxprTracer.init 4 3 0
meth JaxprTracer.repr 1 0 0
meth JaxprTracer.full_lower 1 0 0
meth JaxprTracer.is_known 1 0 0
meth JaxprTracer.get_referent 1 0 0
attr JaxprTracer.pval 1 0 0
attr JaxprTracer.recipe 1 0 0

jax._src.interpreters.pxla (1 missing, 0 any)

Symbol Typable Typed Any
attr xla_pmap_p 1 0 0

jax._src.lax.ann (1 missing, 0 any)

Symbol Typable Typed Any
attr approx_top_k_p 1 0 0

jax._src.lax.control_flow.conditionals (8 missing, 3 any)

Symbol Typable Typed Any
attr cond_p 1 0 0
func cond 6 2 0
func platform_dependent 4 3 1
func switch 5 3 2

jax._src.lax.control_flow.loops (17 missing, 0 any)

Symbol Typable Typed Any
attr scan_p 1 0 0
attr cumlogsumexp_p 1 0 0
attr cumprod_p 1 0 0
attr cumsum_p 1 0 0
attr cummax_p 1 0 0
attr cummin_p 1 0 0
attr while_p 1 0 0
func fori_loop 6 1 0
func associative_scan 5 3 0
func map 4 1 0

jax._src.lax.control_flow.solves (6 missing, 2 any)

Symbol Typable Typed Any
attr linear_solve_p 1 0 0
func custom_linear_solve 7 4 1
func custom_root 6 4 1

jax._src.lax.convolution (24 missing, 0 any)

Symbol Typable Typed Any
attr conv_general_dilated_p 1 0 0
func conv_general_shape_tuple 6 0 0
func conv_transpose_shape_tuple 6 0 0
func conv_dimension_numbers 4 1 0
func conv_general_permutations 2 0 0
func conv_shape_tuple 6 0 0

jax._src.lax.fft (3 missing, 0 any)

Symbol Typable Typed Any
attr fft_p 1 0 0
func fft 4 2 0

jax._src.lax.lax (147 missing, 3 any)

Symbol Typable Typed Any
attr rng_uniform_p 1 0 0
attr eq_p 1 0 0
attr reduce_min_p 1 0 0
attr create_token_p 1 0 0
attr round_p 1 0 0
attr sub_p 1 0 0
attr pad_p 1 0 0
attr concatenate_p 1 0 0
attr sinh_p 1 0 0
attr real_p 1 0 0
attr population_count_p 1 0 0
attr bitcast_convert_type_p 1 0 0
attr reduce_max_p 1 0 0
attr xor_p 1 0 0
attr empty2_p 1 0 0
attr rng_bit_generator_p 1 0 0
attr reshape_p 1 0 0
attr sqrt_p 1 0 0
attr mul_p 1 0 0
attr atanh_p 1 0 0
attr reduce_prod_p 1 0 0
attr cosh_p 1 0 0
attr reduce_xor_p 1 0 0
attr cbrt_p 1 0 0
attr complex_p 1 0 0
attr cos_p 1 0 0
attr or_p 1 0 0
attr lt_to_p 1 0 0
attr conj_p 1 0 0
attr tanh_p 1 0 0
attr after_all_p 1 0 0
attr expm1_p 1 0 0
attr ge_p 1 0 0
attr log_p 1 0 0
attr neg_p 1 0 0
attr sign_p 1 0 0
attr acos_p 1 0 0
attr top_k_p 1 0 0
attr asin_p 1 0 0
attr abs_p 1 0 0
attr ceil_p 1 0 0
attr shift_right_logical_p 1 0 0
attr exp_p 1 0 0
attr le_p 1 0 0
attr rem_p 1 0 0
attr asinh_p 1 0 0
attr ne_p 1 0 0
attr is_finite_p 1 0 0
attr lt_p 1 0 0
attr iota_p 1 0 0
attr select_n_p 1 0 0
attr atan_p 1 0 0
attr square_p 1 0 0
attr shift_left_p 1 0 0
attr acosh_p 1 0 0
attr clz_p 1 0 0
attr imag_p 1 0 0
attr clamp_p 1 0 0
attr transpose_p 1 0 0
attr broadcast_in_dim_p 1 0 0
attr integer_pow_p 1 0 0
attr reduce_or_p 1 0 0
attr le_to_p 1 0 0
attr nextafter_p 1 0 0
attr shift_right_arithmetic_p 1 0 0
attr argmin_p 1 0 0
attr argmax_p 1 0 0
attr div_p 1 0 0
attr gt_p 1 0 0
attr tan_p 1 0 0
attr not_p 1 0 0
attr reduce_and_p 1 0 0
attr pow_p 1 0 0
attr log1p_p 1 0 0
attr sort_p 1 0 0
attr dot_general_p 1 0 0
attr atan2_p 1 0 0
attr logistic_p 1 0 0
attr rev_p 1 0 0
attr copy_p 1 0 0
attr reduce_p 1 0 0
attr convert_element_type_p 1 0 0
attr squeeze_p 1 0 0
attr sin_p 1 0 0
attr and_p 1 0 0
attr reduce_precision_p 1 0 0
attr exp2_p 1 0 0
attr rsqrt_p 1 0 0
attr reduce_sum_p 1 0 0
attr eq_to_p 1 0 0
attr floor_p 1 0 0
attr tile_p 1 0 0
func empty 4 0 0
meth RaggedDotDimensionNumbers.init 4 0 0
func rsqrt 3 2 0
func rng_uniform 4 0 0
func broadcast_in_dim 5 4 0
func log1p 3 2 0
func shape_as_value 2 1 0
func sqrt 3 2 0
func dce_sink 2 0 0
attr optimization_barrier_p 1 0 0
func dot_general 7 6 0
func create_token 2 0 0
meth Tolerance.init 4 3 0
attr Tolerance.atol 1 0 0
attr Tolerance.rtol 1 0 0
attr Tolerance.ulps 1 0 0
func exp2 3 2 0
func exp 3 2 0
func cos 3 2 0
attr dce_sink_p 1 0 0
func rng_bit_generator 6 0 0
func reduce 6 6 3
func tanh 3 2 0
attr split_p 1 0 0
func composite 4 3 0
func cbrt 3 2 0
func broadcast 4 3 0
func log 3 2 0
func dot 8 6 0
func optimization_barrier 2 0 0
func tan 3 2 0
func reduce_sum 4 3 0
func sin 3 2 0
func expm1 3 2 0
func logistic 3 2 0
func after_all 2 0 0
func broadcasted_iota 5 4 0

jax._src.lax.linalg (16 missing, 0 any)

Symbol Typable Typed Any
attr schur_p 1 0 0
attr cholesky_p 1 0 0
attr householder_product_p 1 0 0
attr eigh_p 1 0 0
attr svd_p 1 0 0
attr hessenberg_p 1 0 0
attr eig_p 1 0 0
attr tridiagonal_solve_p 1 0 0
attr lu_p 1 0 0
attr tridiagonal_p 1 0 0
attr triangular_solve_p 1 0 0
attr qr_p 1 0 0
func symmetric_product 6 5 0
attr cholesky_update_p 1 0 0
attr lu_pivots_to_permutation_p 1 0 0
attr symmetric_product_p 1 0 0

jax._src.lax.parallel (88 missing, 0 any)

Symbol Typable Typed Any
attr pmin_p 1 0 0
attr ragged_all_to_all_p 1 0 0
attr all_gather_p 1 0 0
attr pmax_p 1 0 0
attr psum_p 1 0 0
attr ppermute_p 1 0 0
attr all_to_all_p 1 0 0
attr axis_index_p 1 0 0
func pshuffle 4 0 0
func psum 4 0 0
func ppermute 4 0 0
func psend 4 0 0
func ragged_all_to_all 9 0 0
func pcast 4 1 0
func precv 5 0 0
func pmean 4 0 0
func pmin 4 0 0
func pswapaxes 5 0 0
func all_gather_start 5 0 0
func all_to_all 7 0 0
func all_gather_done 2 0 0
func all_gather 7 1 0
func pmax 4 0 0
func psum_scatter 6 0 0
func pbroadcast 4 0 0

jax._src.lax.scaled_dot (1 missing, 0 any)

Symbol Typable Typed Any
func scaled_dot 7 6 0

jax._src.lax.slicing (11 missing, 0 any)

Symbol Typable Typed Any
attr scatter_mul_p 1 0 0
attr scatter_p 1 0 0
attr gather_p 1 0 0
attr scatter_add_p 1 0 0
attr dynamic_slice_p 1 0 0
attr scatter_max_p 1 0 0
attr dynamic_update_slice_p 1 0 0
attr scatter_min_p 1 0 0
attr slice_p 1 0 0
func gather 9 8 0
attr scatter_sub_p 1 0 0

jax._src.lax.special (14 missing, 0 any)

Symbol Typable Typed Any
attr bessel_i1e_p 1 0 0
attr lgamma_p 1 0 0
attr erf_inv_p 1 0 0
attr igamma_grad_a_p 1 0 0
attr regularized_incomplete_beta_p 1 0 0
attr erf_p 1 0 0
attr igammac_p 1 0 0
attr bessel_i0e_p 1 0 0
attr erfc_p 1 0 0
attr zeta_p 1 0 0
attr igamma_p 1 0 0
attr polygamma_p 1 0 0
attr digamma_p 1 0 0
func random_gamma_grad 4 3 0

jax._src.lax.windowed_reductions (14 missing, 3 any)

Symbol Typable Typed Any
attr reduce_window_p 1 0 0
attr reduce_window_min_p 1 0 0
attr select_and_gather_add_p 1 0 0
attr select_and_scatter_p 1 0 0
attr reduce_window_max_p 1 0 0
attr reduce_window_sum_p 1 0 0
attr select_and_scatter_add_p 1 0 0
func reduce_window 9 9 3
func reduce_window_shape_tuple 7 0 0

jax._src.layout (18 missing, 0 any)

Symbol Typable Typed Any
meth Format.init 3 2 0
meth Format.repr 1 0 0
meth Format.hash 1 0 0
meth Format.eq 2 0 0
attr Format.layout 1 0 0
attr Format.sharding 1 0 0
meth Layout.init 4 3 0
meth Layout.from_pjrt_layout 2 1 0
meth Layout.repr 1 0 0
meth Layout.hash 1 0 0
meth Layout.eq 2 0 0
meth Layout.update 2 0 0
meth Layout._to_xla_layout 2 1 0
meth Layout.check_compatible_aval 2 1 0
attr Layout.AUTO 1 0 0

jax._src.linear_util (29 missing, 0 any)

Symbol Typable Typed Any
meth DebugInfo.set_result_paths 2 0 0
meth DebugInfo.assert_arg_names 2 1 0
meth DebugInfo.assert_result_paths 2 1 0
func transformation_with_aux 4 2 0
func transformation_with_aux2 6 4 0
func transformation 4 2 0
func cache 3 2 0
func transformation2 4 2 0
func merge_linear_aux 3 0 0
meth WrappedFun.init 8 7 0
meth WrappedFun.wrap 4 2 0
meth WrappedFun.populate_stores 2 0 0
meth WrappedFun.call_wrapped 3 0 0
meth WrappedFun.repr 1 0 0
meth WrappedFun.hash 1 0 0
meth WrappedFun.eq 2 0 0
prop WrappedFun.name 1 0 0

jax._src.memory (1 missing, 0 any)

Symbol Typable Typed Any
meth Space.repr 1 0 0

jax._src.mesh (75 missing, 0 any)

Symbol Typable Typed Any
meth Mesh.new 4 3 0
meth Mesh.reduce 1 0 0
meth Mesh.eq 2 0 0
meth Mesh.hash 1 0 0
meth Mesh.setattr 3 0 0
meth Mesh.enter 1 0 0
meth Mesh.exit 4 0 0
meth Mesh.update 4 0 0
meth Mesh._local_mesh 2 0 0
meth Mesh.str 1 0 0
meth Mesh.repr 1 0 0
prop Mesh.shape 1 0 0
prop Mesh.shape_tuple 1 0 0
prop Mesh.size 1 0 0
prop Mesh.empty 1 0 0
prop Mesh.is_multi_process 1 0 0
prop Mesh.local_mesh 1 0 0
prop Mesh.device_ids 1 0 0
prop Mesh._local_devices_set 1 0 0
prop Mesh._flat_devices_tuple 1 0 0
prop Mesh._internal_device_list 1 0 0
prop Mesh._flat_devices_set 1 0 0
prop Mesh._repr 1 0 0
prop Mesh.local_devices 1 0 0
prop Mesh.abstract_mesh 1 0 0
meth use_abstract_mesh.init 2 1 0
meth use_abstract_mesh.enter 1 0 0
meth use_abstract_mesh.exit 4 0 0
attr use_abstract_mesh.mesh 1 0 0
meth AbstractMesh.init 5 3 0
meth AbstractMesh.hash 1 0 0
meth AbstractMesh.eq 2 0 0
meth AbstractMesh.repr 1 0 0
meth AbstractMesh.update 5 0 0
meth AbstractMesh.update_axis_types 2 1 0
meth AbstractMesh.enter 1 0 0
meth AbstractMesh.exit 4 0 0
prop AbstractMesh.shape 1 0 0
prop AbstractMesh.shape_tuple 1 0 0
prop AbstractMesh._internal_device_list 1 0 0
prop AbstractMesh.empty 1 0 0
prop AbstractMesh.abstract_mesh 1 0 0
prop AbstractMesh.devices 1 0 0
prop AbstractMesh.device_ids 1 0 0
prop AbstractMesh.is_multi_process 1 0 0
prop AbstractMesh.local_devices 1 0 0
prop AbstractMesh.local_mesh 1 0 0
attr AbstractMesh.axis_sizes 1 0 0
attr AbstractMesh.abstract_device 1 0 0
attr AbstractMesh.size 1 0 0
meth AxisType.repr 1 0 0
meth AbstractDevice.repr 1 0 0
meth AbstractDevice._repr 1 0 0

jax._src.monitoring (1 missing, 0 any)

Symbol Typable Typed Any
func clear_event_listeners 1 0 0

jax._src.named_sharding (11 missing, 0 any)

Symbol Typable Typed Any
meth NamedSharding.init 5 3 0
meth NamedSharding.repr 1 0 0
meth NamedSharding.reduce 1 0 0
meth NamedSharding.hash 1 0 0
meth NamedSharding.eq 2 0 0
meth NamedSharding.update 2 1 0
meth AUTO.init 2 1 0
prop AUTO._device_assignment 1 0 0
attr AUTO.mesh 1 0 0

jax._src.nn.initializers (0 missing, 8 any)

Symbol Typable Typed Any
func orthogonal 4 4 1
func uniform 3 3 1
func truncated_normal 5 5 3
func variance_scaling 8 8 1
func normal 3 3 1
func delta_orthogonal 4 4 1

jax._src.numpy.linalg (4 missing, 0 any)

Symbol Typable Typed Any
func cond 3 1 0
func cross 4 2 0

jax._src.pallas.core (24 missing, 2 any)

Symbol Typable Typed Any
meth BoundedSlice.repr 1 0 0
func lower_as_mlir 8 1 0
meth MemorySpace.call 3 2 0
func core_map 9 7 0
attr no_block_spec 1 0 0
attr squeezed 1 0 0
meth CostEstimate.post_init 1 0 0
meth MemoryRef.lt 2 0 0
prop MemoryRef.dtype 1 0 0
prop MemoryRef.shape 1 0 0
meth GridSpec.init 5 4 2
meth GridSpec._make_scalar_ref_aval 2 0 0
meth Blocked.str 1 0 0
meth BlockSpec.post_init 1 0 0
meth Element.str 1 0 0

jax._src.pallas.cost_estimate (3 missing, 0 any)

Symbol Typable Typed Any
func estimate_cost 4 1 0

jax._src.pallas.fuser.block_spec (8 missing, 0 any)

Symbol Typable Typed Any
func get_fusion_values 4 2 0
func pull_block_spec 5 4 0
func push_block_spec 4 1 0
func make_scalar_prefetch_handler 2 0 0

jax._src.pallas.fuser.custom_evaluate (2 missing, 0 any)

Symbol Typable Typed Any
func evaluate 3 1 0

jax._src.pallas.fuser.custom_fusion_lib (9 missing, 0 any)

Symbol Typable Typed Any
meth custom_fusion.init 2 1 0
meth custom_fusion.def_pallas_impl 2 0 0
meth custom_fusion.def_pull_block_spec 2 1 0
meth custom_fusion.def_push_block_spec 2 1 0
meth custom_fusion.def_eval_rule 2 1 0
meth custom_fusion.call 3 0 0

jax._src.pallas.fuser.fusible (2 missing, 1 any)

Symbol Typable Typed Any
func fusible 3 1 1

jax._src.pallas.fuser.fusion (5 missing, 0 any)

Symbol Typable Typed Any
prop Fusion.shape 1 0 0
prop Fusion.dtype 1 0 0
prop Fusion.type 1 0 0
prop Fusion.in_shape 1 0 0
prop Fusion.in_dtype 1 0 0

jax._src.pallas.fuser.jaxpr_fusion (2 missing, 0 any)

Symbol Typable Typed Any
func fuse 4 2 0

jax._src.pallas.helpers (7 missing, 2 any)

Symbol Typable Typed Any
func kernel 11 10 0
func empty_like 2 1 0
attr empty 1 0 0
func with_scoped 4 3 2
func debug_check 3 0 0

jax._src.pallas.mosaic.core (10 missing, 0 any)

Symbol Typable Typed Any
meth MemorySpace.from_type 2 0 0
meth MemorySpace.call 3 2 0
meth MemorySpace.getattr 2 0 0
meth PrefetchScalarGridSpec.init 6 5 0
meth PrefetchScalarGridSpec._make_scalar_ref_aval 2 0 0
meth CompilerParams.init 16 15 0
meth SemaphoreType.call 2 1 0

jax._src.pallas.mosaic.helpers (5 missing, 0 any)

Symbol Typable Typed Any
func sync_copy 4 2 0
func core_barrier 3 1 0
func run_on_first_core 2 1 0

jax._src.pallas.mosaic.interpret.interpret_pallas_call (3 missing, 0 any)

Symbol Typable Typed Any
func reset_tpu_interpret_mode_state 1 0 0
func set_tpu_interpret_mode 2 1 0
func force_tpu_interpret_mode 2 1 0

jax._src.pallas.mosaic.pipeline (98 missing, 1 any)

Symbol Typable Typed Any
func emit_pipeline 12 8 0
meth BufferedRef.post_init 1 0 0
meth BufferedRef.create 10 4 0
meth BufferedRef.input 5 0 0
meth BufferedRef.output 5 0 0
meth BufferedRef.accumulator 5 0 0
meth BufferedRef.input_output 5 0 0
meth BufferedRef.with_next_fetch 2 1 0
meth BufferedRef.bind_existing_ref 3 0 0
meth BufferedRef.unbind_refs 1 0 0
meth BufferedRef.compute_slice 2 0 0
meth BufferedRef.init_slots 1 0 0
meth BufferedRef.save_slots 2 1 0
meth BufferedRef.copy_in 3 0 0
meth BufferedRef.copy_out 3 0 0
meth BufferedRef.wait_in 3 0 0
meth BufferedRef.wait_out 3 0 0
meth BufferedRef.set_accumulator 2 0 0
meth BufferedRef.accumulate 1 0 0
prop BufferedRef.spec 1 0 0
prop BufferedRef.buffer_type 1 0 0
prop BufferedRef.block_shape 1 0 0
prop BufferedRef.compute_index 1 0 0
prop BufferedRef.current_ref 1 0 0
prop BufferedRef.cumulative_copy_in 1 0 0
prop BufferedRef.current_copy_in_slot 1 0 0
prop BufferedRef.cumulative_copy_out 1 0 0
prop BufferedRef.current_copy_out_slot 1 0 0
prop BufferedRef.cumulative_wait_in 1 0 0
prop BufferedRef.current_wait_in_slot 1 0 0
prop BufferedRef.cumulative_wait_out 1 0 0
prop BufferedRef.current_wait_out_slot 1 0 0
prop BufferedRef.next_fetch_indices 1 0 0
meth BufferedRefBase.init_slots 1 0 0
meth BufferedRefBase.save_slots 2 1 0
meth BufferedRefBase.get_dma_slice 3 0 0
meth BufferedRefBase.bind_existing_ref 3 0 0
meth BufferedRefBase.unbind_refs 1 0 0
prop BufferedRefBase.is_input 1 0 0
prop BufferedRefBase.is_output 1 0 0
prop BufferedRefBase.is_accumulator 1 0 0
prop BufferedRefBase.is_input_output 1 0 0
prop BufferedRefBase.is_manual 1 0 0
prop BufferedRefBase.compute_index 1 0 0
func emit_pipeline_with_allocations 6 0 0
func get_pipeline_schedule 2 1 1
func make_pipeline_allocations 8 1 0

jax._src.pallas.mosaic.primitives (31 missing, 1 any)

Symbol Typable Typed Any
func matmul_pop 5 4 0
func stochastic_round 4 0 0
func unpack_elementwise 5 0 0
func async_copy 6 3 0
func make_async_copy 4 1 0
func pack_elementwise 3 0 0
func make_async_remote_copy 7 3 0
func async_remote_copy 7 2 0
func prng_random_bits 2 0 0
func with_memory_space_constraint 3 3 1
func get_barrier_semaphore 1 0 0

jax._src.pallas.mosaic.random (5 missing, 0 any)

Symbol Typable Typed Any
attr stateful_bernoulli 1 0 0
attr stateful_uniform 1 0 0
func sample_block 8 7 0
attr stateful_bits 1 0 0
attr stateful_normal 1 0 0

jax._src.pallas.mosaic.sc_core (9 missing, 0 any)

Symbol Typable Typed Any
meth VectorSubcoreMesh.post_init 1 0 0
meth VectorSubcoreMesh.discharges_effect 2 0 0
prop VectorSubcoreMesh.shape 1 0 0
meth MemoryRef.init 5 4 0
meth ScalarSubcoreMesh.discharges_effect 2 0 0
prop ScalarSubcoreMesh.shape 1 0 0
meth BlockSpec.post_init 1 0 0

jax._src.pallas.mosaic.sc_primitives (1 missing, 0 any)

Symbol Typable Typed Any
func subcore_barrier 1 0 0

jax._src.pallas.mosaic_gpu.core (25 missing, 4 any)

Symbol Typable Typed Any
func kernel 13 12 2
meth PeerMemRef.transform_type 2 0 0
meth RefUnion.init 2 1 1
meth SwizzleTransform.post_init 1 0 0
meth Barrier.post_init 1 0 0
meth WGMMAAccumulatorRef.init 2 0 0
func unswizzle_ref 3 2 0
meth Layout.call 3 1 0
meth Layout.to_mgpu 3 1 0
meth TilingTransform.transform_type 2 0 0
func layout_cast 3 2 1
meth Mesh.post_init 1 0 0
meth CompilerParams.post_init 1 0 0
meth TMEMLayout.call 3 1 0
meth TMEMLayout.to_mgpu 3 1 0
meth SemaphoreType.call 2 1 0
prop WarpMesh.shape 1 0 0
func untile_ref 3 2 0

jax._src.pallas.mosaic_gpu.helpers (4 missing, 0 any)

Symbol Typable Typed Any
func find_swizzle 3 2 0
func planar_snake 5 4 0
func format_tcgen05_sparse_metadata 2 0 0

jax._src.pallas.mosaic_gpu.pipeline (1 missing, 0 any)

Symbol Typable Typed Any
func emit_pipeline 7 6 0

jax._src.pallas.mosaic_gpu.primitives (24 missing, 0 any)

Symbol Typable Typed Any
func wgmma 4 2 0
func wgmma_accumulator_load 3 1 0
func async_store_tmem 3 1 0
func semaphore_signal_multicast 4 2 0
func commit_tmem 1 0 0
func tcgen05_mma 10 9 0
func inline_mgpu 3 0 0
func load 5 4 0
func semaphore_signal_parallel 2 1 0
func set_max_registers 3 2 0
func commit_smem 1 0 0
func multimem_store 4 3 0
func wgmma_wait 2 1 0
func wait_load_tmem 1 0 0
func async_copy_smem_to_tmem 4 3 0
func tcgen05_commit_arrive 3 2 0
func async_copy_scales_to_tmem 4 3 0
func async_copy_sparse_metadata_to_tmem 4 3 0

jax._src.pallas.mosaic_gpu.torch (2 missing, 0 any)

Symbol Typable Typed Any
func as_torch_kernel 2 0 0

jax._src.pallas.pallas_call (1 missing, 2 any)

Symbol Typable Typed Any
func pallas_call 15 15 2
attr pallas_call_p 1 0 0

jax._src.pallas.primitives (48 missing, 3 any)

Symbol Typable Typed Any
func dot 7 3 0
func semaphore_signal 6 4 0
func semaphore_read 2 1 0
func debug_print 3 2 0
func reciprocal 3 0 0
func semaphore_wait 4 2 0
func run_scoped 5 5 3
func atomic_or 5 1 0
func atomic_min 5 1 0
func max_contiguous 3 0 0
func atomic_xor 5 1 0
func atomic_cas 4 0 0
func atomic_max 5 1 0
func atomic_add 5 1 0
func atomic_and 5 1 0
func atomic_xchg 5 1 0

jax._src.partition_spec (28 missing, 0 any)

Symbol Typable Typed Any
meth P.init 4 0 0
meth P.repr 1 0 0
meth P.reduce 1 0 0
meth P.getitem 2 0 0
meth P.iter 1 0 0
meth P.len 1 0 0
meth P.eq 2 0 0
meth P.hash 1 0 0
meth P.add 2 0 0
meth P.radd 2 0 0
meth P.index 2 0 0
meth P.count 2 0 0
meth P.update 2 0 0
meth P.to_lo 1 0 0
meth P._check_compatible_wrt_shape 2 0 0
attr P.unreduced 1 0 0
attr P.reduced 1 0 0

jax._src.pjit (17 missing, 0 any)

Symbol Typable Typed Any
func reshard 3 0 0
func with_layout_constraint 3 0 0
attr jit_p 1 0 0
attr sharding_constraint_p 1 0 0
func with_sharding_constraint 3 0 0
func auto_axes 4 1 0
func explicit_axes 4 1 0

jax._src.prng (11 missing, 0 any)

Symbol Typable Typed Any
attr random_fold_in_p 1 0 0
attr random_seed_p 1 0 0
attr random_split_p 1 0 0
attr threefry2x32_p 1 0 0
attr random_bits_p 1 0 0
func threefry_2x32 3 0 0
attr unsafe_rbg_prng_impl 1 0 0
attr rbg_prng_impl 1 0 0
attr threefry_prng_impl 1 0 0

jax._src.profiler (10 missing, 0 any)

Symbol Typable Typed Any
func stop_server 1 0 0
func stop_trace 1 0 0
func annotate_function 4 2 0
meth StepTraceAnnotation.init 3 1 0
func save_device_memory_profile 3 2 0
func trace 5 2 0

jax._src.public_test_util (24 missing, 0 any)

Symbol Typable Typed Any
func check_grads 8 0 0
func check_vjp 8 0 0
func check_jvp 8 0 0

jax._src.random (14 missing, 0 any)

Symbol Typable Typed Any
attr random_gamma_p 1 0 0
func gumbel 6 5 0
func randint 7 6 0
func clone 2 0 0
func multinomial 7 6 0
func wrap_key_data 3 2 0
func bits 5 4 0
func bernoulli 6 5 0
func uniform 7 6 0
func truncated_normal 7 6 0
func ball 6 5 0
func permutation 6 5 0
func normal 5 4 0

jax._src.ref (0 missing, 2 any)

Symbol Typable Typed Any
func new_ref 3 3 2

jax._src.scipy.linalg (0 missing, 2 any)

Symbol Typable Typed Any
func solve_triangular 9 9 1
func qr 7 7 1

jax._src.scipy.ndimage (1 missing, 0 any)

Symbol Typable Typed Any
func map_coordinates 6 5 0

jax._src.scipy.sparse.linalg (26 missing, 0 any)

Symbol Typable Typed Any
func gmres 10 0 0
func cg 8 0 0
func bicgstab 8 0 0

jax._src.scipy.spatial.transform (18 missing, 0 any)

Symbol Typable Typed Any
meth Slerp.init 3 2 0
meth Slerp.call 2 1 0
meth Rotation.concatenate 2 1 0
meth Rotation.from_euler 4 3 0
meth Rotation.from_matrix 2 1 0
meth Rotation.from_mrp 2 1 0
meth Rotation.from_quat 2 1 0
meth Rotation.from_rotvec 3 2 0
meth Rotation.identity 3 1 0
meth Rotation.random 3 2 0
meth Rotation.getitem 2 0 0
meth Rotation.len 1 0 0
meth Rotation.mul 2 1 0
meth Rotation.as_euler 3 2 0
meth Rotation.inv 1 0 0
meth Rotation.mean 2 1 0

jax._src.scipy.stats.kde (39 missing, 0 any)

Symbol Typable Typed Any
meth gaussian_kde.init 4 1 0
meth gaussian_kde._setattr 3 0 0
meth gaussian_kde.tree_flatten 1 0 0
meth gaussian_kde.tree_unflatten 3 0 0
meth gaussian_kde.evaluate 2 0 0
meth gaussian_kde.call 2 0 0
meth gaussian_kde.integrate_gaussian 3 0 0
meth gaussian_kde.integrate_box_1d 3 0 0
meth gaussian_kde.integrate_kde 2 0 0
meth gaussian_kde.resample 3 0 0
meth gaussian_kde.pdf 2 0 0
meth gaussian_kde.logpdf 2 0 0
meth gaussian_kde.integrate_box 4 0 0
meth gaussian_kde.set_bandwidth 2 0 0
meth gaussian_kde._reshape_points 2 0 0
prop gaussian_kde.d 1 0 0
prop gaussian_kde.n 1 0 0

jax._src.scipy.stats.truncnorm (36 missing, 0 any)

Symbol Typable Typed Any
func cdf 6 0 0
func logsf 6 0 0
func logcdf 6 0 0
func sf 6 0 0
func logpdf 6 0 0
func pdf 6 0 0

jax._src.shard_alike (3 missing, 0 any)

Symbol Typable Typed Any
func shard_alike 3 0 0

jax._src.shard_map (0 missing, 2 any)

Symbol Typable Typed Any
func shard_map 7 7 1
func smap 5 5 1

jax._src.sharding (1 missing, 0 any)

Symbol Typable Typed Any
meth Sharding._to_sdy_sharding 2 1 0

jax._src.sharding_impls (27 missing, 0 any)

Symbol Typable Typed Any
meth set_mesh.init 2 1 0
meth set_mesh.enter 1 0 0
meth set_mesh.exit 4 0 0
attr set_mesh.prev_abstract_mesh 1 0 0
attr set_mesh.prev_mesh 1 0 0
meth GSPMDSharding.init 4 3 0
meth GSPMDSharding.reduce 1 0 0
meth GSPMDSharding.eq 2 0 0
meth GSPMDSharding.hash 1 0 0
meth GSPMDSharding.repr 1 0 0
meth GSPMDSharding.get_replicated 3 1 0
prop GSPMDSharding._hlo_sharding_hash 1 0 0
prop SPMDAxisContext.axis_env 1 0 0
prop SPMDAxisContext.unsafe_axis_env 1 0 0
meth ShardingContext.post_init 1 0 0
prop ShardingContext.axis_env 1 0 0
meth SingleDeviceSharding.init 3 2 0
meth SingleDeviceSharding.reduce 1 0 0
meth SingleDeviceSharding.repr 1 0 0
meth SingleDeviceSharding.hash 1 0 0
meth SingleDeviceSharding.eq 2 0 0

jax._src.source_info_util (4 missing, 0 any)

Symbol Typable Typed Any
func summarize 3 2 0
meth NameStack.len 1 0 0
func register_exclusion 2 1 0
meth SourceInfo.init 3 2 0

jax._src.sourcemap (4 missing, 0 any)

Symbol Typable Typed Any
meth MappingsGenerator.init 1 0 0
meth MappingsGenerator.new_group 1 0 0
meth MappingsGenerator.new_segment 2 0 0

jax._src.stages (47 missing, 2 any)

Symbol Typable Typed Any
meth Lowered.init 7 3 0
prop Lowered.in_avals 1 0 0
prop Lowered.out_info 1 0 0
attr Lowered.args_info 1 1 1
meth Compiled.init 8 1 0
meth Compiled._input_shardings_flat 1 0 0
meth Compiled._input_layouts_flat 1 0 0
meth Compiled.call 3 0 0
meth Compiled.call 3 0 0
prop Compiled.in_avals 1 0 0
prop Compiled.out_info 1 0 0
prop Compiled.input_shardings 1 0 0
prop Compiled.output_shardings 1 0 0
prop Compiled.input_formats 1 0 0
prop Compiled._output_formats_flat 1 0 0
prop Compiled.output_formats 1 0 0
attr Compiled.args_info 1 1 1
prop ArgInfo.shape 1 0 0
prop ArgInfo.dtype 1 0 0
meth Traced.init 6 0 0
meth Traced.call 3 0 0
meth Traced.lower 3 2 0
prop Traced.out_avals 1 0 0
attr Traced.out_tree 1 0 0
attr Traced.jaxpr 1 0 0
attr Traced.fun_name 1 0 0
attr Traced.args_info 1 0 0
attr Traced.out_info 1 0 0
attr Traced._num_consts 1 0 0

jax._src.state.types (91 missing, 2 any)

Symbol Typable Typed Any
meth AbstractRef.init 4 3 2
meth AbstractRef.lo_ty 1 0 0
meth AbstractRef.lower_val 2 0 0
meth AbstractRef.raise_val 2 0 0
meth AbstractRef.update_weak_type 2 0 0
meth AbstractRef.update_vma 2 0 0
meth AbstractRef.update 4 0 0
meth AbstractRef._len 2 1 0
meth AbstractRef.at 1 0 0
meth AbstractRef.bitcast 2 0 0
meth AbstractRef.reshape 2 0 0
meth AbstractRef.transpose 2 0 0
meth AbstractRef.T 1 0 0
meth AbstractRef.get 3 0 0
meth AbstractRef.swap 4 0 0
meth AbstractRef.set 4 0 0
meth AbstractRef.addupdate 4 0 0
meth AbstractRef._getitem 3 1 0
meth AbstractRef._setitem 4 1 0
meth AbstractRef._addupdate 4 0 0
meth AbstractRef.str_short 3 1 0
meth AbstractRef.to_tangent_aval 1 0 0
meth AbstractRef.to_cotangent_aval 1 0 0
meth AbstractRef.eq 2 0 0
meth AbstractRef.hash 1 0 0
prop AbstractRef.is_high 1 0 0
prop AbstractRef.shape 1 0 0
prop AbstractRef.dtype 1 0 0
prop AbstractRef.sharding 1 0 0
prop AbstractRef.vma 1 0 0
attr AbstractRef.inner_aval 1 0 0
attr AbstractRef.memory_space 1 0 0
attr AbstractRef.kind 1 0 0
attr AbstractRef.ndim 1 0 0
attr AbstractRef.size 1 0 0
meth TransformedRef.bitcast 2 0 0
meth TransformedRef.reshape 2 0 0
meth TransformedRef.transpose 2 1 0
meth TransformedRef.set 3 0 0
meth TransformedRef.swap 3 0 0
meth TransformedRef.get 2 0 0
meth TransformedRef.getattr 2 0 0
meth TransformedRef.getitem 2 0 0
meth TransformedRef.setitem 3 0 0
prop TransformedRef.is_dynamic_size 1 0 0
prop TransformedRef.dtype 1 0 0
attr TransformedRef.ndim 1 0 0
attr TransformedRef.size 1 0 0
attr TransformedRef.T 1 0 0
meth TransposeTransform.transform_type 2 0 0

jax._src.stateful_rng (2 missing, 0 any)

Symbol Typable Typed Any
meth StatefulPRNG.post_init 1 0 0
meth StatefulPRNG.random 3 2 0

jax._src.third_party.scipy.interpolate (22 missing, 0 any)

Symbol Typable Typed Any
meth RegularGridInterpolator.init 6 0 0
meth RegularGridInterpolator.call 3 0 0
meth RegularGridInterpolator._evaluate_linear 3 0 0
meth RegularGridInterpolator._evaluate_nearest 3 0 0
meth RegularGridInterpolator._find_indices 2 0 0
attr RegularGridInterpolator.method 1 0 0
attr RegularGridInterpolator.bounds_error 1 0 0
attr RegularGridInterpolator.fill_value 1 0 0
attr RegularGridInterpolator.grid 1 0 0
attr RegularGridInterpolator.values 1 0 0

jax._src.tpu.linalg.qdwh (2 missing, 0 any)

Symbol Typable Typed Any
func qdwh 6 4 0

jax._src.tpu_custom_call (0 missing, 2 any)

Symbol Typable Typed Any
func as_tpu_kernel 22 22 1
func lower_module_to_custom_call 25 25 1

jax._src.tree (2 missing, 20 any)

Symbol Typable Typed Any
func static 2 0 0
func reduce 5 5 1
func reduce_associative 5 5 1
func map_with_path 6 6 3
func flatten_with_path 4 4 1
func leaves_with_path 4 4 1
func map 5 5 3
func unflatten 3 3 1
func structure 3 3 1
func flatten 3 3 1
func all 3 3 1
func broadcast 4 4 3
func leaves 3 3 1
func transpose 4 4 2

jax._src.tree_util (6 missing, 24 any)

Symbol Typable Typed Any
func tree_unflatten 3 3 1
meth Partial.new 4 0 0
attr default_registry 1 0 0
func tree_reduce_associative 5 5 1
func tree_all 3 3 1
func tree_flatten 3 3 1
func tree_reduce 5 5 1
attr DictKey 1 1 1
attr GetAttrKey 1 1 1
func tree_map_with_path 6 6 3
attr SequenceKey 1 1 1
func tree_broadcast 4 4 3
attr FlattenedIndexKey 1 1 1
func tree_map 5 5 3
func register_pytree_with_keys 5 4 0
func tree_leaves_with_path 4 4 1
func tree_leaves 3 3 1
func tree_structure 3 3 1
func tree_flatten_with_path 4 4 1
func tree_transpose 4 4 2

jax._src.util (1 missing, 0 any)

Symbol Typable Typed Any
func safe_map 7 6 0

jax._src.xla_bridge (1 missing, 0 any)

Symbol Typable Typed Any
func backend_xla_version 2 1 0

jax._src.xla_metadata (3 missing, 0 any)

Symbol Typable Typed Any
func set_xla_metadata 3 0 0

jax.collect_profile (7 missing, 0 any)

Symbol Typable Typed Any
attr known_args 1 0 0
func main 3 0 0
attr parser 1 0 0
func collect_profile 7 6 0
attr unknown_flags 1 0 0

jax.example_libraries.optimizers (58 missing, 2 any)

Symbol Typable Typed Any
func rmsprop 4 0 0
func sm3 3 0 0
func inverse_time_decay 5 0 0
meth JoinPoint.init 2 0 0
meth JoinPoint.iter 1 0 0
attr JoinPoint.subtree 1 0 0
func unpack_optimizer_state 2 0 0
func l2_norm 2 0 0
func pack_optimizer_state 2 0 0
func momentum 3 2 0
func polynomial_decay 5 0 0
func constant 2 1 0
func adagrad 3 0 0
func adam 5 0 0
func rmsprop_momentum 5 0 0
func exponential_decay 4 0 0
func nesterov 3 2 0
func sgd 2 0 0
func clip_grads 3 0 0
func adamax 5 0 0
func piecewise_constant 3 2 2

jax.example_libraries.stax (63 missing, 0 any)

Symbol Typable Typed Any
attr SumPool 1 0 0
func FanInSum 1 0 0
func GeneralConv 8 0 0
attr Sigmoid 1 0 0
func BatchNorm 7 0 0
func FanOut 2 0 0
func Flatten 1 0 0
attr Relu 1 0 0
attr Conv 1 0 0
attr Exp 1 0 0
func parallel 2 0 0
attr ConvTranspose 1 0 0
func Dropout 3 0 0
attr Softmax 1 0 0
attr Elu 1 0 0
func Identity 1 0 0
attr Conv1DTranspose 1 0 0
attr Gelu 1 0 0
attr Tanh 1 0 0
func shape_dependent 2 0 0
attr Selu 1 0 0
func elementwise 3 0 0
func serial 2 0 0
attr MaxPool 1 0 0
func FanInConcat 2 0 0
attr Softplus 1 0 0
attr AvgPool 1 0 0
attr LogSoftmax 1 0 0
attr LeakyRelu 1 0 0
func Dense 4 0 0
func GeneralConvTranspose 8 0 0

jax.experimental.array_serialization.pytree_serialization (0 missing, 6 any)

Symbol Typable Typed Any
func nonblocking_save 5 5 1
func load_pytreedef 2 2 1
func load 5 5 2
func save 5 5 1
func nonblocking_load 5 5 1

jax.experimental.array_serialization.pytree_serialization_utils (2 missing, 0 any)

Symbol Typable Typed Any
func serialize_pytreedef 2 1 0
func deserialize_pytreedef 2 1 0

jax.experimental.array_serialization.serialization (23 missing, 0 any)

Symbol Typable Typed Any
attr logger 1 0 0
meth GlobalAsyncCheckpointManager.serialize 5 2 0
meth GlobalAsyncCheckpointManager.serialize_with_paths 5 4 0
meth GlobalAsyncCheckpointManager.deserialize 6 5 0
meth GlobalAsyncCheckpointManager.deserialize_with_paths 6 5 0
attr get_tensorstore_spec 1 0 0
meth GlobalAsyncCheckpointManagerBase.check_for_errors 1 0 0
meth GlobalAsyncCheckpointManagerBase.wait_until_finished 1 0 0
meth GlobalAsyncCheckpointManagerBase.serialize 4 1 0
meth GlobalAsyncCheckpointManagerBase.deserialize 5 4 0
meth AsyncManager.init 2 0 0
meth AsyncManager.del 1 0 0
meth AsyncManager._thread_func 1 0 0
meth AsyncManager._start_async_commit 2 0 0
meth AsyncManager.check_for_errors 1 0 0
meth AsyncManager.wait_until_finished 1 0 0
meth AsyncManager._add_futures 2 1 0

jax.experimental.array_serialization.tensorstore_impl (13 missing, 1 any)

Symbol Typable Typed Any
func async_deserialize 9 5 0
attr logger 1 0 0
func merge_nested_ts_specs 3 2 0
func async_serialize 9 3 0
func is_tensorstore_spec_leaf 2 1 1

jax.experimental.colocated_python.api (1 missing, 0 any)

Symbol Typable Typed Any
func colocated_python 2 1 0

jax.experimental.colocated_python.func (4 missing, 0 any)

Symbol Typable Typed Any
meth Specialization.update 7 6 0
func make_callable 4 3 0
meth WeakSpec.init 3 2 0
meth StrongSpec.init 3 2 0

jax.experimental.colocated_python.func_backend (1 missing, 0 any)

Symbol Typable Typed Any
attr SINGLETON_RESULT_STORE 1 0 0

jax.experimental.colocated_python.obj (1 missing, 0 any)

Symbol Typable Typed Any
attr SINGLETON_INSTANCE_REGISTRY 1 0 0

jax.experimental.colocated_python.obj_backend (1 missing, 0 any)

Symbol Typable Typed Any
attr SINGLETON_OBJECT_STORE 1 0 0

jax.experimental.fused (3 missing, 0 any)

Symbol Typable Typed Any
func fused 2 0 0
attr fused_p 1 0 0

jax.experimental.jax2tf.call_tf (12 missing, 1 any)

Symbol Typable Typed Any
func call_tf 6 2 0
attr call_tf_ordered_effect 1 0 0
func emit_tf_embedded_graph_custom_call 7 2 0
attr call_tf_effect 1 0 0
attr call_tf_p 1 0 0
func add_to_call_tf_concrete_function_list 3 3 1

jax.experimental.jax2tf.jax2tf (16 missing, 5 any)

Symbol Typable Typed Any
func split_to_logical_devices 3 2 1
func dtype_of_val 2 2 2
func eval_polymorphic_shape 3 2 0
func preprocess_arg_tf 3 3 2
func inside_call_tf 1 0 0
meth NativeSerializationImpl.init 6 2 0
meth NativeSerializationImpl.before_conversion 1 0 0
meth NativeSerializationImpl.after_conversion 1 0 0
attr NativeSerializationImpl.convert_kwargs 1 0 0
attr NativeSerializationImpl.fun_jax 1 0 0
attr NativeSerializationImpl.args_specs 1 0 0
attr NativeSerializationImpl.kwargs_specs 1 0 0
attr NativeSerializationImpl.native_serialization_disabled_checks 1 0 0
attr NativeSerializationImpl.native_serialization_platforms 1 0 0
attr DEFAULT_NATIVE_SERIALIZATION 1 0 0

jax.experimental.jet (92 missing, 0 any)

Symbol Typable Typed Any
func deflinear 2 0 0
attr zero_series 1 0 0
func fact 2 0 0
func zero_prop 5 0 0
func deriv_prop 5 0 0
func jet_subtrace 6 0 0
meth JetTracer.init 4 0 0
meth JetTracer.full_lower 1 0 0
prop JetTracer.aval 1 0 0
attr JetTracer.primal 1 0 0
attr JetTracer.terms 1 0 0
meth JetTrace.init 4 0 0
meth JetTrace.to_primal_terms_pair 2 0 0
meth JetTrace.process_primitive 4 0 0
meth JetTrace.process_call 5 0 0
meth JetTrace.process_custom_jvp_call 6 0 0
meth JetTrace.process_custom_vjp_call 7 0 0
attr JetTrace.tag 1 0 0
attr JetTrace.parent_trace 1 0 0
attr JetTrace.order 1 0 0
func traceable 5 0 0
func linear_prop 5 0 0
func def_deriv 3 0 0
attr jet2 1 0 0
func defzero 2 0 0
func jet_fun 5 0 0
func jet 6 0 0
attr zero_term 1 0 0
func def_comp 4 0 0

jax.experimental.mosaic.gpu.constraints (13 missing, 0 any)

Symbol Typable Typed Any
meth Equals.str 1 0 0
meth NotOfType.str 1 0 0
meth IsValidMmaTiling.str 1 0 0
meth ConstraintSystem.str 1 0 0
meth SMEMTiling.str 1 0 0
meth Reduce.str 1 0 0
meth Transpose.str 1 0 0
meth IsTransferable.str 1 0 0
meth RegisterLayout.str 1 0 0
meth Relayout.str 1 0 0
meth TMEMLayout.str 1 0 0
meth Variable.str 1 0 0
meth Divides.str 1 0 0

jax.experimental.mosaic.gpu.core (21 missing, 0 any)

Symbol Typable Typed Any
meth Barrier.post_init 1 0 0
meth Union.iter 1 0 0
meth TMEM.post_init 1 0 0
func supports_cross_device_collectives 1 0 0
func as_gpu_kernel 14 9 0
func as_torch_gpu_kernel 13 8 0
func is_nvshmem_available 1 0 0
attr mosaic_gpu_p 1 0 0
func is_single_process_multi_device_topology 1 0 0
attr libdevice_path 1 0 0
attr PYTHON_RUNFILES 1 0 0
func artificial_shared_memory_limit 2 0 0

jax.experimental.mosaic.gpu.dialect_lowering (3 missing, 0 any)

Symbol Typable Typed Any
func lower_mgpu_dialect 4 3 0
meth LoweringContext.lower_op 2 1 0
attr RECURSED 1 0 0

jax.experimental.mosaic.gpu.fragmented_array (120 missing, 3 any)

Symbol Typable Typed Any
attr WGMMA_COL_LAYOUT 1 0 0
attr TiledLayout 1 1 1
attr WGMMA_ROW_LAYOUT 1 0 0
attr WGMMA_LAYOUT 1 0 0
attr TCGEN05_TRANSPOSED_LAYOUT 1 0 0
attr WGMMA_LAYOUT_UPCAST_4X 1 0 0
func tmem_native_layout 2 1 0
attr TCGEN05_COL_LAYOUT 1 0 0
attr WGMMA_LAYOUT_UPCAST_2X 1 0 0
attr TCGEN05_LAYOUT 1 0 0
meth FragmentedArray.init 4 3 0
meth FragmentedArray.splat 5 2 0
meth FragmentedArray._pointwise 5 3 0
meth FragmentedArray.pos 1 0 0
meth FragmentedArray.neg 1 0 0
meth FragmentedArray.add 2 0 0
meth FragmentedArray.radd 2 0 0
meth FragmentedArray.mul 2 0 0
meth FragmentedArray.rmul 2 0 0
meth FragmentedArray.sub 2 0 0
meth FragmentedArray.rsub 2 0 0
meth FragmentedArray.truediv 2 0 0
meth FragmentedArray.rtruediv 2 0 0
meth FragmentedArray.floordiv 2 0 0
meth FragmentedArray.rfloordiv 2 0 0
meth FragmentedArray.mod 2 0 0
meth FragmentedArray.rmod 2 0 0
meth FragmentedArray.invert 1 0 0
meth FragmentedArray.or 2 0 0
meth FragmentedArray.ror 2 0 0
meth FragmentedArray.and 2 0 0
meth FragmentedArray.rand 2 0 0
meth FragmentedArray.xor 2 0 0
meth FragmentedArray.rxor 2 0 0
meth FragmentedArray.lshift 2 0 0
meth FragmentedArray.rshift 2 0 0
meth FragmentedArray.eq 2 0 0
meth FragmentedArray.ne 2 0 0
meth FragmentedArray.lt 2 0 0
meth FragmentedArray.le 2 0 0
meth FragmentedArray.gt 2 0 0
meth FragmentedArray.ge 2 0 0
meth FragmentedArray._compare 6 0 0
meth FragmentedArray.max 2 1 0
meth FragmentedArray.min 2 1 0
meth FragmentedArray.getitem 2 1 0
meth FragmentedArray.broadcast 2 1 0
meth FragmentedArray.broadcast_minor 2 1 0
meth FragmentedArray.broadcast_in_dim 4 2 0
meth FragmentedArray.select 3 0 0
meth FragmentedArray.foreach 4 1 0
meth FragmentedArray.load_reduce_untiled 6 5 0
meth FragmentedArray._store_untiled_splat 2 1 0
meth FragmentedArray.store_tiled_async 8 7 0
meth FragmentedArray.store_tiled 5 4 0
meth FragmentedArray.load_tiled 9 8 0
meth FragmentedArray.transfer_strided 3 2 0
meth FragmentedArray.transfer_tiled 7 6 0
meth FragmentedArray.tree_flatten 1 0 0
meth FragmentedArray.tree_unflatten 3 0 0
attr TCGEN05_ROW_LAYOUT 1 0 0
func copy_tiled 4 3 0
attr TMA_GATHER_INDICES_LAYOUT 1 0 0
meth WGStridedFragLayout.post_init 1 0 0
meth WGStridedFragLayout.thread_idxs 2 0 0
meth WGStridedFragLayout.linear_thread_idxs 1 0 0
meth WGSplatFragLayout.can_broadcast_to 2 1 0
meth WGSplatFragLayout.thread_idxs 2 0 0
attr WGMMA_LAYOUT_8BIT 1 0 0
attr TMEM_NATIVE_LAYOUT 1 0 0
attr WGMMA_TRANSPOSED_LAYOUT 1 0 0
func subf 3 2 0
func addf 3 2 0
prop StaggeredTransferPlanImpl.tile_index_transforms 1 0 0
attr WGMMA_LAYOUT_ACC_32BIT 1 0 0
meth TiledLayoutImpl.post_init 2 1 0
attr Tiling 1 1 1
attr Replicated 1 1 1
func mulf 3 2 0
prop TrivialTransferPlanImpl.tile_index_transforms 1 0 0

jax.experimental.mosaic.gpu.inference_utils (2 missing, 0 any)

Symbol Typable Typed Any
attr in_transforms_for_operand 1 0 0
attr in_layout_for_operand 1 0 0

jax.experimental.mosaic.gpu.launch_context (24 missing, 3 any)

Symbol Typable Typed Any
attr GLOBAL_BROADCAST 1 0 0
meth LaunchContext.named_region 3 0 0
meth LaunchContext._recompute_peer_id 3 2 0
meth LaunchContext._get_tma_desc 7 6 0
meth LaunchContext._prepare_async_copy 7 6 1
meth LaunchContext._prepare_tma 11 9 0
meth LaunchContext.async_copy 14 13 1
meth LaunchContext.async_prefetch 9 8 1
meth LaunchContext.await_async_copy 4 3 0
meth LaunchContext.await_cp_async_copy 2 1 0
meth LaunchContext._ensure_nvshmem_decls 1 0 0
meth LaunchContext._find_kernel_argument_index 2 1 0
meth LaunchContext.to_remote 5 4 0
meth LaunchContext.to_remote_multicast 2 1 0
meth TileTransform.post_init 1 0 0
meth TransposeTransform.post_init 1 0 0
meth Scratch.init 2 1 0
meth Scratch._create_ops 1 0 0
meth Scratch.finalize_size 1 0 0
func uses_collective_metadata 2 0 0

jax.experimental.mosaic.gpu.layout_inference (4 missing, 0 any)

Symbol Typable Typed Any
func infer_layout 3 2 0
meth ValueSite.post_init 1 0 0
meth ValueSite.str 1 0 0
func traverse_op 3 2 0

jax.experimental.mosaic.gpu.mma (4 missing, 0 any)

Symbol Typable Typed Any
meth MMALayouts.init 2 1 0
attr MMALayouts.lhs 1 0 0
attr MMALayouts.rhs 1 0 0
attr MMALayouts.acc 1 0 0

jax.experimental.mosaic.gpu.mma_utils (5 missing, 0 any)

Symbol Typable Typed Any
func create_descriptor 8 7 0
func tiled_memref_shape 2 1 0
func encode_descriptor 7 5 0
func encode_addr 2 1 0

jax.experimental.mosaic.gpu.profiler (18 missing, 0 any)

Symbol Typable Typed Any
meth Cupti.measure 4 2 0
meth ProfilerSpec.init 3 2 0
meth ProfilerSpec.smem_i32_elements 2 1 0
meth ProfilerSpec.smem_bytes 2 1 0
meth ProfilerSpec.dump 5 2 0
attr ProfilerSpec.entries_per_warpgroup 1 0 0
attr ProfilerSpec.dump_path 1 0 0
meth OnDeviceProfiler.init 5 4 0
meth OnDeviceProfiler._profiler_ctx 1 0 0
meth OnDeviceProfiler.record 2 1 0
meth OnDeviceProfiler.finalize 3 2 0
attr OnDeviceProfiler.spec 1 0 0
attr OnDeviceProfiler.entries_per_wg 1 0 0
attr OnDeviceProfiler.wrap_in_custom_primitive 1 0 0
attr OnDeviceProfiler.ctx 1 0 0

jax.experimental.mosaic.gpu.tcgen05 (11 missing, 0 any)

Symbol Typable Typed Any
meth TMEMRef.post_init 1 0 0
meth TMEMRef.from_alloc 6 5 0
meth TMEMRef.slice 2 1 0
meth TMEMRef.store 2 1 0
func create_instr_descriptor 8 6 0
func tmem_half_lane_layout 3 2 0
func create_scaled_f8f6f4_instr_descriptor 3 1 0
func create_scaled_f4_instr_descriptor 3 1 0

jax.experimental.mosaic.gpu.utils (113 missing, 0 any)

Symbol Typable Typed Any
func debug_print 5 0 0
func bitwidth 2 1 0
func c 3 1 0
func memref_fold 4 2 0
func single_thread 2 1 0
func when 2 0 0
func nanosleep 2 1 0
meth DynamicSlice.post_init 1 0 0
func fori 3 0 0
func try_cluster_cancel 4 2 0
meth Partition1D.init 5 4 0
meth Partition1D.refine 4 3 0
attr Partition1D.base_offset 1 0 0
func is_known_divisible 4 3 0
func bytewidth 2 1 0
func get_cluster_ref 5 4 0
func query_cluster_cancel 2 1 0
func warpgroup_idx 2 0 0
func memref_unsqueeze 3 2 0
meth CollectiveBarrierRef.iter 1 0 0
meth CollectiveBarrierRef.getitem 2 0 0
meth CollectiveBarrierRef.arrive 2 1 0
meth CollectiveBarrierRef.wait 3 0 0
meth CollectiveBarrierRef.wait_parity 3 0 0
meth Partition.init 6 5 0
prop Partition.target_block_shape 1 0 0
meth DialectBarrierRef.wait_parity 3 0 0
meth DialectBarrierRef.wait 2 1 0
meth DialectBarrierRef.arrive 2 1 0
meth DialectBarrierRef.arrive_expect_tx 2 1 0
meth DialectBarrierRef.get_ptr 1 0 0
meth DialectBarrierRef.from_barrier_memref 2 1 0
func commit_shared 1 0 0
func warpgroup_barrier 1 0 0
func memref_unfold 4 2 0
meth BarrierRef.wait_parity 3 0 0
meth BarrierRef.wait 2 1 0
meth BarrierRef.arrive 5 4 0
meth BarrierRef.arrive_expect_tx 3 2 0
meth BarrierRef.complete_tx 3 2 0
meth BarrierRef.get_ptr 1 0 0
func warp_idx 2 0 0
attr thread_idx 1 0 0
func memref_slice 3 2 0
meth SemaphoreRef.signal 4 3 0
meth SemaphoreRef.signal_multimem 4 1 0
meth SemaphoreRef.wait 4 3 0
func single_thread_predicate 2 1 0
func tile_shape 3 0 0
func memref_ptr 3 0 0
meth MultimemRef.store 3 2 0
func warp_barrier 1 0 0
attr block_idx 1 0 0
func redux 4 3 0
func warp_tree_reduce 4 0 0
func prmt 4 3 0
func smid 1 0 0
func get_cluster_ptr 4 3 0
func vector_slice 3 2 0
func shfl_bfly 3 2 0
func multimem_store 3 2 0
func fence_release_sys 1 0 0
func get_contiguous_strides 2 0 0
func parse_indices 4 3 0
prop ForResult.result 1 0 0
func ptr_as_memref 4 2 0
func pack_array 2 0 0
func multimem_load_reduce 5 4 0
func dyn_dot 3 0 0
func clock 1 0 0
func ceil_div 3 2 0
func bitwidth_impl 2 1 0
func cluster_collective_mask 3 2 0
func bitcast 3 2 0
func nvvm_mbarrier_arrive_expect_tx 4 3 0
attr WORKGROUP_NVPTX_ADDRESS_SPACE 1 0 0
func globaltimer 2 1 0

jax.experimental.mosaic.gpu.wgmma (14 missing, 0 any)

Symbol Typable Typed Any
func wgmma 5 4 0
meth WGMMAAccumulator.init 4 3 0
meth WGMMAAccumulator.zero 5 1 0
meth WGMMAAccumulator.from_registers 2 0 0
meth WGMMAAccumulator.tree_flatten 1 0 0
meth WGMMAAccumulator.tree_unflatten 3 0 0
func wgmma_m64 11 9 0

jax.experimental.multihost_utils (25 missing, 12 any)

Symbol Typable Typed Any
func host_local_array_to_global_array_impl 4 3 2
func gtl_abstract_eval 4 0 0
func global_array_to_host_local_array 4 3 2
func process_allgather 3 3 2
attr global_array_to_host_local_array_p 1 0 0
func global_array_to_host_local_array_impl 4 3 2
func ltg_batcher 7 0 0
func ltg_abstract_eval 4 0 0
func broadcast_one_to_all 3 3 2
func assert_equal 3 1 0
attr live_devices 1 0 0
attr host_local_array_to_global_array_p 1 0 0
func host_local_array_to_global_array 4 3 2
func sync_global_devices 2 1 0

jax.experimental.ode (57 missing, 0 any)

Symbol Typable Typed Any
func abs2 2 0 0
func initial_step_size 8 0 0
func mean_error_ratio 6 0 0
func optimal_step_size 7 0 0
func ravel_first_arg 4 2 0
func fit_4th_order_polynomial 7 0 0
func runge_kutta_step 6 0 0
func ravel_first_arg_ 5 0 0
func odeint 9 0 0
func interp_fit_dopri 5 0 0

jax.experimental.pallas.ops.gpu.attention (27 missing, 2 any)

Symbol Typable Typed Any
meth BlockSizes.get_default 1 0 0
attr DEFAULT_MASK_VALUE 1 0 0
func segment_mask 3 2 0
func mha 15 11 0
func mha_reference 7 2 0
func mha_forward_kernel 12 8 2
func mha_backward_kernel 19 8 0

jax.experimental.pallas.ops.gpu.attention_mgpu (18 missing, 0 any)

Symbol Typable Typed Any
func attention_with_pipeline_emitter 6 1 0
func main 2 0 0
func attention_reference 6 0 0
meth TuningConfig.post_init 1 0 0
func attention 6 2 0

jax.experimental.pallas.ops.gpu.blackwell_matmul_mgpu (4 missing, 0 any)

Symbol Typable Typed Any
func main 2 1 0
func matmul_kernel 4 1 0

jax.experimental.pallas.ops.gpu.blackwell_ragged_dot_mgpu (24 missing, 0 any)

Symbol Typable Typed Any
func main 2 1 0
func ragged_dot_kernel 5 1 0
func do_matmul 19 6 0
meth TuningConfig.str 1 0 0
func sample_group_sizes 5 4 0
func ragged_dot_reference 4 0 0

jax.experimental.pallas.ops.gpu.collective_matmul_mgpu (1 missing, 0 any)

Symbol Typable Typed Any
func all_gather_lhs_matmul 6 5 0

jax.experimental.pallas.ops.gpu.decode_attention (49 missing, 2 any)

Symbol Typable Typed Any
func gqa 17 11 0
func gqa_reference 9 0 0
func mqa 17 11 0
func attn_forward_kernel 12 6 2
func mqa_reference 9 0 0
func mha_reference 7 0 0
func decode_attn_unbatched 17 11 0

jax.experimental.pallas.ops.gpu.hopper_matmul_mgpu (12 missing, 0 any)

Symbol Typable Typed Any
func kernel 7 1 0
func main 2 1 0
meth MatmulDimension.str 1 0 0
meth MatmulDimension.repr 1 0 0
func matmul 4 1 0

jax.experimental.pallas.ops.gpu.hopper_mixed_type_matmul_mgpu (3 missing, 0 any)

Symbol Typable Typed Any
func main 2 1 0
meth MatmulDimension.str 1 0 0
meth MatmulDimension.repr 1 0 0

jax.experimental.pallas.ops.gpu.layer_norm (39 missing, 0 any)

Symbol Typable Typed Any
func layer_norm_reference 5 1 0
func layer_norm_forward_kernel 9 2 0
func layer_norm_backward_kernel_dx 10 2 0
func layer_norm 9 5 0
func layer_norm_backward_kernel_dw_db 12 3 0
func layer_norm_backward 8 5 0
func layer_norm_forward 9 5 0

jax.experimental.pallas.ops.gpu.paged_attention (9 missing, 2 any)

Symbol Typable Typed Any
attr DEFAULT_MASK_VALUE 1 0 0
func paged_attention_kernel 14 6 2

jax.experimental.pallas.ops.gpu.ragged_dot_mgpu (9 missing, 0 any)

Symbol Typable Typed Any
meth GroupInfo.create 4 0 0
func ragged_dot 11 8 0
func main 2 0 0

jax.experimental.pallas.ops.gpu.reduce_scatter_mgpu (1 missing, 0 any)

Symbol Typable Typed Any
func reduce_scatter 8 7 0

jax.experimental.pallas.ops.gpu.rms_norm (36 missing, 0 any)

Symbol Typable Typed Any
func rms_norm 9 5 0
func rms_norm_forward 9 5 0
func rms_norm_backward_kernel_dx 9 2 0
func rms_norm_backward 8 5 0
func rms_norm_backward_kernel_dw_db 11 3 0
func rms_norm_reference 5 1 0
func rms_norm_forward_kernel 8 2 0

jax.experimental.pallas.ops.gpu.transposed_ragged_dot_mgpu (9 missing, 0 any)

Symbol Typable Typed Any
func ref_transposed_ragged_dot 4 0 0
func transposed_ragged_dot 9 6 0
func main 2 0 0

jax.experimental.pallas.ops.tpu.all_gather (7 missing, 0 any)

Symbol Typable Typed Any
func ag_kernel 7 2 0
func all_gather 5 3 0

jax.experimental.pallas.ops.tpu.example_kernel (5 missing, 0 any)

Symbol Typable Typed Any
func double_kernel 3 0 0
func double 2 0 0

jax.experimental.pallas.ops.tpu.flash_attention (38 missing, 0 any)

Symbol Typable Typed Any
meth BlockSizes.post_init 1 0 0
meth BlockSizes.get_default 6 0 0
attr DEFAULT_MASK_VALUE 1 0 0
func mha_reference_no_custom_vjp 10 6 0
func below_or_on_diag 5 0 0
func mha_reference_bwd 13 4 0
func mha_reference 9 3 0
func flash_attention 10 4 0

jax.experimental.pallas.ops.tpu.matmul (6 missing, 0 any)

Symbol Typable Typed Any
func matmul 7 6 0
func matmul_kernel 5 0 0

jax.experimental.pallas.ops.tpu.megablox.gmm (0 missing, 1 any)

Symbol Typable Typed Any
func make_group_metadata 7 7 1

jax.experimental.pallas.ops.tpu.megablox.ops (1 missing, 0 any)

Symbol Typable Typed Any
attr gmm 1 0 0

jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel (59 missing, 0 any)

Symbol Typable Typed Any
attr DEFAULT_MASK_VALUE 1 0 0
func paged_flash_attention_kernel 26 6 0
func paged_flash_attention_kernel_inline_seq_dim 25 6 0
meth MultiPageAsyncCopyDescriptor.init 10 0 0
meth MultiPageAsyncCopyDescriptor._make_async_copy 2 0 0
meth MultiPageAsyncCopyDescriptor._make_scales_async_copy 2 0 0
meth MultiPageAsyncCopyDescriptor.start 1 0 0
meth MultiPageAsyncCopyDescriptor._maybe_dequantize 4 0 0

jax.experimental.pallas.ops.tpu.paged_attention.util (1 missing, 0 any)

Symbol Typable Typed Any
attr MASK_VALUE 1 0 0

jax.experimental.pallas.ops.tpu.ragged_paged_attention.kernel (34 missing, 0 any)

Symbol Typable Typed Any
func ref_ragged_paged_attention 13 12 0
func ragged_paged_attention 16 15 0
attr DEFAULT_MASK_VALUE 1 0 0
func dynamic_validate_inputs 16 15 0
func get_dtype_packing 2 0 0
func static_validate_inputs 16 15 0
meth MultiPageAsyncCopyDescriptor.init 6 0 0
meth MultiPageAsyncCopyDescriptor.start 1 0 0
meth MultiPageAsyncCopyDescriptor.wait 1 0 0
func ragged_paged_attention_kernel 20 6 0
func get_min_heads_per_blk 5 0 0

jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes (15 missing, 0 any)

Symbol Typable Typed Any
func get_device_name 2 1 0
func get_min_page_size 3 0 0
func get_tuned_block_sizes 9 1 0
func simplify_key 2 0 0
func next_power_of_2 2 1 0

jax.experimental.pallas.ops.tpu.random.philox (20 missing, 0 any)

Symbol Typable Typed Any
func philox_4x32 8 0 0
func philox_split 3 1 0
func philox_fold_in 3 0 0
func philox_4x32_count 5 3 0
func philox_4x32_kernel 7 5 0
func philox_random_bits 4 2 0
attr plphilox_prng_impl 1 0 0

jax.experimental.pallas.ops.tpu.random.prng_utils (3 missing, 0 any)

Symbol Typable Typed Any
func compute_scalar_offset 4 2 0
func blocked_iota 3 2 0

jax.experimental.pallas.ops.tpu.random.threefry (5 missing, 0 any)

Symbol Typable Typed Any
func plthreefry_random_bits 4 2 0
func threefry_2x32_count 5 3 0
attr plthreefry_prng_impl 1 0 0

jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel (41 missing, 1 any)

Symbol Typable Typed Any
attr make_splash_mha_single_device 1 0 0
meth BlockSizes.post_init 1 0 0
meth BlockSizes.get_default 1 0 0
attr make_splash_mqa 1 0 0
attr make_masked_mha_reference 1 0 0
attr make_splash_mqa_single_device 1 0 0
attr make_masked_mqa_reference 1 0 0
attr make_splash_mha 1 0 0
func from_head_minor 3 2 0
func attention_reference_custom 11 10 0
attr DEFAULT_MASK_VALUE 1 0 0
func make_attention_reference 5 5 1
meth SplashAttentionKernel.init 5 3 0
meth SplashAttentionKernel.call 3 1 0
meth SplashAttentionKernel.manual_sharding_spec 2 1 0
meth SplashAttentionKernel.tree_flatten 1 0 0
meth SplashAttentionKernel.tree_unflatten 3 0 0
attr SplashAttentionKernel.kwargs 1 0 0
attr SplashAttentionKernel.fwd_mask_info 1 0 0
attr SplashAttentionKernel.dq_mask_info 1 0 0
attr SplashAttentionKernel.dkv_mask_info 1 0 0
func flash_attention_kernel 28 11 0

jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask (28 missing, 0 any)

Symbol Typable Typed Any
meth CausalMask.init 4 3 0
meth CausalMask.eq 2 1 0
meth CausalMask.hash 1 0 0
meth NumpyMask.post_init 1 0 0
meth NumpyMask.getitem 2 1 0
meth NumpyMask.eq 2 1 0
meth NumpyMask.hash 1 0 0
meth LocalMask.init 5 4 0
meth LocalMask.eq 2 1 0
meth LocalMask.hash 1 0 0
meth FullMask.post_init 1 0 0
meth FullMask.getitem 2 1 0
meth FullMask.eq 2 1 0
meth FullMask.hash 1 0 0
meth MultiHeadMask.post_init 1 0 0
meth MultiHeadMask.getitem 2 1 0
meth MultiHeadMask.eq 2 1 0
meth MultiHeadMask.hash 1 0 0
meth Mask.getitem 2 1 0
meth LogicalAnd.init 3 2 0
meth LogicalAnd.getitem 2 1 0
meth LogicalAnd.hash 1 0 0
meth LogicalOr.init 3 2 0
meth LogicalOr.getitem 2 1 0
meth LogicalOr.hash 1 0 0
meth ChunkedCausalMask.init 4 3 0
meth ChunkedCausalMask.eq 2 1 0
meth ChunkedCausalMask.hash 1 0 0

jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info (4 missing, 0 any)

Symbol Typable Typed Any
attr process_dynamic_mask_dkv 1 0 0
attr process_mask_dkv 1 0 0
attr process_mask 1 0 0
attr process_dynamic_mask 1 0 0

jax.experimental.rnn (30 missing, 1 any)

Symbol Typable Typed Any
attr rnn_fwd_p 1 0 0
func swap_lstm_gates 6 0 0
func rnn_abstract_eval 12 6 0
func init_lstm_weight 6 5 1
attr rnn_bwd_p 1 0 0
func rnn_bwd_abstract_eval 17 6 0
func lstm_bwd 9 6 0
func lstm_fwd 12 11 0

jax.experimental.roofline.roofline (2 missing, 2 any)

Symbol Typable Typed Any
func roofline_and_grad 8 8 2
func register_standard_roofline 2 1 0
func register_roofline 2 1 0

jax.experimental.scheduling_groups (8 missing, 0 any)

Symbol Typable Typed Any
func xla_metadata_call 3 0 0
func scheduling_group 2 0 0
func attr_get 2 0 0
attr xla_metadata_call_p 1 0 0

jax.experimental.serialize_executable (5 missing, 0 any)

Symbol Typable Typed Any
func deserialize_and_load 6 2 0
func serialize 2 1 0

jax.experimental.shard_map (6 missing, 0 any)

Symbol Typable Typed Any
func shard_map 6 0 0

jax.experimental.source_mapper.common (9 missing, 0 any)

Symbol Typable Typed Any
func compile_with_env 6 0 0
func register_pass 2 1 0
func flag_env 2 0 0

jax.experimental.source_mapper.generate_map (2 missing, 0 any)

Symbol Typable Typed Any
func generate_sourcemaps 4 2 0

jax.experimental.source_mapper.hlo (10 missing, 0 any)

Symbol Typable Typed Any
func stable_hlo_generate_dump 3 2 0
func original_hlo_generate_dump 3 2 0
func optimized_generate_dump 4 3 0
attr METADATA_REGEX 1 0 0
func trace_and_lower 6 0 0

jax.experimental.source_mapper.jaxpr (8 missing, 0 any)

Symbol Typable Typed Any
func canonicalize_filename 2 1 0
func compile_jaxpr 6 0 0
func make_jaxpr_dump 3 2 0

jax.experimental.source_mapper.mlir (4 missing, 0 any)

Symbol Typable Typed Any
attr LOC_REGEX 1 0 0
attr SCOPED_REGEX 1 0 0
attr SRC_REGEX 1 0 0
attr CALLSITE_REGEX 1 0 0

jax.experimental.sparse._base (34 missing, 0 any)

Symbol Typable Typed Any
meth JAXSparse.len 1 0 0
meth JAXSparse.init 3 2 0
meth JAXSparse.repr 1 0 0
meth JAXSparse.tree_flatten 1 0 0
meth JAXSparse.tree_unflatten 3 0 0
meth JAXSparse.transpose 2 0 0
meth JAXSparse.block_until_ready 1 0 0
meth JAXSparse.sum 3 0 0
meth JAXSparse.neg 1 0 0
meth JAXSparse.pos 1 0 0
meth JAXSparse.matmul 2 0 0
meth JAXSparse.rmatmul 2 0 0
meth JAXSparse.mul 2 0 0
meth JAXSparse.rmul 2 0 0
meth JAXSparse.add 2 0 0
meth JAXSparse.radd 2 0 0
meth JAXSparse.sub 2 0 0
meth JAXSparse.rsub 2 0 0
meth JAXSparse.getitem 2 0 0
prop JAXSparse.T 1 0 0

jax.experimental.sparse.ad (9 missing, 0 any)

Symbol Typable Typed Any
func jacrev 5 4 0
func jacfwd 5 4 0
func value_and_grad 5 3 0
func grad 5 3 0
func flatten_fun_for_sparse_ad 4 2 0
attr is_sparse 1 0 0

jax.experimental.sparse.api (3 missing, 0 any)

Symbol Typable Typed Any
func eye 8 7 0
attr todense_p 1 0 0
func empty 6 5 0

jax.experimental.sparse.bcoo (46 missing, 0 any)

Symbol Typable Typed Any
func bcoo_rev 3 0 0
attr bcoo_sort_indices_p 1 0 0
func bcoo_broadcast_in_dim 5 4 0
func bcoo_dot_general 7 6 0
func bcoo_reshape 5 4 0
attr bcoo_spdot_general_p 1 0 0
func bcoo_gather 9 8 0
attr bcoo_fromdense_p 1 0 0
attr bcoo_extract_p 1 0 0
attr bcoo_todense_p 1 0 0
attr bcoo_dot_general_sampled_p 1 0 0
func bcoo_conv_general_dilated 13 1 0
attr bcoo_transpose_p 1 0 0
meth BCOO.init 5 4 0
meth BCOO.repr 1 0 0
meth BCOO.reshape 3 1 0
meth BCOO.astype 3 1 0
meth BCOO.from_scipy_sparse 5 4 0
meth BCOO.tree_flatten 1 0 0
meth BCOO.tree_unflatten 3 0 0
attr BCOO.nse 1 0 0
attr BCOO.dtype 1 0 0
attr BCOO.n_batch 1 0 0
attr BCOO.n_sparse 1 0 0
attr BCOO.n_dense 1 0 0
attr BCOO._info 1 0 0
attr BCOO._bufs 1 0 0
attr bcoo_dot_general_p 1 0 0
attr bcoo_sum_duplicates_p 1 0 0

jax.experimental.sparse.bcsr (42 missing, 0 any)

Symbol Typable Typed Any
attr bcsr_extract_p 1 0 0
func bcsr_dot_general 7 6 0
attr bcsr_todense_p 1 0 0
func bcsr_broadcast_in_dim 5 4 0
attr bcsr_dot_general_p 1 0 0
meth BCSR.init 5 4 0
meth BCSR.repr 1 0 0
meth BCSR.transpose 3 0 0
meth BCSR.tree_flatten 1 0 0
meth BCSR.tree_unflatten 3 0 0
meth BCSR._empty 7 0 0
meth BCSR.fromdense 6 0 0
meth BCSR.todense 1 0 0
meth BCSR.from_scipy_sparse 5 0 0
prop BCSR._sparse_shape 1 0 0
attr BCSR.nse 1 0 0
attr BCSR.dtype 1 0 0
attr BCSR.n_batch 1 0 0
attr BCSR.n_sparse 1 0 0
attr BCSR.n_dense 1 0 0
attr BCSR._bufs 1 0 0
attr BCSR._info 1 0 0
attr bcsr_fromdense_p 1 0 0

jax.experimental.sparse.coo (12 missing, 0 any)

Symbol Typable Typed Any
attr coo_fromdense_p 1 0 0
attr coo_matvec_p 1 0 0
attr coo_todense_p 1 0 0
attr coo_matmat_p 1 0 0
meth COO.init 5 4 0
meth COO.tree_unflatten 3 0 0
attr COO.nse 1 0 0
attr COO.dtype 1 0 0
attr COO._info 1 0 0
attr COO._bufs 1 0 0

jax.experimental.sparse.csr (61 missing, 0 any)

Symbol Typable Typed Any
attr csr_matmat_p 1 0 0
attr csr_matvec_p 1 0 0
attr csr_fromdense_p 1 0 0
attr csr_todense_p 1 0 0
meth CSR.init 3 0 0
meth CSR.fromdense 4 0 0
meth CSR._empty 4 0 0
meth CSR._eye 6 0 0
meth CSR.todense 1 0 0
meth CSR.transpose 2 0 0
meth CSR.matmul 2 0 0
meth CSR.tree_flatten 1 0 0
meth CSR.tree_unflatten 3 0 0
attr CSR.nse 1 0 0
attr CSR.dtype 1 0 0
attr CSR._bufs 1 0 0
meth CSC.init 3 0 0
meth CSC.fromdense 4 0 0
meth CSC._empty 4 0 0
meth CSC._eye 6 0 0
meth CSC.todense 1 0 0
meth CSC.transpose 2 0 0
meth CSC.matmul 2 0 0
meth CSC.tree_flatten 1 0 0
meth CSC.tree_unflatten 3 0 0
attr CSC.nse 1 0 0
attr CSC.dtype 1 0 0

jax.experimental.sparse.linalg (9 missing, 0 any)

Symbol Typable Typed Any
func spsolve 7 0 0
attr spsolve_p 1 0 0
func lobpcg_standard 5 4 0

jax.experimental.sparse.nm (4 missing, 0 any)

Symbol Typable Typed Any
attr nm_pack_p 1 0 0
attr nm_spmm_p 1 0 0
func nm_pack 4 2 0

jax.experimental.sparse.random (12 missing, 0 any)

Symbol Typable Typed Any
func random_bcoo 12 0 0

jax.experimental.sparse.test_util (51 missing, 0 any)

Symbol Typable Typed Any
func iter_bcsr_layouts 3 2 0
func is_sparse 2 0 0
func iter_sparse_layouts 3 2 0
func rand_bcoo 6 5 0
func rand_sparse 5 0 0
meth SparseTestCase.assertSparseArraysEquivalent 8 0 0
meth SparseTestCase._CheckAgainstDense 10 0 0
meth SparseTestCase._CheckGradsSparse 8 0 0
meth SparseTestCase._random_bdims 2 0 0
meth SparseTestCase._CheckBatchingSparse 12 0 0
func rand_bcsr 6 5 0

jax.experimental.sparse.transform (61 missing, 10 any)

Symbol Typable Typed Any
func sparsify 3 0 0
meth SparseTracer.init 3 1 0
meth SparseTracer.full_lower 1 0 0
prop SparseTracer.spenv 1 0 0
prop SparseTracer.aval 1 0 0
func sparsify_fun 3 1 0
meth SparseTrace.init 4 0 0
meth SparseTrace.to_sparse_tracer 2 0 0
meth SparseTrace.process_primitive 4 0 0
meth SparseTrace.process_call 5 1 0
meth SparseTrace.process_custom_jvp_call 6 0 0
attr SparseTrace.parent_trace 1 0 0
attr SparseTrace.tag 1 0 0
attr SparseTrace.spenv 1 0 0
func sparsify_subtrace 7 0 0
func spvalues_to_avals 3 3 2
func spvalues_to_arrays 3 3 2
meth SparsifyValue.is_sparse 1 0 0
meth SparsifyValue.is_dense 1 0 0
meth SparsifyValue.is_bcoo 1 0 0
meth SparsifyValue.is_bcsr 1 0 0
prop SparsifyValue.ndim 1 0 0
func arrays_to_spvalues 3 3 2
func sparsify_raw 2 0 0
meth SparsifyEnv.init 2 0 0
meth SparsifyEnv._push 2 2 1
meth SparsifyEnv.data 2 2 1
meth SparsifyEnv.indices 2 2 1
meth SparsifyEnv.indptr 2 2 1
meth SparsifyEnv.dense 2 0 0
meth SparsifyEnv.sparse 10 0 0

jax.experimental.sparse.util (9 missing, 0 any)

Symbol Typable Typed Any
func broadcasting_vmap 4 0 0
func nfold_vmap 5 0 0

jax.experimental.topologies (3 missing, 0 any)

Symbol Typable Typed Any
func get_attached_topology 2 1 0
meth TopologyDescription.init 2 1 0
func get_topology_desc 4 3 0

jax.experimental.transfer (9 missing, 4 any)

Symbol Typable Typed Any
meth TransferConnection._pull_flat 4 0 0
meth TransferConnection.pull 3 3 2
func make_error_array 3 0 0
meth TransferServer._await_pull_flat 3 1 0
meth TransferServer.await_pull 3 3 2

jax.extend.linear_util (2 missing, 0 any)

Symbol Typable Typed Any
func wrap_init 4 2 0

jax.interpreters.mlir (3 missing, 0 any)

Symbol Typable Typed Any
attr Token 1 0 0
attr token_type 1 0 0
attr dense_int_array 1 0 0

jax.interpreters.xla (1 missing, 0 any)

Symbol Typable Typed Any
attr Backend 1 0 0

jax.nn (0 missing, 2 any)

Symbol Typable Typed Any
func one_hot 5 5 2

jax.numpy (22 missing, 60 any)

Symbol Typable Typed Any
attr uint2 1 1 1
func iscomplexobj 2 2 1
attr complex128 1 1 1
func result_type 2 2 1
func from_dlpack 4 4 1
func diagonal 5 4 0
attr bfloat16 1 1 1
attr cdouble 1 1 1
func empty 5 5 1
func fromfile 3 0 0
func ones_like 6 6 1
func isscalar 2 2 1
func pad 5 4 0
attr int2 1 1 1
attr float8_e5m2fnuz 1 1 1
func einsum 10 8 0
func full_like 6 6 1
attr float32 1 1 1
attr float8_e4m3 1 1 1
attr uint64 1 1 1
attr single 1 1 1
attr float_ 1 1 1
attr float8_e4m3fn 1 1 1
func einsum_path 5 4 0
func piecewise 6 4 0
attr uint8 1 1 1
attr float4_e2m1fn 1 1 1
attr int16 1 1 1
attr int8 1 1 1
func fromfunction 5 4 1
attr int32 1 1 1
attr float16 1 1 1
attr uint4 1 1 1
func array 8 8 1
attr float8_e4m3fnuz 1 1 1
attr int_ 1 1 1
func vectorize 4 1 0
attr csingle 1 1 1
attr float8_e4m3b11fnuz 1 1 1
func isrealobj 2 2 1
attr uint32 1 1 1
func zeros_like 6 6 1
attr uint16 1 1 1
func apply_along_axis 6 4 0
func full 5 5 1
attr complex_ 1 1 1
attr float8_e3m4 1 1 1
attr complex64 1 1 1
meth ufunc.init 12 11 1
meth ufunc.call 2 2 1
meth ufunc.at 5 5 1
meth ufunc.reduceat 6 6 1
func asarray 7 7 1
func frompyfunc 5 5 1
attr float8_e5m2 1 1 1
func empty_like 5 5 1
attr uint 1 1 1
func unique 10 9 0
func load 4 4 2
func modf 3 2 0
func zeros 5 5 1
attr float8_e8m0fnu 1 1 1
attr bool_ 1 1 1
attr bool 1 1 1
attr int64 1 1 1
attr float64 1 1 1
attr int4 1 1 1
attr double 1 1 1
func fromiter 3 0 0
func ones 5 5 1

jax.profiler (3 missing, 0 any)

Symbol Typable Typed Any
attr ProfileEvent 1 0 0
attr ProfileData 1 0 0
attr ProfilePlane 1 0 0

jax.tree_util (1 missing, 0 any)

Symbol Typable Typed Any
attr PyTreeDef 1 0 0

jax.version (2 missing, 0 any)

Symbol Typable Typed Any
attr version_info 1 0 0
attr version 1 0 0

Type-Ignore Comments

Flavor Count
type: ignore 88
type: ignore[assignment] 25
pyrefly: ignore[bad-override] 23
type: ignore[arg-type] 11
type: ignore[missing-attribute] 6
pyrefly: ignore[bad-param-name-override] 5
type: ignore[method-assign] 5
type: ignore[misc] 5
pyrefly: ignore[missing-attribute] 3
type: ignore[override] 3
type: ignore[bad-return-type, unused-ignore] 2
pyrefly: ignore[bad-assignment] 1
pyrefly: ignore[bad-class-definition] 1
type: ignore[assignment, bad-override] 1
type: ignore[bad-assignment] 1
type: ignore[import-not-found] 1