-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathnanoeuler_cuda.cu
More file actions
1399 lines (1350 loc) · 97.9 KB
/
Copy pathnanoeuler_cuda.cu
File metadata and controls
1399 lines (1350 loc) · 97.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
============================================================================
nanoeuler — CUDA kernels (step 1: matmul + self-test).
This file is developed on a machine WITHOUT a GPU, so it is validated on the
user's GPU: the self-test below runs the CUDA kernel and compares its output
against a CPU reference, printing the maximum error. If the error is ~1e-5 or
smaller, the kernel is correct on your hardware.
Build (RTX 40-series = Ada = sm_89):
nvcc -O3 -arch=sm_89 nanoeuler_cuda.cu -o nanoeuler_cuda
./nanoeuler_cuda
This is the first kernel of the GPU port (the matmul, which dominates compute).
The remaining kernels (RMSNorm, RoPE, attention, SwiGLU, MTP) follow the same
pattern and mirror ../nanoeuler.c. For a production build the matmuls are best
delegated to cuBLAS; this hand-written kernel is the reference/study version.
============================================================================
*/
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
/* deliberately NOT including <math.h>: on some systems glibc's <math.h> clashes with
CUDA's math headers (rsqrt exception-spec error). We only need an absolute value. */
static inline double dabs(double x){ return x<0 ? -x : x; }
#define TILE 16
#define CHECK(call) do { cudaError_t e=(call); if(e!=cudaSuccess){ \
fprintf(stderr,"CUDA error %s at %s:%d\n",cudaGetErrorString(e),__FILE__,__LINE__); exit(1);} } while(0)
/* FlashAttention tile sizes + forward declarations (definitions are further down) */
#define FA_BR 32
#define FA_BC 32
__global__ void flash_attn_forward_kernel(float*out,float*lse,const float*qkv,int B,int T,int C,int NH,int NKV);
__global__ void flash_attn_backward_kernel(float*dqkv,const float*qkv,const float*out,const float*dout,const float*lse,int B,int T,int C,int NH,int NKV);
/* out[r,o] = sum_i inp[r,i] * w[o,i], with w stored row-major as [OC][C].
Viewed as a GEMM (BT x C) * (C x OC)^T -> (BT x OC), tiled in shared memory. */
__global__ void matmul_forward_kernel(float* out, const float* inp,
const float* w, int BT, int C, int OC) {
__shared__ float As[TILE][TILE];
__shared__ float Bs[TILE][TILE];
int row = blockIdx.y * TILE + threadIdx.y; /* index along BT */
int col = blockIdx.x * TILE + threadIdx.x; /* index along OC */
float acc = 0.0f;
for (int k0 = 0; k0 < C; k0 += TILE) {
int ak = k0 + threadIdx.x;
As[threadIdx.y][threadIdx.x] = (row < BT && ak < C) ? inp[(size_t)row * C + ak] : 0.0f;
int bk = k0 + threadIdx.y; /* w[col][bk] = element (bk,col) of w^T */
Bs[threadIdx.y][threadIdx.x] = (col < OC && bk < C) ? w[(size_t)col * C + bk] : 0.0f;
__syncthreads();
for (int k = 0; k < TILE; k++) acc += As[threadIdx.y][k] * Bs[k][threadIdx.x];
__syncthreads();
}
if (row < BT && col < OC) out[(size_t)row * OC + col] = acc;
}
/* cuBLAS handle (lazily created) — optimized GEMM instead of the naive tiled kernel */
static cublasHandle_t g_cub=0;
static void cub_init(){ if(!g_cub){ cublasCreate(&g_cub); cublasSetMathMode(g_cub,CUBLAS_TF32_TENSOR_OP_MATH); } } /* tensor core */
/* row-major out[BT,OC] = inp[BT,C] * w[OC,C]^T, expressed for column-major cuBLAS. */
void matmul_forward_cuda(float* d_out, const float* d_inp, const float* d_w, int BT, int C, int OC) {
cub_init(); const float one=1.0f, zero=0.0f;
cublasSgemm(g_cub, CUBLAS_OP_T, CUBLAS_OP_N, OC, BT, C, &one, d_w, C, d_inp, C, &zero, d_out, OC);
}
/* dinp[BT,C] += dout[BT,OC] * w[OC,C] (accumulate, beta=1) */
static void matmul_backward_dinp_cuda(float*dinp,const float*dout,const float*w,int BT,int C,int OC){
cub_init(); const float one=1.0f;
cublasSgemm(g_cub, CUBLAS_OP_N, CUBLAS_OP_N, C, BT, OC, &one, w, C, dout, OC, &one, dinp, C);
}
/* dw[OC,C] += dout^T * inp (accumulate, beta=1) */
static void matmul_backward_dw_cuda(float*dw,const float*dout,const float*inp,int BT,int C,int OC){
cub_init(); const float one=1.0f;
cublasSgemm(g_cub, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, BT, &one, inp, C, dout, OC, &one, dw, C);
}
/* CPU reference, identical math to matmul_forward in ../nanoeuler.c */
static void matmul_cpu(float* out, const float* inp, const float* w, int BT, int C, int OC) {
for (int r = 0; r < BT; r++)
for (int o = 0; o < OC; o++) {
float v = 0.0f;
for (int i = 0; i < C; i++) v += inp[(size_t)r*C+i] * w[(size_t)o*C+i];
out[(size_t)r*OC+o] = v;
}
}
/* max relative error, with a +1.0 floor so near-zero elements from random tests
don't inflate the metric (a real O(1) kernel error still fails). */
static void check(const char*name,const float*a,const float*b,size_t n){
double maxrel=0;
for(size_t i=0;i<n;i++){ double d=dabs(a[i]-b[i]);
double den=dabs(a[i])+dabs(b[i])+1.0; if(d/den>maxrel) maxrel=d/den; }
printf(" %-24s max rel err %.2e %s\n", name, maxrel, maxrel<1e-2 ? "OK" : "FAIL <<<"); /* 1e-2: TF32 tensor-core precision */
}
/* ===== RMSNorm: one thread per (batch,token) row, mirrors ../nanoeuler.c ===== */
__global__ void rmsnorm_forward_kernel(float*out,float*rstd,const float*inp,const float*g,int BT,int C){
int bt=blockIdx.x*blockDim.x+threadIdx.x; if(bt>=BT) return;
const float*x=inp+(size_t)bt*C; float*o=out+(size_t)bt*C;
float s=0; for(int i=0;i<C;i++) s+=x[i]*x[i]; float rs=1.0f/sqrtf(s/C+1e-5f);
for(int i=0;i<C;i++) o[i]=x[i]*rs*g[i]; rstd[bt]=rs;
}
__global__ void rmsnorm_backward_kernel(float*dinp,float*dg,const float*dout,const float*inp,
const float*g,const float*rstd,int BT,int C){
int bt=blockIdx.x*blockDim.x+threadIdx.x; if(bt>=BT) return;
const float*x=inp+(size_t)bt*C,*dop=dout+(size_t)bt*C; float*dip=dinp+(size_t)bt*C; float rs=rstd[bt];
float q=0; for(int i=0;i<C;i++) q+=dop[i]*g[i]*x[i]; q*=rs*rs*rs/C;
for(int i=0;i<C;i++){ atomicAdd(&dg[i], dop[i]*x[i]*rs); dip[i]+=g[i]*dop[i]*rs - x[i]*q; } /* += accumulate into residual stream; dg atomic */
}
static void rmsnorm_cpu_fwd(float*out,float*rstd,const float*inp,const float*g,int BT,int C){
for(int bt=0;bt<BT;bt++){ const float*x=inp+(size_t)bt*C; float*o=out+(size_t)bt*C;
float s=0; for(int i=0;i<C;i++) s+=x[i]*x[i]; float rs=1.0f/sqrtf(s/C+1e-5f);
for(int i=0;i<C;i++) o[i]=x[i]*rs*g[i]; rstd[bt]=rs; }
}
static void rmsnorm_cpu_bwd(float*dinp,float*dg,const float*dout,const float*inp,
const float*g,const float*rstd,int BT,int C){
for(int bt=0;bt<BT;bt++){ const float*x=inp+(size_t)bt*C,*dop=dout+(size_t)bt*C;
float*dip=dinp+(size_t)bt*C; float rs=rstd[bt];
float q=0; for(int i=0;i<C;i++) q+=dop[i]*g[i]*x[i]; q*=rs*rs*rs/C;
for(int i=0;i<C;i++){ dg[i]+=dop[i]*x[i]*rs; dip[i]+=g[i]*dop[i]*rs - x[i]*q; } }
}
/* ===== RoPE: rotate dim pairs of each head; inverse=1 is the transposed rotation
(the backward pass). One thread per (batch,token,head). Mirrors ../nanoeuler.c ===== */
__global__ void rope_kernel(float*buf,int base_off,int stride,int nheads,int hs,int B,int T,int inverse){
int idx=blockIdx.x*blockDim.x+threadIdx.x; if(idx>=B*T*nheads) return;
int h=idx%nheads, t=(idx/nheads)%T, b=idx/(nheads*T);
float*v=buf+(size_t)(b*T+t)*stride + base_off + h*hs;
for(int i=0;i<hs/2;i++){ float freq=powf(10000.0f,-2.0f*i/hs); float ang=t*freq;
float cs=cosf(ang),sn=sinf(ang); float x0=v[2*i],x1=v[2*i+1];
if(!inverse){ v[2*i]=x0*cs-x1*sn; v[2*i+1]=x0*sn+x1*cs; }
else { v[2*i]=x0*cs+x1*sn; v[2*i+1]=-x0*sn+x1*cs; }
}
}
static void rope_cpu(float*buf,int base_off,int stride,int nheads,int hs,int B,int T,int inverse){
for(int b=0;b<B;b++) for(int t=0;t<T;t++) for(int h=0;h<nheads;h++){
float*v=buf+(size_t)(b*T+t)*stride + base_off + h*hs;
for(int i=0;i<hs/2;i++){ float freq=powf(10000.0f,-2.0f*i/hs); float ang=t*freq;
float cs=cosf(ang),sn=sinf(ang); float x0=v[2*i],x1=v[2*i+1];
if(!inverse){ v[2*i]=x0*cs-x1*sn; v[2*i+1]=x0*sn+x1*cs; }
else { v[2*i]=x0*cs+x1*sn; v[2*i+1]=-x0*sn+x1*cs; } }
}
}
/* ===== Causal multi-head attention with GQA. One thread per (batch,head,query t). ===== */
__global__ void attention_forward_kernel(float*out,float*preatt,float*att,const float*qkv,
int B,int T,int C,int NH,int NKV){
int idx=blockIdx.x*blockDim.x+threadIdx.x; if(idx>=B*NH*T) return;
int t=idx%T, h=(idx/T)%NH, b=idx/(NH*T);
int hs=C/NH,kvd=NKV*hs,QKV=C+2*kvd,grp=NH/NKV,kh=h/grp; float scale=1.0f/sqrtf((float)hs);
const float*q=qkv+(size_t)(b*T+t)*QKV + h*hs;
float*pa=preatt+(((size_t)b*NH+h)*T+t)*T, *a=att+(((size_t)b*NH+h)*T+t)*T;
float maxv=-1e30f;
for(int t2=0;t2<=t;t2++){ const float*k=qkv+(size_t)(b*T+t2)*QKV + C + kh*hs;
float v=0; for(int i=0;i<hs;i++) v+=q[i]*k[i]; v*=scale; pa[t2]=v; if(v>maxv)maxv=v; }
float sum=0; for(int t2=0;t2<=t;t2++){ float e=expf(pa[t2]-maxv); a[t2]=e; sum+=e; }
float inv=sum>0?1.0f/sum:0.0f; for(int t2=0;t2<=t;t2++) a[t2]*=inv;
for(int t2=t+1;t2<T;t2++){ a[t2]=0; pa[t2]=0; }
float*o=out+(size_t)(b*T+t)*C + h*hs; for(int i=0;i<hs;i++) o[i]=0;
for(int t2=0;t2<=t;t2++){ const float*v=qkv+(size_t)(b*T+t2)*QKV + C + kvd + kh*hs;
float aw=a[t2]; for(int i=0;i<hs;i++) o[i]+=aw*v[i]; }
}
__global__ void attention_backward_kernel(float*dqkv,float*dpreatt,float*datt,const float*dout,
const float*qkv,const float*att,int B,int T,int C,int NH,int NKV){
int idx=blockIdx.x*blockDim.x+threadIdx.x; if(idx>=B*NH*T) return;
int t=idx%T, h=(idx/T)%NH, b=idx/(NH*T);
int hs=C/NH,kvd=NKV*hs,QKV=C+2*kvd,grp=NH/NKV,kh=h/grp; float scale=1.0f/sqrtf((float)hs);
const float*a=att+(((size_t)b*NH+h)*T+t)*T;
float*da=datt+(((size_t)b*NH+h)*T+t)*T, *dpa=dpreatt+(((size_t)b*NH+h)*T+t)*T;
const float*dop=dout+(size_t)(b*T+t)*C + h*hs;
for(int t2=0;t2<=t;t2++){ const float*v=qkv+(size_t)(b*T+t2)*QKV + C + kvd + kh*hs;
float*dv=dqkv+(size_t)(b*T+t2)*QKV + C + kvd + kh*hs; float dat=0;
for(int i=0;i<hs;i++){ dat+=dop[i]*v[i]; atomicAdd(&dv[i], a[t2]*dop[i]); } da[t2]=dat; } /* dv shared -> atomic */
for(int t2=0;t2<=t;t2++){ float d=0;
for(int t3=0;t3<=t;t3++){ float ind=(t2==t3)?1.0f:0.0f; d+=a[t3]*(ind-a[t2])*da[t3]; }
dpa[t2]=d; }
const float*q=qkv+(size_t)(b*T+t)*QKV + h*hs; float*dq=dqkv+(size_t)(b*T+t)*QKV + h*hs;
for(int t2=0;t2<=t;t2++){ const float*k=qkv+(size_t)(b*T+t2)*QKV + C + kh*hs;
float*dk=dqkv+(size_t)(b*T+t2)*QKV + C + kh*hs; float dp=dpa[t2]*scale;
for(int i=0;i<hs;i++){ dq[i]+=dp*k[i]; atomicAdd(&dk[i], dp*q[i]); } } /* dq unique, dk shared -> atomic */
}
static void attention_cpu_fwd(float*out,float*preatt,float*att,const float*qkv,int B,int T,int C,int NH,int NKV){
int hs=C/NH,kvd=NKV*hs,QKV=C+2*kvd,grp=NH/NKV; float scale=1.0f/sqrtf((float)hs);
for(int b=0;b<B;b++) for(int h=0;h<NH;h++) for(int t=0;t<T;t++){
int kh=h/grp; const float*q=qkv+(size_t)(b*T+t)*QKV+h*hs;
float*pa=preatt+(((size_t)b*NH+h)*T+t)*T,*a=att+(((size_t)b*NH+h)*T+t)*T;
float maxv=-1e30f;
for(int t2=0;t2<=t;t2++){ const float*k=qkv+(size_t)(b*T+t2)*QKV+C+kh*hs;
float v=0; for(int i=0;i<hs;i++) v+=q[i]*k[i]; v*=scale; pa[t2]=v; if(v>maxv)maxv=v; }
float sum=0; for(int t2=0;t2<=t;t2++){ float e=expf(pa[t2]-maxv); a[t2]=e; sum+=e; }
float inv=sum>0?1.0f/sum:0.0f; for(int t2=0;t2<=t;t2++) a[t2]*=inv;
for(int t2=t+1;t2<T;t2++){ a[t2]=0; pa[t2]=0; }
float*o=out+(size_t)(b*T+t)*C+h*hs; for(int i=0;i<hs;i++) o[i]=0;
for(int t2=0;t2<=t;t2++){ const float*v=qkv+(size_t)(b*T+t2)*QKV+C+kvd+kh*hs;
float aw=a[t2]; for(int i=0;i<hs;i++) o[i]+=aw*v[i]; }
}
}
static void attention_cpu_bwd(float*dqkv,float*dpreatt,float*datt,const float*dout,const float*qkv,const float*att,int B,int T,int C,int NH,int NKV){
int hs=C/NH,kvd=NKV*hs,QKV=C+2*kvd,grp=NH/NKV; float scale=1.0f/sqrtf((float)hs);
for(int b=0;b<B;b++) for(int h=0;h<NH;h++) for(int t=0;t<T;t++){
int kh=h/grp; const float*a=att+(((size_t)b*NH+h)*T+t)*T;
float*da=datt+(((size_t)b*NH+h)*T+t)*T,*dpa=dpreatt+(((size_t)b*NH+h)*T+t)*T;
const float*dop=dout+(size_t)(b*T+t)*C+h*hs;
for(int t2=0;t2<=t;t2++){ const float*v=qkv+(size_t)(b*T+t2)*QKV+C+kvd+kh*hs;
float*dv=dqkv+(size_t)(b*T+t2)*QKV+C+kvd+kh*hs; float dat=0;
for(int i=0;i<hs;i++){ dat+=dop[i]*v[i]; dv[i]+=a[t2]*dop[i]; } da[t2]=dat; }
for(int t2=0;t2<=t;t2++){ float d=0;
for(int t3=0;t3<=t;t3++){ float ind=(t2==t3)?1.0f:0.0f; d+=a[t3]*(ind-a[t2])*da[t3]; }
dpa[t2]=d; }
const float*q=qkv+(size_t)(b*T+t)*QKV+h*hs; float*dq=dqkv+(size_t)(b*T+t)*QKV+h*hs;
for(int t2=0;t2<=t;t2++){ const float*k=qkv+(size_t)(b*T+t2)*QKV+C+kh*hs;
float*dk=dqkv+(size_t)(b*T+t2)*QKV+C+kh*hs; float dp=dpa[t2]*scale;
for(int i=0;i<hs;i++){ dq[i]+=dp*k[i]; dk[i]+=dp*q[i]; } }
}
}
/* ===== SwiGLU activation: hsil = silu(gate) * up (element-wise) ===== */
__global__ void swiglu_forward_kernel(float*hsil,const float*gate,const float*up,size_t n){
size_t i=(size_t)blockIdx.x*blockDim.x+threadIdx.x; if(i>=n) return;
float g=gate[i],sig=1.0f/(1.0f+expf(-g)); hsil[i]=(g*sig)*up[i];
}
__global__ void swiglu_backward_kernel(float*dgate,float*dup,const float*dhsil,
const float*gate,const float*up,size_t n){
size_t i=(size_t)blockIdx.x*blockDim.x+threadIdx.x; if(i>=n) return;
float g=gate[i],u=up[i],sig=1.0f/(1.0f+expf(-g)),sg=g*sig,dsg=sig*(1.0f+g*(1.0f-sig));
dup[i]=dhsil[i]*sg; dgate[i]=dhsil[i]*u*dsg;
}
static void swiglu_cpu_fwd(float*hsil,const float*gate,const float*up,size_t n){
for(size_t i=0;i<n;i++){ float g=gate[i],sig=1.0f/(1.0f+expf(-g)); hsil[i]=(g*sig)*up[i]; }
}
static void swiglu_cpu_bwd(float*dgate,float*dup,const float*dhsil,const float*gate,const float*up,size_t n){
for(size_t i=0;i<n;i++){ float g=gate[i],u=up[i],sig=1.0f/(1.0f+expf(-g)),sg=g*sig,dsg=sig*(1.0f+g*(1.0f-sig));
dup[i]=dhsil[i]*sg; dgate[i]=dhsil[i]*u*dsg; }
}
/* ===== Softmax + cross-entropy over the MTP heads. One thread per (token, head j). ===== */
__global__ void softmax_ce_kernel(float*probs,float*dlogits,const float*logits,const int*targets,
int BT,int V,int K,float scale){
int idx=blockIdx.x*blockDim.x+threadIdx.x; if(idx>=BT*K) return;
int j=idx%K, bt=idx/K;
const float*l=logits+(size_t)bt*K*V+(size_t)j*V; float*p=probs+(size_t)bt*K*V+(size_t)j*V;
float maxv=-1e30f; for(int i=0;i<V;i++) if(l[i]>maxv)maxv=l[i];
float sum=0; for(int i=0;i<V;i++){ float e=expf(l[i]-maxv); p[i]=e; sum+=e; }
float inv=1.0f/sum; for(int i=0;i<V;i++) p[i]*=inv;
int tg=targets[bt*K+j]; float*dl=dlogits+(size_t)bt*K*V+(size_t)j*V;
if(tg<0){ for(int i=0;i<V;i++) dl[i]=0.0f; return; } /* masked position (SFT): no gradient */
for(int i=0;i<V;i++) dl[i]=(p[i]-(i==tg?1.0f:0.0f))*scale;
}
static void softmax_ce_cpu(float*probs,float*dlogits,const float*logits,const int*targets,int BT,int V,int K,float scale){
for(int idx=0;idx<BT*K;idx++){ int j=idx%K,bt=idx/K;
const float*l=logits+(size_t)bt*K*V+(size_t)j*V; float*p=probs+(size_t)bt*K*V+(size_t)j*V;
float maxv=-1e30f; for(int i=0;i<V;i++) if(l[i]>maxv)maxv=l[i];
float sum=0; for(int i=0;i<V;i++){ float e=expf(l[i]-maxv); p[i]=e; sum+=e; }
float inv=1.0f/sum; for(int i=0;i<V;i++) p[i]*=inv;
int tg=targets[bt*K+j]; float*dl=dlogits+(size_t)bt*K*V+(size_t)j*V;
if(tg<0){ for(int i=0;i<V;i++) dl[i]=0.0f; continue; } /* masked position (SFT) */
for(int i=0;i<V;i++) dl[i]=(p[i]-(i==tg?1.0f:0.0f))*scale; }
}
/* ===== AdamW parameter update (element-wise). c1,c2 are the bias corrections. ===== */
__global__ void adamw_kernel(float*params,const float*grads,float*m,float*v,size_t n,
float lr,float b1,float b2,float eps,float wd,float c1,float c2){
size_t i=(size_t)blockIdx.x*blockDim.x+threadIdx.x; if(i>=n) return;
float gr=grads[i]; m[i]=b1*m[i]+(1-b1)*gr; v[i]=b2*v[i]+(1-b2)*gr*gr;
float mh=m[i]/c1,vh=v[i]/c2; params[i]-=lr*(mh/(sqrtf(vh)+eps)+wd*params[i]);
}
static void adamw_cpu(float*params,const float*grads,float*m,float*v,size_t n,
float lr,float b1,float b2,float eps,float wd,float c1,float c2){
for(size_t i=0;i<n;i++){ float gr=grads[i]; m[i]=b1*m[i]+(1-b1)*gr; v[i]=b2*v[i]+(1-b2)*gr*gr;
float mh=m[i]/c1,vh=v[i]/c2; params[i]-=lr*(mh/(sqrtf(vh)+eps)+wd*params[i]); }
}
/* ===== encoder (token embedding lookup) and residual add ===== */
__global__ void encoder_forward_kernel(float*out,const int*ids,const float*tok,int BT,int C){
int bt=blockIdx.x*blockDim.x+threadIdx.x; if(bt>=BT) return;
int id=ids[bt]; const float*e=tok+(size_t)id*C; float*o=out+(size_t)bt*C;
for(int i=0;i<C;i++) o[i]=e[i];
}
__global__ void residual_add_kernel(float*out,const float*a,const float*b,size_t n){
size_t i=(size_t)blockIdx.x*blockDim.x+threadIdx.x; if(i>=n) return; out[i]=a[i]+b[i];
}
static void encoder_cpu(float*out,const int*ids,const float*tok,int BT,int C){
for(int bt=0;bt<BT;bt++){ int id=ids[bt]; for(int i=0;i<C;i++) out[(size_t)bt*C+i]=tok[(size_t)id*C+i]; }
}
/* ===== matmul backward: dinp[r,i]+=sum_o dout[r,o]*w[o,i] ; dw[o,i]+=sum_r dout[r,o]*inp[r,i] ===== */
__global__ void matmul_backward_dinp_kernel(float*dinp,const float*dout,const float*w,int BT,int C,int OC){
int idx=blockIdx.x*blockDim.x+threadIdx.x; if(idx>=BT*C) return;
int i=idx%C, r=idx/C; float acc=0;
for(int o=0;o<OC;o++) acc+=dout[(size_t)r*OC+o]*w[(size_t)o*C+i];
dinp[(size_t)r*C+i]+=acc;
}
__global__ void matmul_backward_dw_kernel(float*dw,const float*dout,const float*inp,int BT,int C,int OC){
int idx=blockIdx.x*blockDim.x+threadIdx.x; if(idx>=OC*C) return;
int i=idx%C, o=idx/C; float acc=0;
for(int r=0;r<BT;r++) acc+=dout[(size_t)r*OC+o]*inp[(size_t)r*C+i];
dw[(size_t)o*C+i]+=acc;
}
static void matmul_backward_cpu(float*dinp,float*dw,const float*dout,const float*inp,const float*w,int BT,int C,int OC){
for(int r=0;r<BT;r++) for(int o=0;o<OC;o++){ float d=dout[(size_t)r*OC+o];
for(int i=0;i<C;i++){ dinp[(size_t)r*C+i]+=w[(size_t)o*C+i]*d; dw[(size_t)o*C+i]+=d*inp[(size_t)r*C+i]; } }
}
/* ===== encoder backward: dtok[id] += dout (token ids repeat -> atomic) ===== */
__global__ void encoder_backward_kernel(float*dtok,const float*dout,const int*ids,int BT,int C){
int idx=blockIdx.x*blockDim.x+threadIdx.x; if(idx>=BT*C) return;
int i=idx%C, bt=idx/C; int id=ids[bt];
atomicAdd(&dtok[(size_t)id*C+i], dout[(size_t)bt*C+i]);
}
static void encoder_backward_cpu(float*dtok,const float*dout,const int*ids,int BT,int C){
for(int bt=0;bt<BT;bt++){ int id=ids[bt]; for(int i=0;i<C;i++) dtok[(size_t)id*C+i]+=dout[(size_t)bt*C+i]; }
}
/* ===== full-model forward: GPU orchestration vs CPU reference ===== */
typedef struct { int V,C,NH,NKV,NL,T,H,K; } Cfg;
static int qkvd(Cfg c){ int hs=c.C/c.NH; return c.C+2*c.NKV*hs; }
typedef struct { float *tok,*rms1g,*qkvw,*attprojw,*rms2g,*gatew,*upw,*downw,*rmsfg,*headw; } W;
static size_t wsizes(Cfg c,size_t*s){ int C=c.C,V=c.V,NL=c.NL,H=c.H,K=c.K,QKV=qkvd(c);
s[0]=(size_t)V*C; s[1]=(size_t)NL*C; s[2]=(size_t)NL*QKV*C; s[3]=(size_t)NL*C*C; s[4]=(size_t)NL*C;
s[5]=(size_t)NL*H*C; s[6]=(size_t)NL*H*C; s[7]=(size_t)NL*C*H; s[8]=(size_t)C; s[9]=(size_t)K*V*C;
size_t tot=0; for(int i=0;i<10;i++) tot+=s[i]; return tot; }
static void wset(W*w,float*base,Cfg c){ size_t s[10]; wsizes(c,s); float*p=base;
w->tok=p;p+=s[0]; w->rms1g=p;p+=s[1]; w->qkvw=p;p+=s[2]; w->attprojw=p;p+=s[3]; w->rms2g=p;p+=s[4];
w->gatew=p;p+=s[5]; w->upw=p;p+=s[6]; w->downw=p;p+=s[7]; w->rmsfg=p;p+=s[8]; w->headw=p;p+=s[9]; }
/* CPU forward, returns mean loss; writes logits for comparison. Mirrors ../nanoeuler.c. */
static float forward_cpu(Cfg c,W w,const int*ids,const int*tgt,int B,float*logits_out){
int C=c.C,T=c.T,NL=c.NL,NH=c.NH,NKV=c.NKV,V=c.V,H=c.H,K=c.K,QKV=qkvd(c),hs=C/NH;
size_t BT=(size_t)B*T,BTC=BT*C,ATT=(size_t)B*NH*T*T,BTH=BT*H;
float*enc=(float*)malloc(BTC*sizeof(float)),*rms=(float*)malloc(BTC*sizeof(float)),*rstd=(float*)malloc(BT*sizeof(float));
float*qkv=(float*)malloc(BT*QKV*sizeof(float)),*atty=(float*)malloc(BTC*sizeof(float));
float*pre=(float*)malloc(ATT*sizeof(float)),*att=(float*)malloc(ATT*sizeof(float));
float*aproj=(float*)malloc(BTC*sizeof(float)),*res2=(float*)malloc(BTC*sizeof(float)),*res3=(float*)malloc(BTC*sizeof(float));
float*gate=(float*)malloc(BTH*sizeof(float)),*up=(float*)malloc(BTH*sizeof(float)),*hsil=(float*)malloc(BTH*sizeof(float)),*mlp=(float*)malloc(BTC*sizeof(float));
float*probs=(float*)malloc(BT*K*V*sizeof(float)),*dl=(float*)malloc(BT*K*V*sizeof(float));
encoder_cpu(enc,ids,w.tok,(int)BT,C); float*res=enc;
for(int l=0;l<NL;l++){
rmsnorm_cpu_fwd(rms,rstd,res,w.rms1g+l*C,B*T,C);
matmul_cpu(qkv,rms,w.qkvw+(size_t)l*QKV*C,B*T,C,QKV);
rope_cpu(qkv,0,QKV,NH,hs,B,T,0); rope_cpu(qkv,C,QKV,NKV,hs,B,T,0);
attention_cpu_fwd(atty,pre,att,qkv,B,T,C,NH,NKV);
matmul_cpu(aproj,atty,w.attprojw+(size_t)l*C*C,B*T,C,C);
for(size_t i=0;i<BTC;i++) res2[i]=res[i]+aproj[i];
rmsnorm_cpu_fwd(rms,rstd,res2,w.rms2g+l*C,B*T,C);
matmul_cpu(gate,rms,w.gatew+(size_t)l*H*C,B*T,C,H);
matmul_cpu(up,rms,w.upw+(size_t)l*H*C,B*T,C,H);
swiglu_cpu_fwd(hsil,gate,up,BTH);
matmul_cpu(mlp,hsil,w.downw+(size_t)l*C*H,B*T,H,C);
for(size_t i=0;i<BTC;i++) res3[i]=res2[i]+mlp[i];
res=res3;
}
rmsnorm_cpu_fwd(rms,rstd,res,w.rmsfg,B*T,C);
matmul_cpu(logits_out,rms,w.headw,B*T,C,K*V);
softmax_ce_cpu(probs,dl,logits_out,tgt,(int)BT,V,K,1.0f/((float)BT*K));
float loss=0; for(size_t bt=0;bt<BT;bt++) for(int j=0;j<K;j++){ int tt=tgt[bt*K+j];
float pr=probs[bt*K*V+(size_t)j*V+tt]; loss+=-logf(pr>1e-12f?pr:1e-12f); }
loss/=(BT*K);
free(enc);free(rms);free(rstd);free(qkv);free(atty);free(pre);free(att);free(aproj);free(res2);free(res3);
free(gate);free(up);free(hsil);free(mlp);free(probs);free(dl);
return loss;
}
#define DM(p,bytes) do{ CHECK(cudaMalloc(&(p),(bytes))); }while(0)
/* GPU forward, returns mean loss; writes logits for comparison. */
static float forward_gpu(Cfg c,const float*hostparams,size_t nparams,const int*ids,const int*tgt,int B,float*logits_out){
int C=c.C,T=c.T,NL=c.NL,NH=c.NH,NKV=c.NKV,V=c.V,H=c.H,K=c.K,QKV=qkvd(c),hs=C/NH;
size_t BT=(size_t)B*T,BTC=BT*C,ATT=(size_t)B*NH*T*T,BTH=BT*H,LG=BT*K*V;
float*dp; DM(dp,nparams*sizeof(float)); CHECK(cudaMemcpy(dp,hostparams,nparams*sizeof(float),cudaMemcpyHostToDevice));
W w; wset(&w,dp,c);
int *d_ids,*d_tgt; DM(d_ids,BT*sizeof(int)); DM(d_tgt,BT*K*sizeof(int));
CHECK(cudaMemcpy(d_ids,ids,BT*sizeof(int),cudaMemcpyHostToDevice));
CHECK(cudaMemcpy(d_tgt,tgt,BT*K*sizeof(int),cudaMemcpyHostToDevice));
float *enc,*rms,*rstd,*qkv,*atty,*pre,*att,*aproj,*res2,*res3,*gate,*up,*hsil,*mlp,*logits,*probs,*dl;
DM(enc,BTC*sizeof(float));DM(rms,BTC*sizeof(float));DM(rstd,BT*sizeof(float));DM(qkv,BT*QKV*sizeof(float));
DM(atty,BTC*sizeof(float));DM(pre,ATT*sizeof(float));DM(att,ATT*sizeof(float));DM(aproj,BTC*sizeof(float));
DM(res2,BTC*sizeof(float));DM(res3,BTC*sizeof(float));DM(gate,BTH*sizeof(float));DM(up,BTH*sizeof(float));
DM(hsil,BTH*sizeof(float));DM(mlp,BTC*sizeof(float));DM(logits,LG*sizeof(float));DM(probs,LG*sizeof(float));DM(dl,LG*sizeof(float));
int blk=128; size_t btc=BTC;
encoder_forward_kernel<<<((int)BT+blk-1)/blk,blk>>>(enc,d_ids,w.tok,(int)BT,C); CHECK(cudaGetLastError());
float*res=enc;
for(int l=0;l<NL;l++){
rmsnorm_forward_kernel<<<((int)BT+blk-1)/blk,blk>>>(rms,rstd,res,w.rms1g+l*C,(int)BT,C); CHECK(cudaGetLastError());
matmul_forward_cuda(qkv,rms,w.qkvw+(size_t)l*QKV*C,B*T,C,QKV);
int rg=(B*T*NH+blk-1)/blk; rope_kernel<<<rg,blk>>>(qkv,0,QKV,NH,hs,B,T,0); CHECK(cudaGetLastError());
int rgk=(B*T*NKV+blk-1)/blk; rope_kernel<<<rgk,blk>>>(qkv,C,QKV,NKV,hs,B,T,0); CHECK(cudaGetLastError());
attention_forward_kernel<<<(B*NH*T+blk-1)/blk,blk>>>(atty,pre,att,qkv,B,T,C,NH,NKV); CHECK(cudaGetLastError());
matmul_forward_cuda(aproj,atty,w.attprojw+(size_t)l*C*C,B*T,C,C);
residual_add_kernel<<<(int)((btc+blk-1)/blk),blk>>>(res2,res,aproj,btc); CHECK(cudaGetLastError());
rmsnorm_forward_kernel<<<((int)BT+blk-1)/blk,blk>>>(rms,rstd,res2,w.rms2g+l*C,(int)BT,C); CHECK(cudaGetLastError());
matmul_forward_cuda(gate,rms,w.gatew+(size_t)l*H*C,B*T,C,H);
matmul_forward_cuda(up,rms,w.upw+(size_t)l*H*C,B*T,C,H);
swiglu_forward_kernel<<<(int)((BTH+blk-1)/blk),blk>>>(hsil,gate,up,BTH); CHECK(cudaGetLastError());
matmul_forward_cuda(mlp,hsil,w.downw+(size_t)l*C*H,B*T,H,C);
residual_add_kernel<<<(int)((btc+blk-1)/blk),blk>>>(res3,res2,mlp,btc); CHECK(cudaGetLastError());
res=res3;
}
rmsnorm_forward_kernel<<<((int)BT+blk-1)/blk,blk>>>(rms,rstd,res,w.rmsfg,(int)BT,C); CHECK(cudaGetLastError());
matmul_forward_cuda(logits,rms,w.headw,B*T,C,K*V);
softmax_ce_kernel<<<((int)(BT*K)+blk-1)/blk,blk>>>(probs,dl,logits,d_tgt,(int)BT,V,K,1.0f/((float)BT*K)); CHECK(cudaGetLastError());
CHECK(cudaDeviceSynchronize());
CHECK(cudaMemcpy(logits_out,logits,LG*sizeof(float),cudaMemcpyDeviceToHost));
float*hp=(float*)malloc(LG*sizeof(float)); CHECK(cudaMemcpy(hp,probs,LG*sizeof(float),cudaMemcpyDeviceToHost));
float loss=0; for(size_t bt=0;bt<BT;bt++) for(int j=0;j<K;j++){ int tt=tgt[bt*K+j];
float pr=hp[bt*K*V+(size_t)j*V+tt]; loss+=-logf(pr>1e-12f?pr:1e-12f); } loss/=(BT*K); free(hp);
cudaFree(dp);cudaFree(d_ids);cudaFree(d_tgt);cudaFree(enc);cudaFree(rms);cudaFree(rstd);cudaFree(qkv);cudaFree(atty);
cudaFree(pre);cudaFree(att);cudaFree(aproj);cudaFree(res2);cudaFree(res3);cudaFree(gate);cudaFree(up);cudaFree(hsil);
cudaFree(mlp);cudaFree(logits);cudaFree(probs);cudaFree(dl);
return loss;
}
static void run_fwdcheck(void){
srand(7);
Cfg c={64,64,4,2,2,16,128,2}; int B=4; /* small config for the check */
size_t s[10]; size_t nparams=wsizes(c,s);
float*params=(float*)malloc(nparams*sizeof(float));
for(size_t i=0;i<nparams;i++) params[i]=0.05f*((float)rand()/RAND_MAX-0.5f);
W w; wset(&w,params,c);
size_t BT=(size_t)B*c.T; int*ids=(int*)malloc(BT*sizeof(int)),*tgt=(int*)malloc(BT*c.K*sizeof(int));
for(size_t i=0;i<BT;i++) ids[i]=rand()%c.V;
for(size_t i=0;i<BT*c.K;i++) tgt[i]=rand()%c.V;
size_t LG=BT*c.K*c.V; float*lg_cpu=(float*)malloc(LG*sizeof(float)),*lg_gpu=(float*)malloc(LG*sizeof(float));
float lc=forward_cpu(c,w,ids,tgt,B,lg_cpu);
float lg=forward_gpu(c,params,nparams,ids,tgt,B,lg_gpu);
printf("full-model forward check:\n");
printf(" loss CPU=%.6f GPU=%.6f (diff %.2e)\n", lc, lg, dabs(lc-lg));
check(" logits GPU vs CPU", lg_gpu, lg_cpu, LG);
free(params);free(ids);free(tgt);free(lg_cpu);free(lg_gpu);
}
/* ===== full-model forward+backward: GPU gradients vs CPU reference ===== */
static void run_gradcheck(void){
srand(11); printf("[g] start\n"); fflush(stdout);
Cfg c={64,64,4,2,2,16,128,2}; int B=4;
int C=c.C,T=c.T,NL=c.NL,NH=c.NH,NKV=c.NKV,V=c.V,H=c.H,K=c.K,QKV=qkvd(c),hs=C/NH;
size_t BT=(size_t)B*T,BTC=BT*C,ATT=(size_t)B*NH*T*T,BTH=BT*H,BTQ=BT*QKV,LG=BT*K*V;
size_t sz[10]; size_t np=wsizes(c,sz);
float*params=(float*)malloc(np*sizeof(float)); for(size_t i=0;i<np;i++) params[i]=0.05f*((float)rand()/RAND_MAX-0.5f);
W w; wset(&w,params,c);
int*ids=(int*)malloc(BT*sizeof(int)),*tgt=(int*)malloc(BT*K*sizeof(int));
for(size_t i=0;i<BT;i++) ids[i]=rand()%V; for(size_t i=0;i<BT*K;i++) tgt[i]=rand()%V;
float scale=1.0f/((float)BT*K);
/* ---------- CPU reference: forward (save acts) + backward ---------- */
float*cg=(float*)calloc(np,sizeof(float)); W gc; wset(&gc,cg,c);
#define A(n) (float*)malloc((n)*sizeof(float))
float*enc=A(BTC),*Lr1=A(NL*BTC),*Lr1s=A(NL*BT),*Lqkv=A(NL*BTQ),*Latty=A(NL*BTC),*Lpre=A(NL*ATT),*Latt=A(NL*ATT);
float*Lap=A(NL*BTC),*Lr2=A(NL*BTC),*Lr2n=A(NL*BTC),*Lr2s=A(NL*BT),*Lg=A(NL*BTH),*Lu=A(NL*BTH),*Lh=A(NL*BTH),*Lm=A(NL*BTC),*Lr3=A(NL*BTC);
float*rmsf=A(BTC),*rmsfs=A(BT),*logits=A(LG),*probs=A(LG),*dlog=A(LG);
float*drms=A(BTC),*dres=A(BTC),*dHs=A(BTH),*dG=A(BTH),*dU=A(BTH),*datty=A(BTC),*dqkv=A(BTQ),*dpre=A(ATT),*datt=A(ATT);
encoder_cpu(enc,ids,w.tok,(int)BT,C); float*res=enc;
for(int l=0;l<NL;l++){
float*r1=Lr1+l*BTC; rmsnorm_cpu_fwd(r1,Lr1s+l*BT,res,w.rms1g+l*C,(int)BT,C);
float*qk=Lqkv+(size_t)l*BTQ; matmul_cpu(qk,r1,w.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
rope_cpu(qk,0,QKV,NH,hs,B,T,0); rope_cpu(qk,C,QKV,NKV,hs,B,T,0);
float*at=Latty+l*BTC; attention_cpu_fwd(at,Lpre+(size_t)l*ATT,Latt+(size_t)l*ATT,qk,B,T,C,NH,NKV);
float*ap=Lap+l*BTC; matmul_cpu(ap,at,w.attprojw+(size_t)l*C*C,(int)BT,C,C);
float*r2=Lr2+l*BTC; for(size_t i=0;i<BTC;i++) r2[i]=res[i]+ap[i];
float*r2n=Lr2n+l*BTC; rmsnorm_cpu_fwd(r2n,Lr2s+l*BT,r2,w.rms2g+l*C,(int)BT,C);
float*ga=Lg+l*BTH; matmul_cpu(ga,r2n,w.gatew+(size_t)l*H*C,(int)BT,C,H);
float*up=Lu+l*BTH; matmul_cpu(up,r2n,w.upw+(size_t)l*H*C,(int)BT,C,H);
float*hs2=Lh+l*BTH; swiglu_cpu_fwd(hs2,ga,up,BTH);
float*ml=Lm+l*BTC; matmul_cpu(ml,hs2,w.downw+(size_t)l*C*H,(int)BT,H,C);
float*r3=Lr3+l*BTC; for(size_t i=0;i<BTC;i++) r3[i]=r2[i]+ml[i];
res=r3;
}
rmsnorm_cpu_fwd(rmsf,rmsfs,res,w.rmsfg,(int)BT,C);
matmul_cpu(logits,rmsf,w.headw,(int)BT,C,K*V);
softmax_ce_cpu(probs,dlog,logits,tgt,(int)BT,V,K,scale);
printf("[g] cpu forward done\n"); fflush(stdout);
memset(drms,0,BTC*sizeof(float)); matmul_backward_cpu(drms,gc.headw,dlog,rmsf,w.headw,(int)BT,C,K*V);
memset(dres,0,BTC*sizeof(float)); rmsnorm_cpu_bwd(dres,gc.rmsfg,drms,Lr3+(size_t)(NL-1)*BTC,w.rmsfg,rmsfs,(int)BT,C);
for(int l=NL-1;l>=0;l--){
float*res_in=(l==0)?enc:Lr3+(size_t)(l-1)*BTC;
memset(dHs,0,BTH*sizeof(float)); matmul_backward_cpu(dHs,gc.downw+(size_t)l*C*H,dres,Lh+l*BTH,w.downw+(size_t)l*C*H,(int)BT,H,C);
swiglu_cpu_bwd(dG,dU,dHs,Lg+l*BTH,Lu+l*BTH,BTH);
memset(drms,0,BTC*sizeof(float));
matmul_backward_cpu(drms,gc.gatew+(size_t)l*H*C,dG,Lr2n+l*BTC,w.gatew+(size_t)l*H*C,(int)BT,C,H);
matmul_backward_cpu(drms,gc.upw+(size_t)l*H*C,dU,Lr2n+l*BTC,w.upw+(size_t)l*H*C,(int)BT,C,H);
rmsnorm_cpu_bwd(dres,gc.rms2g+l*C,drms,Lr2+l*BTC,w.rms2g+l*C,Lr2s+l*BT,(int)BT,C);
memset(datty,0,BTC*sizeof(float)); matmul_backward_cpu(datty,gc.attprojw+(size_t)l*C*C,dres,Latty+l*BTC,w.attprojw+(size_t)l*C*C,(int)BT,C,C);
memset(dqkv,0,BTQ*sizeof(float)); memset(dpre,0,ATT*sizeof(float)); memset(datt,0,ATT*sizeof(float));
attention_cpu_bwd(dqkv,dpre,datt,datty,Lqkv+(size_t)l*BTQ,Latt+(size_t)l*ATT,B,T,C,NH,NKV);
rope_cpu(dqkv,0,QKV,NH,hs,B,T,1); rope_cpu(dqkv,C,QKV,NKV,hs,B,T,1);
memset(drms,0,BTC*sizeof(float)); matmul_backward_cpu(drms,gc.qkvw+(size_t)l*QKV*C,dqkv,Lr1+l*BTC,w.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
rmsnorm_cpu_bwd(dres,gc.rms1g+l*C,drms,res_in,w.rms1g+l*C,Lr1s+l*BT,(int)BT,C);
}
encoder_backward_cpu(gc.tok,dres,ids,(int)BT,C);
printf("[g] cpu reference done\n"); fflush(stdout);
/* ---------- GPU: forward (save acts) + backward ---------- */
float*dp; DM(dp,np*sizeof(float)); CHECK(cudaMemcpy(dp,params,np*sizeof(float),cudaMemcpyHostToDevice)); W wd; wset(&wd,dp,c);
float*dgr; DM(dgr,np*sizeof(float)); CHECK(cudaMemset(dgr,0,np*sizeof(float))); W gd; wset(&gd,dgr,c);
int *did,*dtg; DM(did,BT*sizeof(int)); DM(dtg,BT*K*sizeof(int));
CHECK(cudaMemcpy(did,ids,BT*sizeof(int),cudaMemcpyHostToDevice)); CHECK(cudaMemcpy(dtg,tgt,BT*K*sizeof(int),cudaMemcpyHostToDevice));
float *Genc,*GLr1,*GLr1s,*GLqkv,*GLatty,*GLap,*GLr2,*GLr2n,*GLr2s,*GLg,*GLu,*GLh,*GLm,*GLr3,*GLlse;
float *Grmsf,*Grmsfs,*Glog,*Gprob,*Gdlog,*Gdrms,*Gdres,*GdHs,*GdG,*GdU,*Gdatty,*Gdqkv;
DM(Genc,BTC*4);DM(GLr1,NL*BTC*4);DM(GLr1s,NL*BT*4);DM(GLqkv,NL*BTQ*4);DM(GLatty,NL*BTC*4);
DM(GLap,NL*BTC*4);DM(GLr2,NL*BTC*4);DM(GLr2n,NL*BTC*4);DM(GLr2s,NL*BT*4);DM(GLg,NL*BTH*4);DM(GLu,NL*BTH*4);DM(GLh,NL*BTH*4);DM(GLm,NL*BTC*4);DM(GLr3,NL*BTC*4);DM(GLlse,(size_t)NL*B*NH*T*4);
DM(Grmsf,BTC*4);DM(Grmsfs,BT*4);DM(Glog,LG*4);DM(Gprob,LG*4);DM(Gdlog,LG*4);DM(Gdrms,BTC*4);DM(Gdres,BTC*4);DM(GdHs,BTH*4);DM(GdG,BTH*4);DM(GdU,BTH*4);DM(Gdatty,BTC*4);DM(Gdqkv,BTQ*4);
int blk=128; size_t btc=BTC,bth=BTH;
#define GR(n) (int)(((n)+blk-1)/blk)
encoder_forward_kernel<<<GR(BT),blk>>>(Genc,did,wd.tok,(int)BT,C); CHECK(cudaGetLastError());
float*gres=Genc;
for(int l=0;l<NL;l++){
float*r1=GLr1+(size_t)l*BTC; rmsnorm_forward_kernel<<<GR(BT),blk>>>(r1,GLr1s+(size_t)l*BT,gres,wd.rms1g+l*C,(int)BT,C);
float*qk=GLqkv+(size_t)l*BTQ; matmul_forward_cuda(qk,r1,wd.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
rope_kernel<<<GR((size_t)B*T*NH),blk>>>(qk,0,QKV,NH,hs,B,T,0); rope_kernel<<<GR((size_t)B*T*NKV),blk>>>(qk,C,QKV,NKV,hs,B,T,0);
float*at=GLatty+(size_t)l*BTC; flash_attn_forward_kernel<<<dim3(B*NH,(T+FA_BR-1)/FA_BR),FA_BR>>>(at,GLlse+(size_t)l*B*NH*T,qk,B,T,C,NH,NKV);
float*ap=GLap+(size_t)l*BTC; matmul_forward_cuda(ap,at,wd.attprojw+(size_t)l*C*C,(int)BT,C,C);
float*r2=GLr2+(size_t)l*BTC; residual_add_kernel<<<GR(btc),blk>>>(r2,gres,ap,btc);
float*r2n=GLr2n+(size_t)l*BTC; rmsnorm_forward_kernel<<<GR(BT),blk>>>(r2n,GLr2s+(size_t)l*BT,r2,wd.rms2g+l*C,(int)BT,C);
float*ga=GLg+(size_t)l*BTH; matmul_forward_cuda(ga,r2n,wd.gatew+(size_t)l*H*C,(int)BT,C,H);
float*up=GLu+(size_t)l*BTH; matmul_forward_cuda(up,r2n,wd.upw+(size_t)l*H*C,(int)BT,C,H);
float*hs2=GLh+(size_t)l*BTH; swiglu_forward_kernel<<<GR(bth),blk>>>(hs2,ga,up,bth);
float*ml=GLm+(size_t)l*BTC; matmul_forward_cuda(ml,hs2,wd.downw+(size_t)l*C*H,(int)BT,H,C);
float*r3=GLr3+(size_t)l*BTC; residual_add_kernel<<<GR(btc),blk>>>(r3,r2,ml,btc);
gres=r3;
}
rmsnorm_forward_kernel<<<GR(BT),blk>>>(Grmsf,Grmsfs,gres,wd.rmsfg,(int)BT,C);
matmul_forward_cuda(Glog,Grmsf,wd.headw,(int)BT,C,K*V);
softmax_ce_kernel<<<GR(BT*K),blk>>>(Gprob,Gdlog,Glog,dtg,(int)BT,V,K,scale); CHECK(cudaGetLastError());
CHECK(cudaMemset(Gdrms,0,BTC*4)); matmul_backward_dinp_cuda(Gdrms,Gdlog,wd.headw,(int)BT,C,K*V);
matmul_backward_dw_cuda(gd.headw,Gdlog,Grmsf,(int)BT,C,K*V);
CHECK(cudaMemset(Gdres,0,BTC*4)); rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rmsfg,Gdrms,GLr3+(size_t)(NL-1)*BTC,wd.rmsfg,Grmsfs,(int)BT,C);
for(int l=NL-1;l>=0;l--){
float*res_in=(l==0)?Genc:GLr3+(size_t)(l-1)*BTC;
CHECK(cudaMemset(GdHs,0,BTH*4)); matmul_backward_dinp_cuda(GdHs,Gdres,wd.downw+(size_t)l*C*H,(int)BT,H,C);
matmul_backward_dw_cuda(gd.downw+(size_t)l*C*H,Gdres,GLh+(size_t)l*BTH,(int)BT,H,C);
swiglu_backward_kernel<<<GR(bth),blk>>>(GdG,GdU,GdHs,GLg+(size_t)l*BTH,GLu+(size_t)l*BTH,bth);
CHECK(cudaMemset(Gdrms,0,BTC*4));
matmul_backward_dinp_cuda(Gdrms,GdG,wd.gatew+(size_t)l*H*C,(int)BT,C,H);
matmul_backward_dw_cuda(gd.gatew+(size_t)l*H*C,GdG,GLr2n+(size_t)l*BTC,(int)BT,C,H);
matmul_backward_dinp_cuda(Gdrms,GdU,wd.upw+(size_t)l*H*C,(int)BT,C,H);
matmul_backward_dw_cuda(gd.upw+(size_t)l*H*C,GdU,GLr2n+(size_t)l*BTC,(int)BT,C,H);
rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rms2g+l*C,Gdrms,GLr2+(size_t)l*BTC,wd.rms2g+l*C,GLr2s+(size_t)l*BT,(int)BT,C);
CHECK(cudaMemset(Gdatty,0,BTC*4)); matmul_backward_dinp_cuda(Gdatty,Gdres,wd.attprojw+(size_t)l*C*C,(int)BT,C,C);
matmul_backward_dw_cuda(gd.attprojw+(size_t)l*C*C,Gdres,GLatty+(size_t)l*BTC,(int)BT,C,C);
CHECK(cudaMemset(Gdqkv,0,BTQ*4));
flash_attn_backward_kernel<<<dim3(B*NH,(T+FA_BR-1)/FA_BR),FA_BR>>>(Gdqkv,GLqkv+(size_t)l*BTQ,GLatty+(size_t)l*BTC,Gdatty,GLlse+(size_t)l*B*NH*T,B,T,C,NH,NKV);
rope_kernel<<<GR((size_t)B*T*NH),blk>>>(Gdqkv,0,QKV,NH,hs,B,T,1); rope_kernel<<<GR((size_t)B*T*NKV),blk>>>(Gdqkv,C,QKV,NKV,hs,B,T,1);
CHECK(cudaMemset(Gdrms,0,BTC*4)); matmul_backward_dinp_cuda(Gdrms,Gdqkv,wd.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
matmul_backward_dw_cuda(gd.qkvw+(size_t)l*QKV*C,Gdqkv,GLr1+(size_t)l*BTC,(int)BT,C,QKV);
rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rms1g+l*C,Gdrms,res_in,wd.rms1g+l*C,GLr1s+(size_t)l*BT,(int)BT,C);
}
encoder_backward_kernel<<<GR(btc),blk>>>(gd.tok,Gdres,did,(int)BT,C); CHECK(cudaGetLastError());
CHECK(cudaDeviceSynchronize());
float*ggpu=(float*)malloc(np*sizeof(float)); CHECK(cudaMemcpy(ggpu,dgr,np*sizeof(float),cudaMemcpyDeviceToHost));
printf("[g] gpu forward+backward done\n"); fflush(stdout);
printf("full-model gradient check (GPU vs CPU):\n");
const char*nm[10]={"tok","rms1g","qkvw","attprojw","rms2g","gatew","upw","downw","rmsfg","headw"};
size_t off=0; for(int t=0;t<10;t++){ check(nm[t], ggpu+off, cg+off, sz[t]); off+=sz[t]; }
check(" ALL params", ggpu, cg, np);
free(params);free(ids);free(tgt);free(cg);free(ggpu);
}
/* ===== byte-level BPE (same scheme as ../nanoeuler.c) + data ===== */
#define VOCAB_MAX 4096
#define BPE_MERGES (VOCAB_MAX-256)
#define BPE_SAMPLE (32L*1024*1024) /* learn merges on the first ~32MB; merges generalize */
#define BPE_MAXWORD 4096 /* cap a pretokenized "word" so encode stays O(len^2) bounded */
static int *data_ids=0; static long data_n=0; static int VOCAB=0;
static int bpe_a[VOCAB_MAX],bpe_b[VOCAB_MAX],n_merges=0;
static int *bpe_id=0; /* dense pair lookup: bpe_id[a*VOCAB_MAX+b] = merged id, or 0 if none */
#define IS_SP(c) ((c)==' '||(c)=='\n'||(c)=='\t'||(c)=='\r')
/* build the pair->id index from the merge table (call after merges are known/loaded) */
static void bpe_build_index(void){
if(!bpe_id) bpe_id=(int*)malloc((size_t)VOCAB_MAX*VOCAB_MAX*sizeof(int));
memset(bpe_id,0,(size_t)VOCAB_MAX*VOCAB_MAX*sizeof(int));
for(int m=0;m<n_merges;m++) bpe_id[(size_t)bpe_a[m]*VOCAB_MAX+bpe_b[m]]=256+m;
}
/* encode one word in place: repeatedly merge the lowest-id (earliest-learned) applicable pair */
static int bpe_encode_word(int*w,int len){
for(;;){ int best=0x7fffffff,bp=-1;
for(int i=0;i+1<len;i++){ int id=bpe_id[(size_t)w[i]*VOCAB_MAX+w[i+1]]; if(id&&id<best){best=id;bp=i;} }
if(bp<0) break;
int a=w[bp],b=w[bp+1],j=0;
for(int i=0;i<len;){ if(i+1<len&&w[i]==a&&w[i+1]==b){w[j++]=best;i+=2;} else w[j++]=w[i++]; }
len=j;
}
return len;
}
/* GPT-2-style pretokenization: a chunk attaches a single leading space to the following word,
so spaces don't become standalone tokens. Returns the end index of the chunk starting at i. */
static long chunk_end(const unsigned char*b,long n,long i){
if(i>=n) return i;
if(IS_SP(b[i])){ long j=i; while(j<n&&IS_SP(b[j])) j++;
if(j>i+1 && j<n && !IS_SP(b[j])) return j-1; /* >=2 spaces before a word: keep all but the last */
if(j<n && !IS_SP(b[j])){ long k=j; while(k<n&&!IS_SP(b[k])&&k-i<BPE_MAXWORD) k++; return k; } /* " word" */
return j; } /* trailing/standalone whitespace */
long k=i; while(k<n&&!IS_SP(b[k])&&k-i<BPE_MAXWORD) k++; return k;
}
/* pretokenize bytes into chunks, encode each, append ids to out */
static long bpe_encode_bytes(const unsigned char*b,long n,int*out){
long total=0; int w[BPE_MAXWORD]; long i=0;
while(i<n){ long e=chunk_end(b,n,i); int len=(int)(e-i);
for(int k=0;k<len;k++) w[k]=b[i+k];
len=bpe_encode_word(w,len); for(int k=0;k<len;k++) out[total++]=w[k]; i=e; }
return total;
}
static void load_bpe(const char*path){
FILE*f=fopen(path,"rb"); if(!f){fprintf(stderr,"cannot open %s\n",path);exit(1);}
fseek(f,0,SEEK_END);long n=ftell(f);fseek(f,0,SEEK_SET);unsigned char*buf=(unsigned char*)malloc(n);
if(fread(buf,1,n,f)!=(size_t)n){fprintf(stderr,"read error\n");exit(1);} fclose(f);
/* --- learn merges on a capped sample (dense pair counts, max tracked inline) --- */
long sn=n<BPE_SAMPLE?n:BPE_SAMPLE;
/* build the sample with -1 barriers between words (same whitespace/non-whitespace split as
the encoder); merges never cross a barrier, so trainer and encoder agree. */
int*seq=(int*)malloc((2*sn+8)*sizeof(int)); long len=0;
for(long i=0;i<sn;){ long e=chunk_end(buf,sn,i); for(long k=i;k<e;k++) seq[len++]=buf[k]; if(e<sn) seq[len++]=-1; i=e; }
int*cnt=(int*)malloc((size_t)VOCAB_MAX*VOCAB_MAX*sizeof(int)); n_merges=0;
printf("[bpe] training merges on %ldMB sample...\n",sn/(1024*1024)); fflush(stdout);
while(n_merges<BPE_MERGES){ int Vc=256+n_merges; memset(cnt,0,(size_t)Vc*VOCAB_MAX*sizeof(int));
if((n_merges&511)==0&&n_merges){ printf("[bpe] %d/%d merges\n",n_merges,BPE_MERGES); fflush(stdout); }
long best=1; int ba=0,bb=0;
for(long i=0;i+1<len;i++){ if(seq[i]<0||seq[i+1]<0) continue; long c=++cnt[(size_t)seq[i]*VOCAB_MAX+seq[i+1]]; if(c>best){best=c;ba=seq[i];bb=seq[i+1];} }
if(best<=1) break; int id=256+n_merges; bpe_a[n_merges]=ba;bpe_b[n_merges]=bb;n_merges++;
long j=0; for(long i=0;i<len;){ if(i+1<len&&seq[i]==ba&&seq[i+1]==bb){seq[j++]=id;i+=2;}else seq[j++]=seq[i++];} len=j; }
free(cnt); free(seq); VOCAB=256+n_merges; bpe_build_index();
/* --- tokenize the FULL corpus with the learned merges (per-word greedy, grows as needed) --- */
long cap=n/3+16; int*ids=(int*)malloc(cap*sizeof(int)); long m=0; int w[BPE_MAXWORD]; long i=0;
while(i<n){ long e=chunk_end(buf,n,i); int wl=(int)(e-i); for(int k=0;k<wl;k++) w[k]=buf[i+k];
wl=bpe_encode_word(w,wl);
if(m+wl>cap){ cap=cap*2+wl; ids=(int*)realloc(ids,cap*sizeof(int)); }
for(int k=0;k<wl;k++) ids[m++]=w[k]; i=e; }
free(buf); data_ids=ids; data_n=m;
printf("[data] %ld bytes -> %ld tokens | vocab=%d\n",n,m,VOCAB);
}
static void get_batch(int*inp,int*tgt,int B,int T,int K){ for(int b=0;b<B;b++){ long st=rand()%(data_n-T-K-1);
for(int t=0;t<T;t++){ inp[b*T+t]=data_ids[st+t]; for(int j=0;j<K;j++) tgt[((size_t)b*T+t)*K+j]=data_ids[st+t+1+j]; } } }
static float randn(){ float u1=(rand()+1.0f)/(RAND_MAX+2.0f),u2=(float)rand()/RAND_MAX; return sqrtf(-2.0f*logf(u1))*cosf(6.2831853f*u2); }
/* ===== GPU training (mode "t"). Reuses the validated forward+backward orchestration. ===== */
static void run_train(int resume){
srand(1337); load_bpe("../data/pretrain.txt"); /* cat data/gutenberg.txt data/web.txt > data/pretrain.txt */
int Cc=768; Cfg c={VOCAB,Cc,12,4,16,512,(8*Cc)/3,4}; int B=16,steps=300000,warm=2000; float lr0=6e-4f;
int C=c.C,T=c.T,NL=c.NL,NH=c.NH,NKV=c.NKV,V=c.V,H=c.H,K=c.K,QKV=qkvd(c),hs=C/NH;
size_t BT=(size_t)B*T,BTC=BT*C,BTH=BT*H,BTQ=BT*QKV,LG=BT*K*V;
size_t sz[10]; size_t np=wsizes(c,sz);
printf("[model:gpu] C=%d qh=%d kvh=%d L=%d T=%d H=%d MTP=%d | %.2fM params\n",C,NH,NKV,NL,T,H,K,np/1e6);
printf("[sched] 1 epoch = %ld steps (%d tok/step); steps=%d -> %.1f epochs\n",
data_n/((long)B*T), B*T, steps, (double)steps*B*T/(double)data_n);
/* host init (like ../nanoeuler.c model_init) */
float*params=(float*)malloc(np*sizeof(float)); W w; wset(&w,params,c);
int resumed=0;
if(resume){ FILE*rf=fopen("../nanoeuler.bin","rb");
if(rf){ Cfg rc; int rv=0,rm=0;
if(fread(&rc,sizeof(Cfg),1,rf)==1 && fread(&rv,sizeof(int),1,rf)==1 && fread(&rm,sizeof(int),1,rf)==1
&& rc.C==c.C && rc.NL==c.NL && rc.V==c.V && rv==VOCAB){
fseek(rf,(long)2*rm*(long)sizeof(int),SEEK_CUR); /* skip the stored bpe arrays */
if(fread(params,sizeof(float),np,rf)==np){ resumed=1; printf("[resume] continuing from ../nanoeuler.bin\n"); } }
fclose(rf); }
if(!resumed) fprintf(stderr,"[resume] no matching checkpoint; starting fresh\n"); }
if(!resumed){ float rs=0.02f/sqrtf(2.0f*NL);
for(size_t i=0;i<sz[0];i++) w.tok[i]=0.02f*randn();
for(size_t i=0;i<sz[2];i++) w.qkvw[i]=0.02f*randn();
for(size_t i=0;i<sz[3];i++) w.attprojw[i]=rs*randn();
for(size_t i=0;i<sz[5];i++) w.gatew[i]=0.02f*randn();
for(size_t i=0;i<sz[6];i++) w.upw[i]=0.02f*randn();
for(size_t i=0;i<sz[7];i++) w.downw[i]=rs*randn();
for(size_t i=0;i<sz[9];i++) w.headw[i]=0.02f*randn();
for(size_t i=0;i<sz[1];i++) w.rms1g[i]=1.0f;
for(size_t i=0;i<sz[4];i++) w.rms2g[i]=1.0f;
for(int i=0;i<C;i++) w.rmsfg[i]=1.0f; }
/* device params/grads/adam-state */
float*dp;DM(dp,np*4);CHECK(cudaMemcpy(dp,params,np*4,cudaMemcpyHostToDevice)); W wd;wset(&wd,dp,c);
float*dgr;DM(dgr,np*4); float*dm;DM(dm,np*4);CHECK(cudaMemset(dm,0,np*4)); float*dv;DM(dv,np*4);CHECK(cudaMemset(dv,0,np*4));
W gd;wset(&gd,dgr,c);
int *did,*dtg;DM(did,BT*sizeof(int));DM(dtg,BT*K*sizeof(int));
int *hid=(int*)malloc(BT*sizeof(int)),*htg=(int*)malloc(BT*K*sizeof(int));
/* device activations (allocated once) */
float *Genc,*GLr1,*GLr1s,*GLqkv,*GLatty,*GLap,*GLr2,*GLr2n,*GLr2s,*GLg,*GLu,*GLh,*GLm,*GLr3,*GLlse;
float *Grmsf,*Grmsfs,*Glog,*Gprob,*Gdlog,*Gdrms,*Gdres,*GdHs,*GdG,*GdU,*Gdatty,*Gdqkv;
DM(Genc,BTC*4);DM(GLr1,NL*BTC*4);DM(GLr1s,NL*BT*4);DM(GLqkv,NL*BTQ*4);DM(GLatty,NL*BTC*4);
DM(GLap,NL*BTC*4);DM(GLr2,NL*BTC*4);DM(GLr2n,NL*BTC*4);DM(GLr2s,NL*BT*4);DM(GLg,NL*BTH*4);DM(GLu,NL*BTH*4);DM(GLh,NL*BTH*4);DM(GLm,NL*BTC*4);DM(GLr3,NL*BTC*4);DM(GLlse,(size_t)NL*B*NH*T*4);
DM(Grmsf,BTC*4);DM(Grmsfs,BT*4);DM(Glog,LG*4);DM(Gprob,LG*4);DM(Gdlog,LG*4);DM(Gdrms,BTC*4);DM(Gdres,BTC*4);DM(GdHs,BTH*4);DM(GdG,BTH*4);DM(GdU,BTH*4);DM(Gdatty,BTC*4);DM(Gdqkv,BTQ*4);
int blk=128; size_t btc=BTC,bth=BTH; float scale=1.0f/((float)BT*K);
float*hprob=(float*)malloc(LG*sizeof(float));
clock_t t0=clock();
for(int step=1;step<=steps;step++){
float lr=step<warm?lr0*step/warm:lr0*0.5f*(1+cosf(3.14159265f*(step-warm)/(steps-warm)));
get_batch(hid,htg,B,T,K);
CHECK(cudaMemcpy(did,hid,BT*sizeof(int),cudaMemcpyHostToDevice));
CHECK(cudaMemcpy(dtg,htg,BT*K*sizeof(int),cudaMemcpyHostToDevice));
CHECK(cudaMemset(dgr,0,np*4));
/* ---- forward ---- */
encoder_forward_kernel<<<GR(BT),blk>>>(Genc,did,wd.tok,(int)BT,C);
float*gres=Genc;
for(int l=0;l<NL;l++){
float*r1=GLr1+(size_t)l*BTC; rmsnorm_forward_kernel<<<GR(BT),blk>>>(r1,GLr1s+(size_t)l*BT,gres,wd.rms1g+l*C,(int)BT,C);
float*qk=GLqkv+(size_t)l*BTQ; matmul_forward_cuda(qk,r1,wd.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
rope_kernel<<<GR((size_t)B*T*NH),blk>>>(qk,0,QKV,NH,hs,B,T,0); rope_kernel<<<GR((size_t)B*T*NKV),blk>>>(qk,C,QKV,NKV,hs,B,T,0);
float*at=GLatty+(size_t)l*BTC; flash_attn_forward_kernel<<<dim3(B*NH,(T+FA_BR-1)/FA_BR),FA_BR>>>(at,GLlse+(size_t)l*B*NH*T,qk,B,T,C,NH,NKV);
float*ap=GLap+(size_t)l*BTC; matmul_forward_cuda(ap,at,wd.attprojw+(size_t)l*C*C,(int)BT,C,C);
float*r2=GLr2+(size_t)l*BTC; residual_add_kernel<<<GR(btc),blk>>>(r2,gres,ap,btc);
float*r2n=GLr2n+(size_t)l*BTC; rmsnorm_forward_kernel<<<GR(BT),blk>>>(r2n,GLr2s+(size_t)l*BT,r2,wd.rms2g+l*C,(int)BT,C);
float*ga=GLg+(size_t)l*BTH; matmul_forward_cuda(ga,r2n,wd.gatew+(size_t)l*H*C,(int)BT,C,H);
float*up=GLu+(size_t)l*BTH; matmul_forward_cuda(up,r2n,wd.upw+(size_t)l*H*C,(int)BT,C,H);
float*hs2=GLh+(size_t)l*BTH; swiglu_forward_kernel<<<GR(bth),blk>>>(hs2,ga,up,bth);
float*ml=GLm+(size_t)l*BTC; matmul_forward_cuda(ml,hs2,wd.downw+(size_t)l*C*H,(int)BT,H,C);
float*r3=GLr3+(size_t)l*BTC; residual_add_kernel<<<GR(btc),blk>>>(r3,r2,ml,btc);
gres=r3;
}
rmsnorm_forward_kernel<<<GR(BT),blk>>>(Grmsf,Grmsfs,gres,wd.rmsfg,(int)BT,C);
matmul_forward_cuda(Glog,Grmsf,wd.headw,(int)BT,C,K*V);
softmax_ce_kernel<<<GR(BT*K),blk>>>(Gprob,Gdlog,Glog,dtg,(int)BT,V,K,scale);
/* ---- backward ---- */
CHECK(cudaMemset(Gdrms,0,BTC*4)); matmul_backward_dinp_cuda(Gdrms,Gdlog,wd.headw,(int)BT,C,K*V);
matmul_backward_dw_cuda(gd.headw,Gdlog,Grmsf,(int)BT,C,K*V);
CHECK(cudaMemset(Gdres,0,BTC*4)); rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rmsfg,Gdrms,GLr3+(size_t)(NL-1)*BTC,wd.rmsfg,Grmsfs,(int)BT,C);
for(int l=NL-1;l>=0;l--){
float*res_in=(l==0)?Genc:GLr3+(size_t)(l-1)*BTC;
CHECK(cudaMemset(GdHs,0,BTH*4)); matmul_backward_dinp_cuda(GdHs,Gdres,wd.downw+(size_t)l*C*H,(int)BT,H,C);
matmul_backward_dw_cuda(gd.downw+(size_t)l*C*H,Gdres,GLh+(size_t)l*BTH,(int)BT,H,C);
swiglu_backward_kernel<<<GR(bth),blk>>>(GdG,GdU,GdHs,GLg+(size_t)l*BTH,GLu+(size_t)l*BTH,bth);
CHECK(cudaMemset(Gdrms,0,BTC*4));
matmul_backward_dinp_cuda(Gdrms,GdG,wd.gatew+(size_t)l*H*C,(int)BT,C,H);
matmul_backward_dw_cuda(gd.gatew+(size_t)l*H*C,GdG,GLr2n+(size_t)l*BTC,(int)BT,C,H);
matmul_backward_dinp_cuda(Gdrms,GdU,wd.upw+(size_t)l*H*C,(int)BT,C,H);
matmul_backward_dw_cuda(gd.upw+(size_t)l*H*C,GdU,GLr2n+(size_t)l*BTC,(int)BT,C,H);
rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rms2g+l*C,Gdrms,GLr2+(size_t)l*BTC,wd.rms2g+l*C,GLr2s+(size_t)l*BT,(int)BT,C);
CHECK(cudaMemset(Gdatty,0,BTC*4)); matmul_backward_dinp_cuda(Gdatty,Gdres,wd.attprojw+(size_t)l*C*C,(int)BT,C,C);
matmul_backward_dw_cuda(gd.attprojw+(size_t)l*C*C,Gdres,GLatty+(size_t)l*BTC,(int)BT,C,C);
CHECK(cudaMemset(Gdqkv,0,BTQ*4));
flash_attn_backward_kernel<<<dim3(B*NH,(T+FA_BR-1)/FA_BR),FA_BR>>>(Gdqkv,GLqkv+(size_t)l*BTQ,GLatty+(size_t)l*BTC,Gdatty,GLlse+(size_t)l*B*NH*T,B,T,C,NH,NKV);
rope_kernel<<<GR((size_t)B*T*NH),blk>>>(Gdqkv,0,QKV,NH,hs,B,T,1); rope_kernel<<<GR((size_t)B*T*NKV),blk>>>(Gdqkv,C,QKV,NKV,hs,B,T,1);
CHECK(cudaMemset(Gdrms,0,BTC*4)); matmul_backward_dinp_cuda(Gdrms,Gdqkv,wd.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
matmul_backward_dw_cuda(gd.qkvw+(size_t)l*QKV*C,Gdqkv,GLr1+(size_t)l*BTC,(int)BT,C,QKV);
rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rms1g+l*C,Gdrms,res_in,wd.rms1g+l*C,GLr1s+(size_t)l*BT,(int)BT,C);
}
encoder_backward_kernel<<<GR(btc),blk>>>(gd.tok,Gdres,did,(int)BT,C);
/* ---- AdamW ---- */
float c1=1.0f-powf(0.9f,step),c2=1.0f-powf(0.95f,step);
adamw_kernel<<<GR(np),blk>>>(dp,dgr,dm,dv,np,lr,0.9f,0.95f,1e-8f,0.1f,c1,c2); CHECK(cudaGetLastError());
if(step%50==0||step==1){ CHECK(cudaDeviceSynchronize()); CHECK(cudaMemcpy(hprob,Gprob,LG*4,cudaMemcpyDeviceToHost));
float loss=0; for(size_t bt=0;bt<BT;bt++)for(int j=0;j<K;j++){int tt=htg[bt*K+j]; float pr=hprob[bt*K*V+(size_t)j*V+tt]; loss+=-logf(pr>1e-12f?pr:1e-12f);} loss/=(BT*K);
printf("step %5d | ep %.2f | loss %.4f | lr %.1e | %.1fs\n",step,(double)step*B*T/(double)data_n,loss,lr,(double)(clock()-t0)/CLOCKS_PER_SEC); fflush(stdout); }
if(step%5000==0){ CHECK(cudaMemcpy(params,dp,np*4,cudaMemcpyDeviceToHost));
FILE*ck=fopen("../nanoeuler.bin","wb"); fwrite(&c,sizeof(Cfg),1,ck); fwrite(&VOCAB,sizeof(int),1,ck);
fwrite(&n_merges,sizeof(int),1,ck); fwrite(bpe_a,sizeof(int),n_merges,ck); fwrite(bpe_b,sizeof(int),n_merges,ck);
fwrite(params,sizeof(float),np,ck); fclose(ck); printf("[ckpt] step %d -> ../nanoeuler.bin\n",step); fflush(stdout); }
}
/* save in the CPU program's format: Config, VOCAB, n_merges, bpe_a, bpe_b, params */
CHECK(cudaMemcpy(params,dp,np*4,cudaMemcpyDeviceToHost));
FILE*f=fopen("../nanoeuler.bin","wb"); fwrite(&c,sizeof(Cfg),1,f); fwrite(&VOCAB,sizeof(int),1,f);
fwrite(&n_merges,sizeof(int),1,f); fwrite(bpe_a,sizeof(int),n_merges,f); fwrite(bpe_b,sizeof(int),n_merges,f);
fwrite(params,sizeof(float),np,f); fclose(f);
printf("[saved] ../nanoeuler.bin (load it with the CPU build: ./nanoeuler chat)\n");
}
/* ===== BPE decode/encode (for inference; merges loaded from the .bin) ===== */
static unsigned char *tok_bytes[VOCAB_MAX]; static int tok_len[VOCAB_MAX];
static void bpe_build_decode(){ for(int i=0;i<256;i++){tok_bytes[i]=(unsigned char*)malloc(1);tok_bytes[i][0]=(unsigned char)i;tok_len[i]=1;}
for(int m=0;m<n_merges;m++){int id=256+m,a=bpe_a[m],b=bpe_b[m];tok_len[id]=tok_len[a]+tok_len[b];tok_bytes[id]=(unsigned char*)malloc(tok_len[id]);
memcpy(tok_bytes[id],tok_bytes[a],tok_len[a]);memcpy(tok_bytes[id]+tok_len[a],tok_bytes[b],tok_len[b]);}}
static int bpe_encode(const char*t,int*out){ return (int)bpe_encode_bytes((const unsigned char*)t,(long)strlen(t),out); }
static int sample_host(const float*lg,int V,float temp){ float mx=-1e30f; for(int i=0;i<V;i++) if(lg[i]>mx)mx=lg[i];
static float*pr=0; static int cap=0; if(cap<V){free(pr);pr=(float*)malloc(V*sizeof(float));cap=V;} float s=0;
for(int i=0;i<V;i++){float e=expf((lg[i]-mx)/temp);pr[i]=e;s+=e;} float r=((float)rand()/RAND_MAX)*s,acc=0;
for(int i=0;i<V;i++){acc+=pr[i]; if(acc>=r)return i;} return V-1; }
/* top-k sampling + repetition penalty over the tokens already generated this turn. Pure
temperature is fragile (high -> tail junk, low -> "| | |" repetition loops); top-k keeps only
the k most likely tokens and the penalty discourages repeats, which is what makes a small
model behave like a chatbot instead of looping. */
static int sample_topk(const float*lg,int V,float temp,int topk,const int*recent,int nrec,float rep){
static float*w=0; static int wc=0; if(wc<V){free(w);w=(float*)malloc(V*sizeof(float));wc=V;}
for(int i=0;i<V;i++) w[i]=lg[i];
for(int r=0;r<nrec;r++){ int t=recent[r]; if(t>=0&&t<V) w[t]= w[t]>0? w[t]/rep : w[t]*rep; }
if(topk>V) topk=V;
static float*buf=0; static int bc=0; if(bc<V){free(buf);buf=(float*)malloc(V*sizeof(float));bc=V;}
for(int i=0;i<V;i++) buf[i]=w[i];
float thr=-1e30f; /* k-th largest logit = threshold */
for(int k=0;k<topk;k++){ int mi=0; float mv=-1e30f; for(int i=0;i<V;i++) if(buf[i]>mv){mv=buf[i];mi=i;} thr=mv; buf[mi]=-1e30f; }
float mx=-1e30f; for(int i=0;i<V;i++) if(w[i]>=thr && w[i]>mx) mx=w[i];
static float*pr=0; static int pc=0; if(pc<V){free(pr);pr=(float*)malloc(V*sizeof(float));pc=V;}
float s=0; for(int i=0;i<V;i++){ if(w[i]>=thr){ float e=expf((w[i]-mx)/temp); pr[i]=e; s+=e; } else pr[i]=0; }
float rr=((float)rand()/RAND_MAX)*s,acc=0;
for(int i=0;i<V;i++){ acc+=pr[i]; if(pr[i]>0&&acc>=rr) return i; } return V-1; }
/* ===== GPU inference (mode "i"): autoregressive generation, forward only ===== */
static void run_gen(const char*prompt){
FILE*f=fopen("../nanoeuler.bin","rb"); if(!f){fprintf(stderr,"train first: ./nanoeuler_cuda t\n");exit(1);}
Cfg c; if(fread(&c,sizeof(Cfg),1,f)!=1){exit(1);} if(fread(&VOCAB,sizeof(int),1,f)!=1){exit(1);}
if(fread(&n_merges,sizeof(int),1,f)!=1){exit(1);}
if(fread(bpe_a,sizeof(int),n_merges,f)!=(size_t)n_merges){exit(1);} if(fread(bpe_b,sizeof(int),n_merges,f)!=(size_t)n_merges){exit(1);}
size_t sz[10]; size_t np=wsizes(c,sz); float*params=(float*)malloc(np*sizeof(float));
if(fread(params,sizeof(float),np,f)!=np){exit(1);} fclose(f); bpe_build_decode(); bpe_build_index();
int C=c.C,T=c.T,NL=c.NL,NH=c.NH,NKV=c.NKV,V=c.V,H=c.H,K=c.K,QKV=qkvd(c),hs=C/NH;
size_t TC=(size_t)T*C,ATT=(size_t)NH*T*T,TH=(size_t)T*H,TQ=(size_t)T*QKV,LG=(size_t)T*K*V;
float*dp;DM(dp,np*4);CHECK(cudaMemcpy(dp,params,np*4,cudaMemcpyHostToDevice)); W wd;wset(&wd,dp,c);
float *enc,*r1,*r1s,*qkv,*atty,*pre,*att,*ap,*r2,*r2n,*r2s,*gate,*up,*hsil,*mlp,*res3,*rmsf,*rmsfs,*logits; int*did;
DM(enc,TC*4);DM(r1,TC*4);DM(r1s,T*4);DM(qkv,TQ*4);DM(atty,TC*4);DM(pre,ATT*4);DM(att,ATT*4);DM(ap,TC*4);DM(r2,TC*4);DM(r2n,TC*4);DM(r2s,T*4);
DM(gate,TH*4);DM(up,TH*4);DM(hsil,TH*4);DM(mlp,TC*4);DM(res3,TC*4);DM(rmsf,TC*4);DM(rmsfs,T*4);DM(logits,LG*4);DM(did,T*sizeof(int));
int blk=128; int*ctx=(int*)malloc(T*sizeof(int)); int len=0;
int penc[4096]; int npc=bpe_encode(prompt,penc); fputs(prompt,stdout);
for(int i=0;i<npc;i++){ if(len<T) ctx[len++]=penc[i]; else { memmove(ctx,ctx+1,(T-1)*sizeof(int)); ctx[T-1]=penc[i]; } }
if(len==0){ ctx[len++]=data_ids?data_ids[0]:'\n'; }
float*hlog=(float*)malloc(V*sizeof(float)); int*inp=(int*)malloc(T*sizeof(int));
for(int s=0;s<400;s++){
int curT=len<T?len:T; for(int i=0;i<T;i++) inp[i]=(i<T-curT)?0:ctx[i-(T-curT)];
CHECK(cudaMemcpy(did,inp,T*sizeof(int),cudaMemcpyHostToDevice));
encoder_forward_kernel<<<GR(T),blk>>>(enc,did,wd.tok,T,C); float*res=enc;
for(int l=0;l<NL;l++){
rmsnorm_forward_kernel<<<GR(T),blk>>>(r1,r1s,res,wd.rms1g+l*C,T,C);
matmul_forward_cuda(qkv,r1,wd.qkvw+(size_t)l*QKV*C,T,C,QKV);
rope_kernel<<<GR((size_t)T*NH),blk>>>(qkv,0,QKV,NH,hs,1,T,0); rope_kernel<<<GR((size_t)T*NKV),blk>>>(qkv,C,QKV,NKV,hs,1,T,0);
attention_forward_kernel<<<GR((size_t)NH*T),blk>>>(atty,pre,att,qkv,1,T,C,NH,NKV);
matmul_forward_cuda(ap,atty,wd.attprojw+(size_t)l*C*C,T,C,C);
residual_add_kernel<<<GR(TC),blk>>>(r2,res,ap,TC);
rmsnorm_forward_kernel<<<GR(T),blk>>>(r2n,r2s,r2,wd.rms2g+l*C,T,C);
matmul_forward_cuda(gate,r2n,wd.gatew+(size_t)l*H*C,T,C,H);
matmul_forward_cuda(up,r2n,wd.upw+(size_t)l*H*C,T,C,H);
swiglu_forward_kernel<<<GR(TH),blk>>>(hsil,gate,up,TH);
matmul_forward_cuda(mlp,hsil,wd.downw+(size_t)l*C*H,T,H,C);
residual_add_kernel<<<GR(TC),blk>>>(res3,r2,mlp,TC); res=res3;
}
rmsnorm_forward_kernel<<<GR(T),blk>>>(rmsf,rmsfs,res,wd.rmsfg,T,C);
matmul_forward_cuda(logits,rmsf,wd.headw,T,C,K*V);
CHECK(cudaDeviceSynchronize());
CHECK(cudaMemcpy(hlog,logits+(size_t)(T-1)*K*V,V*sizeof(float),cudaMemcpyDeviceToHost)); /* head 0 */
int id=sample_host(hlog,V,0.8f);
for(int q=0;q<tok_len[id];q++) putchar(tok_bytes[id][q]); fflush(stdout);
if(len<T) ctx[len++]=id; else { memmove(ctx,ctx+1,(T-1)*sizeof(int)); ctx[T-1]=id; }
}
putchar('\n');
}
/* ===== Supervised fine-tuning (SFT) on Alpaca-style instruction data (modes "s"/"c").
Each example is rendered with the Alpaca prompt template; the loss is supervised only
on the response tokens (prompt and padding targets are set to -1 -> zero gradient via
the masked softmax_ce kernel). The result is saved to ../nanoeuler_chat.bin. ===== */
/* minimal JSON string reader: from p, find the next "..." value, decode escapes into a
malloc'd UTF-8 string (*dst), return the pointer just past the closing quote. */
static const char* json_str(const char*p,char**dst){
while(*p && *p!='\"') p++; if(!*p){*dst=0;return p;} p++;
size_t cap=64,n=0; char*s=(char*)malloc(cap);
while(*p && *p!='\"'){
char ch;
if(*p=='\\'){ p++;
if(*p=='u'){ int cp=0; for(int k=0;k<4&&p[1];k++){ p++; char h=*p; cp<<=4;
cp += (h>='0'&&h<='9')?h-'0':(h>='a'&&h<='f')?h-'a'+10:(h>='A'&&h<='F')?h-'A'+10:0; }
if(n+3>=cap){cap=cap*2+3;s=(char*)realloc(s,cap);}
if(cp<0x80) s[n++]=(char)cp;
else if(cp<0x800){ s[n++]=(char)(0xC0|(cp>>6)); s[n++]=(char)(0x80|(cp&0x3F)); }
else { s[n++]=(char)(0xE0|(cp>>12)); s[n++]=(char)(0x80|((cp>>6)&0x3F)); s[n++]=(char)(0x80|(cp&0x3F)); }
p++; continue; }
switch(*p){ case 'n':ch='\n';break; case 't':ch='\t';break; case 'r':ch='\r';break;
case 'b':ch='\b';break; case 'f':ch='\f';break; default:ch=*p;break; }
p++;
} else ch=*p++;
if(n+1>=cap){cap*=2;s=(char*)realloc(s,cap);} s[n++]=ch;
}
if(*p=='\"') p++; if(n+1>=cap){cap++;s=(char*)realloc(s,cap);} s[n]='\0'; *dst=s; return p;
}
#define SFT_MAXTOK 8192
static int **sft_seq=0; static int *sft_len=0,*sft_plen=0; static int sft_n=0,sft_cap=0;
static void sft_add(int*seq,int len,int plen){
if(sft_n==sft_cap){ sft_cap=sft_cap?sft_cap*2:1024;
sft_seq=(int**)realloc(sft_seq,sft_cap*sizeof(int*));
sft_len=(int*)realloc(sft_len,sft_cap*sizeof(int));
sft_plen=(int*)realloc(sft_plen,sft_cap*sizeof(int)); }
int*c=(int*)malloc(len*sizeof(int)); memcpy(c,seq,len*sizeof(int));
sft_seq[sft_n]=c; sft_len[sft_n]=len; sft_plen[sft_n]=plen; sft_n++;
}
static void load_alpaca(const char*path,int T){
FILE*f=fopen(path,"rb"); if(!f){fprintf(stderr,"cannot open %s (run data/get_alpaca.sh)\n",path);exit(1);}
fseek(f,0,SEEK_END); long n=ftell(f); fseek(f,0,SEEK_SET);
char*buf=(char*)malloc(n+1); if(fread(buf,1,n,f)!=(size_t)n){exit(1);} buf[n]=0; fclose(f);
int *tp=(int*)malloc((1<<16)*sizeof(int)),*tf=(int*)malloc((1<<16)*sizeof(int));
char *ptxt=(char*)malloc(1<<16),*rtxt=(char*)malloc(1<<16); static int seq[SFT_MAXTOK];
const char*p=buf; int kept=0,skipped=0;
while((p=strstr(p,"\"instruction\""))){
char*instr=0,*inp=0,*outp=0;
p+=13; p=json_str(p,&instr);
const char*q=strstr(p,"\"input\""); if(!q){free(instr);break;} q+=7; q=json_str(q,&inp);
const char*r=strstr(q,"\"output\""); if(!r){free(instr);free(inp);break;} r+=8; r=json_str(r,&outp); p=r;
if(!instr||!outp){free(instr);free(inp);free(outp);continue;}
int hasin = inp && inp[0];
if(hasin) snprintf(ptxt,1<<16,"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n%s\n\n### Input:\n%s\n\n### Response:\n",instr,inp);
else snprintf(ptxt,1<<16,"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n%s\n\n### Response:\n",instr);
snprintf(rtxt,1<<16,"%s</s>",outp);
int pl=bpe_encode(ptxt,tp), rl=bpe_encode(rtxt,tf);
free(instr);free(inp);free(outp);
if(pl<=0||rl<=0||pl>=T-2||pl+rl>SFT_MAXTOK){skipped++;continue;}
for(int i=0;i<pl;i++) seq[i]=tp[i]; for(int i=0;i<rl;i++) seq[pl+i]=tf[i];
sft_add(seq,pl+rl,pl); kept++;
}
free(buf);free(tp);free(tf);free(ptxt);free(rtxt);
printf("[sft] %d examples (skipped %d) | vocab=%d\n",kept,skipped,VOCAB);
}
static void sft_batch(int*hid,int*htg,int B,int T,int K,int*n_sup){
int ns=0;
for(int b=0;b<B;b++){ int e=rand()%sft_n; int*sq=sft_seq[e]; int L=sft_len[e],pl=sft_plen[e];
int Lc=L<T?L:T;
for(int t=0;t<T;t++) hid[b*T+t]=(t<Lc)?sq[t]:0;
for(int t=0;t<T;t++) for(int j=0;j<K;j++){ int ti=t+1+j,v=-1;
if(ti<Lc && ti>=pl){ v=sq[ti]; ns++; } htg[((size_t)b*T+t)*K+j]=v; } }
*n_sup=ns;
}
static void run_sft(void){
srand(4242);
/* load the pretrained base model (config + BPE + params) from ../nanoeuler.bin */
FILE*f=fopen("../nanoeuler.bin","rb"); if(!f){fprintf(stderr,"pretrain first: ./nanoeuler_cuda t\n");exit(1);}
Cfg c; if(fread(&c,sizeof(Cfg),1,f)!=1){exit(1);} if(fread(&VOCAB,sizeof(int),1,f)!=1){exit(1);}
if(fread(&n_merges,sizeof(int),1,f)!=1){exit(1);}
if(fread(bpe_a,sizeof(int),n_merges,f)!=(size_t)n_merges){exit(1);} if(fread(bpe_b,sizeof(int),n_merges,f)!=(size_t)n_merges){exit(1);}
size_t sz[10]; size_t np=wsizes(c,sz); float*params=(float*)malloc(np*sizeof(float));
if(fread(params,sizeof(float),np,f)!=np){exit(1);} fclose(f);
int C=c.C,T=c.T,NL=c.NL,NH=c.NH,NKV=c.NKV,V=c.V,H=c.H,K=c.K,QKV=qkvd(c),hs=C/NH;
int B=16,steps=8000,warm=100; float lr0=1e-4f;
size_t BT=(size_t)B*T,BTC=BT*C,BTH=BT*H,BTQ=BT*QKV,LG=BT*K*V;
bpe_build_index(); load_alpaca("../data/alpaca.json",T);
printf("[model:sft] C=%d L=%d T=%d | %.2fM params | %d steps lr %.0e\n",C,NL,T,np/1e6,steps,lr0);
float*dp;DM(dp,np*4);CHECK(cudaMemcpy(dp,params,np*4,cudaMemcpyHostToDevice)); W wd;wset(&wd,dp,c);
float*dgr;DM(dgr,np*4); float*dm;DM(dm,np*4);CHECK(cudaMemset(dm,0,np*4)); float*dv;DM(dv,np*4);CHECK(cudaMemset(dv,0,np*4));
W gd;wset(&gd,dgr,c);
int *did,*dtg;DM(did,BT*sizeof(int));DM(dtg,BT*K*sizeof(int));
int *hid=(int*)malloc(BT*sizeof(int)),*htg=(int*)malloc(BT*K*sizeof(int));
float *Genc,*GLr1,*GLr1s,*GLqkv,*GLatty,*GLap,*GLr2,*GLr2n,*GLr2s,*GLg,*GLu,*GLh,*GLm,*GLr3,*GLlse;
float *Grmsf,*Grmsfs,*Glog,*Gprob,*Gdlog,*Gdrms,*Gdres,*GdHs,*GdG,*GdU,*Gdatty,*Gdqkv;
DM(Genc,BTC*4);DM(GLr1,NL*BTC*4);DM(GLr1s,NL*BT*4);DM(GLqkv,NL*BTQ*4);DM(GLatty,NL*BTC*4);
DM(GLap,NL*BTC*4);DM(GLr2,NL*BTC*4);DM(GLr2n,NL*BTC*4);DM(GLr2s,NL*BT*4);DM(GLg,NL*BTH*4);DM(GLu,NL*BTH*4);DM(GLh,NL*BTH*4);DM(GLm,NL*BTC*4);DM(GLr3,NL*BTC*4);DM(GLlse,(size_t)NL*B*NH*T*4);
DM(Grmsf,BTC*4);DM(Grmsfs,BT*4);DM(Glog,LG*4);DM(Gprob,LG*4);DM(Gdlog,LG*4);DM(Gdrms,BTC*4);DM(Gdres,BTC*4);DM(GdHs,BTH*4);DM(GdG,BTH*4);DM(GdU,BTH*4);DM(Gdatty,BTC*4);DM(Gdqkv,BTQ*4);
int blk=128; size_t btc=BTC,bth=BTH; float*hprob=(float*)malloc(LG*sizeof(float)); clock_t t0=clock();
for(int step=1;step<=steps;step++){
float lr=step<warm?lr0*step/warm:lr0*0.5f*(1+cosf(3.14159265f*(step-warm)/(steps-warm)));
int n_sup; sft_batch(hid,htg,B,T,K,&n_sup); float scale=1.0f/(n_sup>0?n_sup:1);
CHECK(cudaMemcpy(did,hid,BT*sizeof(int),cudaMemcpyHostToDevice));
CHECK(cudaMemcpy(dtg,htg,BT*K*sizeof(int),cudaMemcpyHostToDevice));
CHECK(cudaMemset(dgr,0,np*4));
encoder_forward_kernel<<<GR(BT),blk>>>(Genc,did,wd.tok,(int)BT,C); float*gres=Genc;
for(int l=0;l<NL;l++){
float*r1=GLr1+(size_t)l*BTC; rmsnorm_forward_kernel<<<GR(BT),blk>>>(r1,GLr1s+(size_t)l*BT,gres,wd.rms1g+l*C,(int)BT,C);
float*qk=GLqkv+(size_t)l*BTQ; matmul_forward_cuda(qk,r1,wd.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
rope_kernel<<<GR((size_t)B*T*NH),blk>>>(qk,0,QKV,NH,hs,B,T,0); rope_kernel<<<GR((size_t)B*T*NKV),blk>>>(qk,C,QKV,NKV,hs,B,T,0);
float*at=GLatty+(size_t)l*BTC; flash_attn_forward_kernel<<<dim3(B*NH,(T+FA_BR-1)/FA_BR),FA_BR>>>(at,GLlse+(size_t)l*B*NH*T,qk,B,T,C,NH,NKV);
float*ap=GLap+(size_t)l*BTC; matmul_forward_cuda(ap,at,wd.attprojw+(size_t)l*C*C,(int)BT,C,C);
float*r2=GLr2+(size_t)l*BTC; residual_add_kernel<<<GR(btc),blk>>>(r2,gres,ap,btc);
float*r2n=GLr2n+(size_t)l*BTC; rmsnorm_forward_kernel<<<GR(BT),blk>>>(r2n,GLr2s+(size_t)l*BT,r2,wd.rms2g+l*C,(int)BT,C);
float*ga=GLg+(size_t)l*BTH; matmul_forward_cuda(ga,r2n,wd.gatew+(size_t)l*H*C,(int)BT,C,H);
float*up=GLu+(size_t)l*BTH; matmul_forward_cuda(up,r2n,wd.upw+(size_t)l*H*C,(int)BT,C,H);
float*hs2=GLh+(size_t)l*BTH; swiglu_forward_kernel<<<GR(bth),blk>>>(hs2,ga,up,bth);
float*ml=GLm+(size_t)l*BTC; matmul_forward_cuda(ml,hs2,wd.downw+(size_t)l*C*H,(int)BT,H,C);
float*r3=GLr3+(size_t)l*BTC; residual_add_kernel<<<GR(btc),blk>>>(r3,r2,ml,btc); gres=r3;
}
rmsnorm_forward_kernel<<<GR(BT),blk>>>(Grmsf,Grmsfs,gres,wd.rmsfg,(int)BT,C);
matmul_forward_cuda(Glog,Grmsf,wd.headw,(int)BT,C,K*V);
softmax_ce_kernel<<<GR(BT*K),blk>>>(Gprob,Gdlog,Glog,dtg,(int)BT,V,K,scale);
CHECK(cudaMemset(Gdrms,0,BTC*4)); matmul_backward_dinp_cuda(Gdrms,Gdlog,wd.headw,(int)BT,C,K*V);
matmul_backward_dw_cuda(gd.headw,Gdlog,Grmsf,(int)BT,C,K*V);
CHECK(cudaMemset(Gdres,0,BTC*4)); rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rmsfg,Gdrms,GLr3+(size_t)(NL-1)*BTC,wd.rmsfg,Grmsfs,(int)BT,C);
for(int l=NL-1;l>=0;l--){
float*res_in=(l==0)?Genc:GLr3+(size_t)(l-1)*BTC;
CHECK(cudaMemset(GdHs,0,BTH*4)); matmul_backward_dinp_cuda(GdHs,Gdres,wd.downw+(size_t)l*C*H,(int)BT,H,C);
matmul_backward_dw_cuda(gd.downw+(size_t)l*C*H,Gdres,GLh+(size_t)l*BTH,(int)BT,H,C);
swiglu_backward_kernel<<<GR(bth),blk>>>(GdG,GdU,GdHs,GLg+(size_t)l*BTH,GLu+(size_t)l*BTH,bth);
CHECK(cudaMemset(Gdrms,0,BTC*4));
matmul_backward_dinp_cuda(Gdrms,GdG,wd.gatew+(size_t)l*H*C,(int)BT,C,H);
matmul_backward_dw_cuda(gd.gatew+(size_t)l*H*C,GdG,GLr2n+(size_t)l*BTC,(int)BT,C,H);
matmul_backward_dinp_cuda(Gdrms,GdU,wd.upw+(size_t)l*H*C,(int)BT,C,H);
matmul_backward_dw_cuda(gd.upw+(size_t)l*H*C,GdU,GLr2n+(size_t)l*BTC,(int)BT,C,H);
rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rms2g+l*C,Gdrms,GLr2+(size_t)l*BTC,wd.rms2g+l*C,GLr2s+(size_t)l*BT,(int)BT,C);
CHECK(cudaMemset(Gdatty,0,BTC*4)); matmul_backward_dinp_cuda(Gdatty,Gdres,wd.attprojw+(size_t)l*C*C,(int)BT,C,C);
matmul_backward_dw_cuda(gd.attprojw+(size_t)l*C*C,Gdres,GLatty+(size_t)l*BTC,(int)BT,C,C);
CHECK(cudaMemset(Gdqkv,0,BTQ*4));
flash_attn_backward_kernel<<<dim3(B*NH,(T+FA_BR-1)/FA_BR),FA_BR>>>(Gdqkv,GLqkv+(size_t)l*BTQ,GLatty+(size_t)l*BTC,Gdatty,GLlse+(size_t)l*B*NH*T,B,T,C,NH,NKV);
rope_kernel<<<GR((size_t)B*T*NH),blk>>>(Gdqkv,0,QKV,NH,hs,B,T,1); rope_kernel<<<GR((size_t)B*T*NKV),blk>>>(Gdqkv,C,QKV,NKV,hs,B,T,1);
CHECK(cudaMemset(Gdrms,0,BTC*4)); matmul_backward_dinp_cuda(Gdrms,Gdqkv,wd.qkvw+(size_t)l*QKV*C,(int)BT,C,QKV);
matmul_backward_dw_cuda(gd.qkvw+(size_t)l*QKV*C,Gdqkv,GLr1+(size_t)l*BTC,(int)BT,C,QKV);
rmsnorm_backward_kernel<<<GR(BT),blk>>>(Gdres,gd.rms1g+l*C,Gdrms,res_in,wd.rms1g+l*C,GLr1s+(size_t)l*BT,(int)BT,C);
}
encoder_backward_kernel<<<GR(btc),blk>>>(gd.tok,Gdres,did,(int)BT,C);
float c1=1.0f-powf(0.9f,step),c2=1.0f-powf(0.95f,step);
adamw_kernel<<<GR(np),blk>>>(dp,dgr,dm,dv,np,lr,0.9f,0.95f,1e-8f,0.0f,c1,c2);
if(step%20==0||step==1){ CHECK(cudaDeviceSynchronize()); CHECK(cudaMemcpy(hprob,Gprob,LG*4,cudaMemcpyDeviceToHost));
float loss=0; int cnt=0; for(size_t bt=0;bt<BT;bt++)for(int j=0;j<K;j++){int tt=htg[bt*K+j]; if(tt<0)continue;
float pr=hprob[bt*K*V+(size_t)j*V+tt]; loss+=-logf(pr>1e-12f?pr:1e-12f); cnt++; } loss/=(cnt>0?cnt:1);
printf("step %4d | loss %.4f | lr %.1e | sup %d | %.1fs\n",step,loss,lr,n_sup,(double)(clock()-t0)/CLOCKS_PER_SEC); fflush(stdout); }
}
CHECK(cudaMemcpy(params,dp,np*4,cudaMemcpyDeviceToHost));
FILE*o=fopen("../nanoeuler_chat.bin","wb"); fwrite(&c,sizeof(Cfg),1,o); fwrite(&VOCAB,sizeof(int),1,o);
fwrite(&n_merges,sizeof(int),1,o); fwrite(bpe_a,sizeof(int),n_merges,o); fwrite(bpe_b,sizeof(int),n_merges,o);
fwrite(params,sizeof(float),np,o); fclose(o);
printf("[saved] ../nanoeuler_chat.bin (chat with: ./nanoeuler_cuda c)\n");
}
/* ===== Interactive chat (mode "c"): loads the SFT model, wraps each user line in the
Alpaca template, samples a response, stops at the "</s>" end marker. ===== */
static void run_chat(void){
srand((unsigned)time(0));
FILE*f=fopen("../nanoeuler_chat.bin","rb"); if(!f){fprintf(stderr,"fine-tune first: ./nanoeuler_cuda s\n");exit(1);}
Cfg c; if(fread(&c,sizeof(Cfg),1,f)!=1){exit(1);} if(fread(&VOCAB,sizeof(int),1,f)!=1){exit(1);}
if(fread(&n_merges,sizeof(int),1,f)!=1){exit(1);}
if(fread(bpe_a,sizeof(int),n_merges,f)!=(size_t)n_merges){exit(1);} if(fread(bpe_b,sizeof(int),n_merges,f)!=(size_t)n_merges){exit(1);}
size_t sz[10]; size_t np=wsizes(c,sz); float*params=(float*)malloc(np*sizeof(float));