| import torch |
| from torch.profiler import profile |
| from bit_transformer import ( |
| BitTransformerLM, |
| quantize_dynamic, |
| hil_safe_inference, |
| collapse_submodel, |
| ) |
| from bit_transformer.training import train_loop |
| from bit_transformer.torch_utils import cpu_autocast |
|
|
| def train( |
| model: BitTransformerLM, |
| data: torch.Tensor, |
| epochs: int = 3, |
| compress_prob: float = 0.5, |
| direct_prob: float = 0.0, |
| log: bool = False, |
| forward_kwargs: dict | None = None, |
| ) -> list[dict]: |
| """Train on bit sequences with optional random compression. |
| |
| If ``direct_prob`` is positive, some batches are fed using their |
| run-length encoded representation packed into bits. Loss on these |
| direct-compressed batches is tracked separately. |
| |
| Returns a list of per-epoch metric dictionaries containing raw and |
| compressed loss/accuracy statistics and the mean compression ratio. |
| """ |
| return train_loop( |
| model, |
| data, |
| epochs=epochs, |
| compress_prob=compress_prob, |
| direct_prob=direct_prob, |
| log=log, |
| forward_kwargs=forward_kwargs, |
| ) |
|
|
|
|
| def main() -> None: |
| data = torch.randint(0, 2, (64, 128), dtype=torch.long) |
| validation_bits = torch.randint(0, 2, (16, 128), dtype=torch.long) |
| input_bits = torch.randint(0, 2, (1, 128), dtype=torch.long) |
| bit_sequence_data = data.tolist() |
|
|
| model = BitTransformerLM( |
| d_model=32, |
| nhead=4, |
| num_layers=1, |
| dim_feedforward=64, |
| max_seq_len=128, |
| use_act=True, |
| act_threshold=0.7, |
| reversible=True, |
| chunk_size=128, |
| ) |
|
|
| for step in range(1, 13): |
| if step % 2 == 0: |
| model = model.double_width() |
| else: |
| model = model.double_layers() |
| train(model, data, epochs=3, compress_prob=0.5, log=True) |
| _, telemetry = model(validation_bits) |
| K = telemetry["negentropy_logits"].mean().item() |
| C = telemetry["lz_complexity_logits"].mean().item() |
| S = telemetry["symbiosis_score"].mean().item() |
| assert ( |
| K > 0.3 and C > 0.35 and S > 0.5 |
| ), f"Step {step} telemetry floor failure" |
|
|
| with cpu_autocast(): |
| model(input_bits) |
|
|
| quantized_model = quantize_dynamic(model) |
| quantized_model.eval() |
|
|
| safe_output, _ = hil_safe_inference( |
| quantized_model, input_bits, c_floor=0.35, s_floor=0.5 |
| ) |
|
|
| student_model, _ = collapse_submodel( |
| bit_sequence_data, |
| target_params=dict( |
| d_model=16, |
| nhead=4, |
| num_layers=1, |
| dim_feedforward=32, |
| max_seq_len=128, |
| ), |
| floors={"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}, |
| ) |
|
|
| compiled_model = ( |
| torch.compile(student_model) |
| if hasattr(torch, "compile") |
| else student_model |
| ) |
| compiled_model.eval() |
|
|
| with profile() as prof: |
| compiled_model(input_bits) |
|
|
| prof.export_chrome_trace("trace12.json") |
| print("Safe output bits:", safe_output.squeeze(0).tolist()) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|