diff --git a/sm2_csr.go b/sm2_csr.go new file mode 100644 index 0000000..5384b43 --- /dev/null +++ b/sm2_csr.go @@ -0,0 +1,390 @@ +/* + * Copyright 2014-2023 The GmSSL Project. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ +/* +build cgo */ + +package gmssl + +/* +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int gmssl_copy_file_to_buffer(FILE *fp, char **out, size_t *outlen) { + long size; + char *buf; + + if (fflush(fp) != 0) { + return -1; + } + if (fseek(fp, 0, SEEK_END) != 0) { + return -1; + } + size = ftell(fp); + if (size < 0) { + return -1; + } + if (fseek(fp, 0, SEEK_SET) != 0) { + return -1; + } + + buf = (char *)malloc((size_t)size + 1); + if (!buf) { + return -1; + } + if (size > 0 && fread(buf, 1, (size_t)size, fp) != (size_t)size) { + free(buf); + return -1; + } + buf[size] = '\0'; + *out = buf; + *outlen = (size_t)size; + return 1; +} + +static int gmssl_sm2_private_key_info_encrypt_to_pem_string(const SM2_KEY *key, const char *pass, char **out, size_t *outlen) { + FILE *fp = tmpfile(); + int ret = -1; + + if (!fp) { + return -1; + } + if (sm2_private_key_info_encrypt_to_pem(key, pass, fp) != 1) { + goto end; + } + if (gmssl_copy_file_to_buffer(fp, out, outlen) != 1) { + goto end; + } + ret = 1; +end: + fclose(fp); + return ret; +} + +static int gmssl_sm2_public_key_info_to_pem_string(const SM2_KEY *key, char **out, size_t *outlen) { + FILE *fp = tmpfile(); + int ret = -1; + + if (!fp) { + return -1; + } + if (sm2_public_key_info_to_pem(key, fp) != 1) { + goto end; + } + if (gmssl_copy_file_to_buffer(fp, out, outlen) != 1) { + goto end; + } + ret = 1; +end: + fclose(fp); + return ret; +} + +static int gmssl_x509_req_to_pem_string(const uint8_t *req, size_t reqlen, char **out, size_t *outlen) { + FILE *fp = tmpfile(); + int ret = -1; + + if (!fp) { + return -1; + } + if (x509_req_to_pem(req, reqlen, fp) != 1) { + goto end; + } + if (gmssl_copy_file_to_buffer(fp, out, outlen) != 1) { + goto end; + } + ret = 1; +end: + fclose(fp); + return ret; +} + +static int gmssl_create_csr_pem( + const SM2_KEY *key, + const char *country, + const char *state, + const char *locality, + const char *org, + const char *org_unit, + const char *common_name, + const char *signer_id, + char **out, + size_t *outlen +) { + uint8_t subject[256]; + size_t subject_len = 0; + uint8_t attrs[1] = {0}; + size_t attrs_len = 0; + uint8_t req[4096]; + uint8_t *p = req; + size_t req_len = 0; + + if (x509_name_set(subject, &subject_len, sizeof(subject), country, state, locality, org, org_unit, common_name) != 1) { + return -1; + } + if (x509_req_sign_to_der( + X509_version_v1, + subject, subject_len, + key, + attrs, attrs_len, + OID_sm2sign_with_sm3, + key, signer_id, strlen(signer_id), + &p, &req_len) != 1) { + return -1; + } + if (x509_req_verify(req, req_len, signer_id, strlen(signer_id)) != 1) { + return -1; + } + return gmssl_x509_req_to_pem_string(req, req_len, out, outlen); +} + +static int gmssl_parse_csr_subject_field(const char *pem, size_t pemlen, int oid, char **out, size_t *outlen) { + FILE *fp = tmpfile(); + uint8_t req[4096]; + size_t reqlen = 0; + int version = 0; + const uint8_t *subject = NULL; + size_t subject_len = 0; + SM2_KEY subject_public_key; + const uint8_t *attributes = NULL; + size_t attributes_len = 0; + int signature_algor = 0; + int tag = 0; + const uint8_t *val = NULL; + size_t vlen = 0; + int ret; + + if (!fp) { + return -1; + } + if (pemlen > 0 && fwrite(pem, 1, pemlen, fp) != pemlen) { + goto err; + } + rewind(fp); + if (x509_req_from_pem(req, &reqlen, sizeof(req), fp) != 1) { + goto err; + } + if (x509_req_get_details( + req, reqlen, + &version, + &subject, &subject_len, + &subject_public_key, + &attributes, &attributes_len, + &signature_algor, + NULL, NULL) != 1) { + goto err; + } + + ret = x509_name_get_value_by_type(subject, subject_len, oid, &tag, &val, &vlen); + if (ret < 0) { + goto err; + } + *out = (char *)malloc(vlen + 1); + if (!*out) { + goto err; + } + if (ret == 1 && vlen > 0) { + memcpy(*out, val, vlen); + } + (*out)[vlen] = '\0'; + *outlen = vlen; + fclose(fp); + return 1; + +err: + fclose(fp); + return -1; +} + +static int gmssl_verify_csr_pem(const char *pem, size_t pemlen, const char *signer_id) { + FILE *fp = tmpfile(); + uint8_t req[4096]; + size_t reqlen = 0; + int ret = -1; + + if (!fp) { + return -1; + } + if (pemlen > 0 && fwrite(pem, 1, pemlen, fp) != pemlen) { + goto end; + } + rewind(fp); + if (x509_req_from_pem(req, &reqlen, sizeof(req), fp) != 1) { + goto end; + } + if (x509_req_verify(req, reqlen, signer_id, strlen(signer_id)) != 1) { + goto end; + } + ret = 1; +end: + fclose(fp); + return ret; +} +*/ +import "C" + +import ( + "errors" + "unsafe" +) + +type CSRSubject struct { + CN string + O string + OU string + C string + L string + ST string +} + +func cOptionalString(value string) (*C.char, func()) { + if value == "" { + return nil, func() {} + } + ptr := C.CString(value) + return ptr, func() { + C.free(unsafe.Pointer(ptr)) + } +} + +func (sm2 *Sm2Key) ExportEncryptedPrivateKeyInfoPemToString(pass string) (string, error) { + if sm2.has_private_key != true { + return "", errors.New("Not private key") + } + + passStr := C.CString(pass) + defer C.free(unsafe.Pointer(passStr)) + + var out *C.char + var outlen C.size_t + if C.gmssl_sm2_private_key_info_encrypt_to_pem_string(&sm2.sm2_key, passStr, &out, &outlen) != 1 { + return "", errors.New("Libgmssl inner error") + } + defer C.free(unsafe.Pointer(out)) + + return C.GoStringN(out, C.int(outlen)), nil +} + +func (sm2 *Sm2Key) ExportPublicKeyInfoPemToString() (string, error) { + var out *C.char + var outlen C.size_t + if C.gmssl_sm2_public_key_info_to_pem_string(&sm2.sm2_key, &out, &outlen) != 1 { + return "", errors.New("Libgmssl inner error") + } + defer C.free(unsafe.Pointer(out)) + + return C.GoStringN(out, C.int(outlen)), nil +} + +func (sm2 *Sm2Key) CreateCSRPEM(subject CSRSubject, signerID string) (string, error) { + if sm2.has_private_key != true { + return "", errors.New("Not private key") + } + if signerID == "" { + signerID = Sm2DefaultId + } + + country := C.CString(subject.C) + defer C.free(unsafe.Pointer(country)) + + state, freeState := cOptionalString(subject.ST) + defer freeState() + + locality, freeLocality := cOptionalString(subject.L) + defer freeLocality() + + org, freeOrg := cOptionalString(subject.O) + defer freeOrg() + + orgUnit, freeOrgUnit := cOptionalString(subject.OU) + defer freeOrgUnit() + + commonName, freeCommonName := cOptionalString(subject.CN) + defer freeCommonName() + + signerIDStr := C.CString(signerID) + defer C.free(unsafe.Pointer(signerIDStr)) + + var out *C.char + var outlen C.size_t + if C.gmssl_create_csr_pem( + &sm2.sm2_key, + country, + state, + locality, + org, + orgUnit, + commonName, + signerIDStr, + &out, + &outlen, + ) != 1 { + return "", errors.New("Libgmssl inner error") + } + defer C.free(unsafe.Pointer(out)) + + return C.GoStringN(out, C.int(outlen)), nil +} + +func VerifyCSRPEM(csrPEM string, signerID string) bool { + if signerID == "" { + signerID = Sm2DefaultId + } + + csr := C.CString(csrPEM) + defer C.free(unsafe.Pointer(csr)) + + id := C.CString(signerID) + defer C.free(unsafe.Pointer(id)) + + return C.gmssl_verify_csr_pem(csr, C.size_t(len(csrPEM)), id) == 1 +} + +func ParseCSRSubjectPEM(csrPEM string) (CSRSubject, error) { + csr := C.CString(csrPEM) + defer C.free(unsafe.Pointer(csr)) + + parseField := func(oid C.int) (string, error) { + var out *C.char + var outlen C.size_t + if C.gmssl_parse_csr_subject_field(csr, C.size_t(len(csrPEM)), oid, &out, &outlen) != 1 { + return "", errors.New("Libgmssl inner error") + } + defer C.free(unsafe.Pointer(out)) + return C.GoStringN(out, C.int(outlen)), nil + } + + subject := CSRSubject{} + var err error + + if subject.CN, err = parseField(C.OID_at_common_name); err != nil { + return CSRSubject{}, err + } + if subject.O, err = parseField(C.OID_at_organization_name); err != nil { + return CSRSubject{}, err + } + if subject.OU, err = parseField(C.OID_at_organizational_unit_name); err != nil { + return CSRSubject{}, err + } + if subject.C, err = parseField(C.OID_at_country_name); err != nil { + return CSRSubject{}, err + } + if subject.L, err = parseField(C.OID_at_locality_name); err != nil { + return CSRSubject{}, err + } + if subject.ST, err = parseField(C.OID_at_state_or_province_name); err != nil { + return CSRSubject{}, err + } + return subject, nil +} diff --git a/sm2_csr_test.go b/sm2_csr_test.go new file mode 100644 index 0000000..a34b029 --- /dev/null +++ b/sm2_csr_test.go @@ -0,0 +1,82 @@ +package gmssl + +import ( + "strings" + "testing" +) + +func TestSm2CSRPEM(t *testing.T) { + sm2, err := GenerateSm2Key() + if err != nil { + t.Fatalf("GenerateSm2Key() error = %v", err) + } + + privateKeyPEM, err := sm2.ExportEncryptedPrivateKeyInfoPemToString("Password123!") + if err != nil { + t.Fatalf("ExportEncryptedPrivateKeyInfoPemToString() error = %v", err) + } + if !strings.HasPrefix(privateKeyPEM, "-----BEGIN ENCRYPTED PRIVATE KEY-----") { + t.Fatalf("unexpected private key pem prefix: %q", privateKeyPEM) + } + + publicKeyPEM, err := sm2.ExportPublicKeyInfoPemToString() + if err != nil { + t.Fatalf("ExportPublicKeyInfoPemToString() error = %v", err) + } + if !strings.HasPrefix(publicKeyPEM, "-----BEGIN PUBLIC KEY-----") { + t.Fatalf("unexpected public key pem prefix: %q", publicKeyPEM) + } + + expected := CSRSubject{ + CN: "China Unicom Root CA", + O: "China United Network Communications Group Co., Ltd.", + OU: "Root CA", + C: "CN", + L: "Beijing", + ST: "Beijing", + } + csrPEM, err := sm2.CreateCSRPEM(expected, Sm2DefaultId) + if err != nil { + t.Fatalf("CreateCSRPEM() error = %v", err) + } + if !strings.HasPrefix(csrPEM, "-----BEGIN CERTIFICATE REQUEST-----") { + t.Fatalf("unexpected csr pem prefix: %q", csrPEM) + } + if !VerifyCSRPEM(csrPEM, Sm2DefaultId) { + t.Fatal("VerifyCSRPEM() returned false") + } + + got, err := ParseCSRSubjectPEM(csrPEM) + if err != nil { + t.Fatalf("ParseCSRSubjectPEM() error = %v", err) + } + if got != expected { + t.Fatalf("ParseCSRSubjectPEM() = %#v, want %#v", got, expected) + } +} + +func TestSm2CSRWithoutOU(t *testing.T) { + sm2, err := GenerateSm2Key() + if err != nil { + t.Fatalf("GenerateSm2Key() error = %v", err) + } + + expected := CSRSubject{ + CN: "Example CSR", + O: "Example Org", + C: "CN", + L: "Shanghai", + ST: "Shanghai", + } + csrPEM, err := sm2.CreateCSRPEM(expected, Sm2DefaultId) + if err != nil { + t.Fatalf("CreateCSRPEM() error = %v", err) + } + got, err := ParseCSRSubjectPEM(csrPEM) + if err != nil { + t.Fatalf("ParseCSRSubjectPEM() error = %v", err) + } + if got != expected { + t.Fatalf("ParseCSRSubjectPEM() = %#v, want %#v", got, expected) + } +}