@@ -480,110 +480,89 @@ void LoweringPreparePass::lowerBinOp(BinOp op) {
480480 op.erase ();
481481}
482482
483- static mlir::Value lowerScalarToComplexCast (MLIRContext &ctx, CastOp op) {
484- CIRBaseBuilderTy builder (ctx);
483+ static mlir::Value lowerScalarToComplexCast (mlir::MLIRContext &ctx,
484+ cir::CastOp op) {
485+ cir::CIRBaseBuilderTy builder (ctx);
485486 builder.setInsertionPoint (op);
486487
487- auto src = op.getSrc ();
488- auto imag = builder.getNullValue (src.getType (), op.getLoc ());
488+ mlir::Value src = op.getSrc ();
489+ mlir::Value imag = builder.getNullValue (src.getType (), op.getLoc ());
489490 return builder.createComplexCreate (op.getLoc (), src, imag);
490491}
491492
492- static mlir::Value lowerComplexToScalarCast (MLIRContext &ctx, CastOp op) {
493- CIRBaseBuilderTy builder (ctx);
493+ static mlir::Value lowerComplexToScalarCast (mlir::MLIRContext &ctx,
494+ cir::CastOp op,
495+ cir::CastKind elemToBoolKind) {
496+ cir::CIRBaseBuilderTy builder (ctx);
494497 builder.setInsertionPoint (op);
495498
496- auto src = op.getSrc ();
497-
499+ mlir::Value src = op.getSrc ();
498500 if (!mlir::isa<cir::BoolType>(op.getType ()))
499501 return builder.createComplexReal (op.getLoc (), src);
500502
501503 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
502- auto srcReal = builder.createComplexReal (op.getLoc (), src);
503- auto srcImag = builder.createComplexImag (op.getLoc (), src);
504-
505- cir::CastKind elemToBoolKind;
506- if (op.getKind () == cir::CastKind::float_complex_to_bool)
507- elemToBoolKind = cir::CastKind::float_to_bool;
508- else if (op.getKind () == cir::CastKind::int_complex_to_bool)
509- elemToBoolKind = cir::CastKind::int_to_bool;
510- else
511- llvm_unreachable (" invalid complex to bool cast kind" );
504+ mlir::Value srcReal = builder.createComplexReal (op.getLoc (), src);
505+ mlir::Value srcImag = builder.createComplexImag (op.getLoc (), src);
512506
513- auto boolTy = builder.getBoolTy ();
514- auto srcRealToBool =
507+ cir::BoolType boolTy = builder.getBoolTy ();
508+ mlir::Value srcRealToBool =
515509 builder.createCast (op.getLoc (), elemToBoolKind, srcReal, boolTy);
516- auto srcImagToBool =
510+ mlir::Value srcImagToBool =
517511 builder.createCast (op.getLoc (), elemToBoolKind, srcImag, boolTy);
518-
519- // srcRealToBool || srcImagToBool
520512 return builder.createLogicalOr (op.getLoc (), srcRealToBool, srcImagToBool);
521513}
522514
523- static mlir::Value lowerComplexToComplexCast (MLIRContext &ctx, CastOp op) {
515+ static mlir::Value lowerComplexToComplexCast (mlir::MLIRContext &ctx,
516+ cir::CastOp op,
517+ cir::CastKind scalarCastKind) {
524518 CIRBaseBuilderTy builder (ctx);
525519 builder.setInsertionPoint (op);
526520
527- auto src = op.getSrc ();
521+ mlir::Value src = op.getSrc ();
528522 auto dstComplexElemTy =
529523 mlir::cast<cir::ComplexType>(op.getType ()).getElementType ();
530524
531- auto srcReal = builder.createComplexReal (op.getLoc (), src);
532- auto srcImag = builder.createComplexReal (op.getLoc (), src);
525+ mlir::Value srcReal = builder.createComplexReal (op.getLoc (), src);
526+ mlir::Value srcImag = builder.createComplexImag (op.getLoc (), src);
533527
534- cir::CastKind scalarCastKind;
535- switch (op.getKind ()) {
536- case cir::CastKind::float_complex:
537- scalarCastKind = cir::CastKind::floating;
538- break ;
539- case cir::CastKind::float_complex_to_int_complex:
540- scalarCastKind = cir::CastKind::float_to_int;
541- break ;
542- case cir::CastKind::int_complex:
543- scalarCastKind = cir::CastKind::integral;
544- break ;
545- case cir::CastKind::int_complex_to_float_complex:
546- scalarCastKind = cir::CastKind::int_to_float;
547- break ;
548- default :
549- llvm_unreachable (" invalid complex to complex cast kind" );
550- }
551-
552- auto dstReal = builder.createCast (op.getLoc (), scalarCastKind, srcReal,
553- dstComplexElemTy);
554- auto dstImag = builder.createCast (op.getLoc (), scalarCastKind, srcImag,
555- dstComplexElemTy);
528+ mlir::Value dstReal = builder.createCast (op.getLoc (), scalarCastKind, srcReal,
529+ dstComplexElemTy);
530+ mlir::Value dstImag = builder.createCast (op.getLoc (), scalarCastKind, srcImag,
531+ dstComplexElemTy);
556532 return builder.createComplexCreate (op.getLoc (), dstReal, dstImag);
557533}
558534
559535void LoweringPreparePass::lowerCastOp (CastOp op) {
560- mlir::Value loweredValue;
561- switch (op.getKind ()) {
562- case cir::CastKind::float_to_complex:
563- case cir::CastKind::int_to_complex:
564- loweredValue = lowerScalarToComplexCast (getContext (), op);
565- break ;
566-
567- case cir::CastKind::float_complex_to_real:
568- case cir::CastKind::int_complex_to_real:
569- case cir::CastKind::float_complex_to_bool:
570- case cir::CastKind::int_complex_to_bool:
571- loweredValue = lowerComplexToScalarCast (getContext (), op);
572- break ;
573-
574- case cir::CastKind::float_complex:
575- case cir::CastKind::float_complex_to_int_complex:
576- case cir::CastKind::int_complex:
577- case cir::CastKind::int_complex_to_float_complex:
578- loweredValue = lowerComplexToComplexCast (getContext (), op);
579- break ;
536+ mlir::MLIRContext &ctx = getContext ();
537+ mlir::Value loweredValue = [&]() -> mlir::Value {
538+ switch (op.getKind ()) {
539+ case cir::CastKind::float_to_complex:
540+ case cir::CastKind::int_to_complex:
541+ return lowerScalarToComplexCast (ctx, op);
542+ case cir::CastKind::float_complex_to_real:
543+ case cir::CastKind::int_complex_to_real:
544+ return lowerComplexToScalarCast (ctx, op, op.getKind ());
545+ case cir::CastKind::float_complex_to_bool:
546+ return lowerComplexToScalarCast (ctx, op, cir::CastKind::float_to_bool);
547+ case cir::CastKind::int_complex_to_bool:
548+ return lowerComplexToScalarCast (ctx, op, cir::CastKind::int_to_bool);
549+ case cir::CastKind::float_complex:
550+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::floating);
551+ case cir::CastKind::float_complex_to_int_complex:
552+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::float_to_int);
553+ case cir::CastKind::int_complex:
554+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::integral);
555+ case cir::CastKind::int_complex_to_float_complex:
556+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::int_to_float);
557+ default :
558+ return nullptr ;
559+ }
560+ }();
580561
581- default :
582- return ;
562+ if (loweredValue) {
563+ op.replaceAllUsesWith (loweredValue);
564+ op.erase ();
583565 }
584-
585- op.replaceAllUsesWith (loweredValue);
586- op.erase ();
587566}
588567
589568static mlir::Value buildComplexBinOpLibCall (
0 commit comments