Accelerating Penetration Testing Simulations: NASim Rewrite in JAX

In the context of the SLATE (Self-Learning Attack Surface Explorer) project, we’re looking for an intern to re-write one of the Reinforcement Learning (RL) environments we’re using in our research. Specifically, the environment shall be re-written using Google JAX, a high-performance numerical computing library that combines NumPy’s familiar API with GPU acceleration and automatic differentiation for machine learning research. The environment in question is the Network Attack Simulator (NASim). Within NASim, an agent has different scan, exploit, and privilege escalation actions in its arsenal. The goal of the agent is to learn a policy, to use these actions in an order that allows it to advance through the network, and gain root privileges on sensitive hosts.

Reinforcement Learning has long suffered from a performance bottleneck that arises due to the environment, with whom the agent interacts, sitting on the CPU, while most other computations are happening on the GPU. The Google JAX framework allows for environments to be put onto the GPU as well. Therefore eliminating the bottleneck. Some RL environments have already been rewritten in JAX, and shown to run up to 250x faster than the Python-native implementation.

With a faster environment, the policy can be trained over a larger number of episodes, in the same time-frame. This would allow for larger hyperparameter-searches, and the usage of more complex scenarios, that resemble the challenges faced in real-world penetration testing more closely.

During your internship, you will get the possibility to learn JAX, or gain more hands-on experience. You’ll also get the chance to discover more about the field of Reinforcement Learning. You’ll get to work with high-end equipment to validate your implementation and the expected speed-up.

Goal

The goal is to have implemented a JAX version of NASim with comprehensive tests demonstrating the expected performance improvements.

Expected outcome

  • NASimJAX, a rewrite of NASim in JAX
  • At least one blog post
  • A poster
  • A final presentation of your work

Required skills

To start this project you should have some knowledge of:

  • Python Programming and the NumPy library
  • Usage of Git/version control
  • Ability to document code and technical processes clearly

Nice to haves:

  • Experience in Reinforcement Learning
  • Experience in Functional Programming
  • Experience with JAX

Conditions

Applicant’s country of origin must be a member of EU or NATO

Interested?

Contact us

This website uses cookies. More information about the use of cookies is available in the cookies policy.
Accept